De batchnormalisatielaag van Keras is verbroken PlatoBlockchain Data Intelligence. Verticaal zoeken. Ai.

De batch-normalisatielaag van Keras is verbroken

UPDATE: Helaas werd mijn Pull-Request to Keras die het gedrag van de Batch Normalization-laag veranderde niet geaccepteerd. Je kunt de details lezen hier. Voor degenen onder u die dapper genoeg zijn om te knoeien met aangepaste implementaties, kunt u de code vinden in mijn branch. Ik onderhoud het misschien en voeg het samen met de nieuwste stabiele versie van Keras (2.1.6, 2.2.2 en 2.2.4) zolang ik het gebruik, maar geen beloftes.

De meeste mensen die bij Deep Learning werken, hebben er gebruik van gemaakt of hebben er van gehoord Keras. Voor degenen onder u die dat niet hebben gedaan, het is een geweldige bibliotheek die de onderliggende Deep Learning-frameworks zoals TensorFlow, Theano en CNTK abstraheert en een API op hoog niveau voor het trainen van ANNs. Het is gemakkelijk te gebruiken, maakt snelle prototyping mogelijk en heeft een vriendelijke actieve gemeenschap. Ik gebruik het al geruime tijd en draag regelmatig bij aan het project en ik raad het zeker aan aan iedereen die aan Deep Learning wil werken.

Hoewel Keras mijn leven gemakkelijker heeft gemaakt, ben ik vaak gebeten door het vreemde gedrag van de Batch Normalization-laag. Het standaardgedrag is in de loop van de tijd veranderd, maar het veroorzaakt nog steeds problemen voor veel gebruikers en als gevolg daarvan zijn er verschillende gerelateerd openstaande kwesties op Github. In deze blogpost zal ik proberen een case te bouwen waarom de BatchNormalization-laag van Keras niet goed werkt met Transfer Learning, ik zal de code geven die het probleem oplost en ik zal voorbeelden geven met de resultaten van de stuk.

In de onderstaande subsecties geef ik een inleiding over hoe Transfer Learning wordt gebruikt in Deep Learning, wat is de Batch Normalization-laag, hoe learnining_phase werkt en hoe Keras het BN-gedrag in de loop van de tijd heeft veranderd. Als je deze al kent, kun je veilig direct naar sectie 2 springen.

1.1 Het gebruik van Transfer Learning is cruciaal voor Deep Learning

Een van de redenen waarom Deep Learning in het verleden werd bekritiseerd, is dat het te veel gegevens vereist. Dit is niet altijd het geval; er zijn verschillende technieken om deze beperking aan te pakken, waaronder Transfer Learning.

Stel dat u aan een Computer Vision-toepassing werkt en u een classifier wilt bouwen die katten van honden onderscheidt. Je hebt eigenlijk geen miljoenen afbeeldingen van katten / honden nodig om het model te trainen. In plaats daarvan kunt u een vooraf opgeleide classificator gebruiken en de topconvoluties verfijnen met minder gegevens. Het idee erachter is dat, aangezien het vooraf getrainde model op afbeeldingen paste, de onderste windingen functies zoals lijnen, randen en andere nuttige patronen kunnen herkennen, wat betekent dat u de gewichten ervan kunt gebruiken als goede initialisatiewaarden of het netwerk gedeeltelijk opnieuw kunt trainen met uw gegevens .
De batchnormalisatielaag van Keras is verbroken PlatoBlockchain Data Intelligence. Verticaal zoeken. Ai.
Keras wordt geleverd met verschillende vooraf getrainde modellen en eenvoudig te gebruiken voorbeelden voor het verfijnen van modellen. U kunt meer lezen op de documentatie.

1.2 Wat is de batchnormalisatielaag?

De Batch Normalization-laag is in 2014 geïntroduceerd door Ioffe en Szegedy. Het lost het verdwijnende gradiëntprobleem op door de output van de vorige laag te standaardiseren, het versnelt de training door het aantal vereiste iteraties te verminderen en het maakt de training van diepere neurale netwerken mogelijk. Uitleggen hoe het precies werkt, valt buiten het bestek van dit bericht, maar ik raad u ten zeerste aan om het te lezen origineel papier. Een te eenvoudige verklaring is dat het de input schaalt door het gemiddelde ervan af te trekken en te delen met zijn standaarddeviatie; het kan ook leren om de transformatie indien nodig ongedaan te maken.
De batchnormalisatielaag van Keras is verbroken PlatoBlockchain Data Intelligence. Verticaal zoeken. Ai.

1.3 Wat is de leerfase in Keras?

Sommige lagen werken anders tijdens de training en inferentie-modus. De meest opvallende voorbeelden zijn de batchnormalisatie en de uitvallagen. In het geval van BN gebruiken we tijdens de training het gemiddelde en de variantie van de minibatch om de invoer te herschalen. Aan de andere kant gebruiken we tijdens inferentie het voortschrijdend gemiddelde en de variantie die tijdens de training werd geschat.

Keras weet in welke modus moet worden uitgevoerd omdat het een ingebouwd mechanisme heeft dat wordt genoemd leerfase. De leerfase bepaalt of het netwerk zich in de trein- of testmodus bevindt. Als het niet handmatig door de gebruiker is ingesteld, werkt het netwerk tijdens fit () met learning_phase = 1 (train-modus). Tijdens het maken van voorspellingen (bijvoorbeeld wanneer we de predict () & evalu () methodes aanroepen of bij de validatiestap van de fit ()) draait het netwerk met learning_phase = 0 (test mode). Hoewel het niet wordt aanbevolen, is de gebruiker ook in staat om de learning_phase statisch te wijzigen in een specifieke waarde, maar dit moet gebeuren voordat een model of tensor in de grafiek wordt toegevoegd. Als de learning_phase statisch is ingesteld, wordt Keras vergrendeld in de modus die de gebruiker heeft geselecteerd.

1.4 Hoe heeft Keras na verloop van tijd batchnormalisatie geïmplementeerd?

Keras heeft het gedrag van batchnormalisatie verschillende keren gewijzigd, maar de meest recente belangrijke update vond plaats in Keras 2.1.3. Vóór v2.1.3 toen de BN-laag werd bevroren (trainable = False), bleef het zijn batchstatistieken bijwerken, iets dat epische hoofdpijn veroorzaakte bij zijn gebruikers.

Dit was niet alleen een raar beleid, het was eigenlijk ook verkeerd. Stel je voor dat er een BN-laag bestaat tussen windingen; als de laag bevroren is, zouden er geen veranderingen aan moeten gebeuren. Als we de gewichten gedeeltelijk bijwerken en de volgende lagen ook worden bevroren, krijgen ze nooit de kans om zich aan te passen aan de updates van de minibatch-statistieken, wat leidt tot een hogere fout. Gelukkig vanaf versie 2.1.3, wanneer een BN-laag is bevroren, worden de statistieken niet meer bijgewerkt. Maar is dat genoeg? Niet als je Transfer Learning gebruikt.

Hieronder beschrijf ik precies wat het probleem is en schets ik de technische implementatie om het op te lossen. Ik geef ook een paar voorbeelden om de effecten op de nauwkeurigheid van het model voor en na de stuk is toegepast.

2.1 Technische beschrijving van het probleem

Het probleem met de huidige implementatie van Keras is dat wanneer een BN-laag wordt bevroren, deze tijdens de training de mini-batchstatistieken blijft gebruiken. Ik geloof dat een betere benadering wanneer de BN bevroren is, het bewegende gemiddelde en de variantie die het tijdens de training heeft geleerd, gebruikt. Waarom? Om dezelfde redenen waarom de minibatchstatistieken niet moeten worden bijgewerkt wanneer de laag is bevroren: dit kan tot slechte resultaten leiden omdat de volgende lagen niet goed zijn opgeleid.

Stel dat u een Computer Vision-model bouwt, maar u beschikt niet over voldoende gegevens, dus besluit u een van de vooraf opgeleide CNN's van Keras te gebruiken en deze af te stemmen. Helaas krijgt u hierdoor geen garanties dat het gemiddelde en de variantie van uw nieuwe dataset binnen de BN-lagen vergelijkbaar zijn met die van de oorspronkelijke dataset. Onthoud dat uw netwerk op dit moment tijdens de training altijd gebruik zal maken van de mini-batch statistieken of de BN-laag is bevroren of niet; ook tijdens inferentie gebruik je de eerder geleerde statistieken van de bevroren BN-lagen. Als gevolg hiervan, als u de bovenste lagen verfijnt, worden hun gewichten aangepast aan het gemiddelde / de variantie van de nieuwe dataset. Desalniettemin zullen ze tijdens de gevolgtrekking gegevens ontvangen die zijn geschaald anders omdat het gemiddelde / de variantie van de origineel dataset zal worden gebruikt.
De batchnormalisatielaag van Keras is verbroken PlatoBlockchain Data Intelligence. Verticaal zoeken. Ai.
Hierboven geef ik een simplistische (en onrealistische) architectuur voor demonstratiedoeleinden. Laten we aannemen dat we het model vanaf Convolution k + 1 tot aan de bovenkant van het netwerk (rechterkant) verfijnen en de onderkant (linkerkant) bevroren houden. Tijdens de training gebruiken alle BN-lagen van 1 tot k het gemiddelde / de variantie van uw trainingsgegevens. Dit zal negatieve effecten hebben op de bevroren ReLU's als het gemiddelde en de variantie op elke BN niet in de buurt komen van die welke tijdens de pre-training zijn geleerd. Het zorgt er ook voor dat de rest van het netwerk (van CONV k + 1 en later) getraind wordt met inputs die verschillende schalen hebben in vergelijking met wat tijdens inferentie wordt ontvangen. Tijdens de training kan uw netwerk zich aanpassen aan deze veranderingen, maar op het moment dat u overschakelt naar de voorspellingsmodus, gebruikt Keras verschillende standaardisatiestatistieken, iets dat de distributie van de inputs van de volgende lagen versnelt, wat leidt tot slechte resultaten.

2.2 Hoe kunt u detecteren of u bent getroffen?

Een manier om het te detecteren is door de leerfase van Keras statisch in te stellen op 1 (treinmodus) en op 0 (testmodus) en uw model telkens te evalueren. Als er een significant verschil in nauwkeurigheid is in dezelfde dataset, wordt u door het probleem getroffen. Het is de moeite waard om erop te wijzen dat het, vanwege de manier waarop het learning_phase-mechanisme in Keras is geïmplementeerd, doorgaans niet wordt aangeraden ermee te rotzooien. Wijzigingen in de leerfase hebben geen effect op modellen die al zijn samengesteld en gebruikt; zoals u kunt zien in de voorbeelden in de volgende subsecties, is de beste manier om dit te doen, te beginnen met een schone sessie en de leerfase te wijzigen voordat een tensor in de grafiek wordt gedefinieerd.

Een andere manier om het probleem op te sporen tijdens het werken met binaire classificaties, is door de nauwkeurigheid en de AUC te controleren. Als de nauwkeurigheid bijna 50% is, maar de AUC bijna 1 (en u ziet ook verschillen tussen trein / testmodus op dezelfde dataset), kan het zijn dat de kansen buiten de schaal vallen vanwege de BN-statistieken. Evenzo kunt u voor regressie de correlatie van MSE en Spearman gebruiken om het te detecteren.

2.3 Hoe kunnen we dit oplossen?

Ik geloof dat het probleem kan worden opgelost als de bevroren BN-lagen eigenlijk precies dat zijn: permanent vergrendeld in testmodus. Qua uitvoering moet de trainbare vlag deel uitmaken van de computergrafiek en het gedrag van de BN moet niet alleen afhangen van de leerfase, maar ook van de waarde van de trainbare eigenschap. U kunt de details van mijn implementatie vinden op GitHub.

Door de bovenstaande oplossing toe te passen, zal een BN-laag die wordt bevroren niet langer de mini-batchstatistieken gebruiken, maar de statistieken die tijdens de training zijn geleerd. Als gevolg hiervan zal er geen discrepantie zijn tussen training- en testmodi, wat leidt tot grotere nauwkeurigheid. Het is duidelijk dat wanneer de BN-laag niet is bevroren, deze tijdens de training de mini-batchstatistieken blijft gebruiken.

2.4 Beoordeling van de effecten van de pleister

Hoewel ik de bovenstaande implementatie onlangs heb geschreven, wordt het idee erachter zwaar getest op problemen in de echte wereld met behulp van verschillende oplossingen die hetzelfde effect hebben. De discrepantie tussen training- en testmodi kan bijvoorbeeld worden voorkomen door het netwerk in twee delen te splitsen (bevroren en niet-bevroren) en gecachte training uit te voeren (gegevens eenmalig door het bevroren model te sturen en ze vervolgens te gebruiken om het niet-bevroren netwerk te trainen). Desalniettemin, omdat het 'geloof me dat ik dit eerder heb gedaan' doorgaans niet van belang is, geef ik hieronder een paar voorbeelden die de effecten van de nieuwe implementatie in de praktijk laten zien.

Hier volgen enkele belangrijke punten over het experiment:

  1. Ik zal een kleine hoeveelheid gegevens gebruiken om opzettelijk het model te overtreffen en ik zal het model trainen en valideren op dezelfde dataset. Door dit te doen, verwacht ik een bijna perfecte nauwkeurigheid en identieke prestaties op de trein / validatiedataset.
  2. Als ik tijdens validatie een significant lagere nauwkeurigheid krijg op dezelfde dataset, zal ik een duidelijke indicatie hebben dat het huidige BN-beleid de prestaties van het model tijdens inferentie negatief beïnvloedt.
  3. Eventuele voorverwerking vindt plaats buiten de generatoren. Dit wordt gedaan om een ​​bug te omzeilen die in v2.1.5 is geïntroduceerd (momenteel opgelost in de aanstaande v2.1.6 en de nieuwste master).
  4. We zullen Keras dwingen om tijdens de evaluatie verschillende leerfasen te gebruiken. Als we verschillen zien tussen de gerapporteerde nauwkeurigheid, weten we dat we worden beïnvloed door het huidige BN-beleid.

De code voor het experiment wordt hieronder weergegeven:

import numpy as np
from keras.datasets import cifar10
from scipy.misc import imresize

from keras.preprocessing.image import ImageDataGenerator
from keras.applications.resnet50 import ResNet50, preprocess_input
from keras.models import Model, load_model
from keras.layers import Dense, Flatten
from keras import backend as K


seed = 42
epochs = 10
records_per_class = 100

# We take only 2 classes from CIFAR10 and a very small sample to intentionally overfit the model.
# We will also use the same data for train/test and expect that Keras will give the same accuracy.
(x, y), _ = cifar10.load_data()

def filter_resize(category):
   # We do the preprocessing here instead in the Generator to get around a bug on Keras 2.1.5.
   return [preprocess_input(imresize(img, (224,224)).astype('float')) for img in x[y.flatten()==category][:records_per_class]]

x = np.stack(filter_resize(3)+filter_resize(5))
records_per_class = x.shape[0] // 2
y = np.array([[1,0]]*records_per_class + [[0,1]]*records_per_class)


# We will use a pre-trained model and finetune the top layers.
np.random.seed(seed)
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
l = Flatten()(base_model.output)
predictions = Dense(2, activation='softmax')(l)
model = Model(inputs=base_model.input, outputs=predictions)

for layer in model.layers[:140]:
   layer.trainable = False

for layer in model.layers[140:]:
   layer.trainable = True

model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(ImageDataGenerator().flow(x, y, seed=42), epochs=epochs, validation_data=ImageDataGenerator().flow(x, y, seed=42))

# Store the model on disk
model.save('tmp.h5')


# In every test we will clear the session and reload the model to force Learning_Phase values to change.
print('DYNAMIC LEARNING_PHASE')
K.clear_session()
model = load_model('tmp.h5')
# This accuracy should match exactly the one of the validation set on the last iteration.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))


print('STATIC LEARNING_PHASE = 0')
K.clear_session()
K.set_learning_phase(0)
model = load_model('tmp.h5')
# Again the accuracy should match the above.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))


print('STATIC LEARNING_PHASE = 1')
K.clear_session()
K.set_learning_phase(1)
model = load_model('tmp.h5')
# The accuracy will be close to the one of the training set on the last iteration.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))

Laten we de resultaten op Keras v2.1.5 bekijken:

Epoch 1/10
1/7 [===>..........................] - ETA: 25s - loss: 0.8751 - acc: 0.5312
2/7 [=======>......................] - ETA: 11s - loss: 0.8594 - acc: 0.4531
3/7 [===========>..................] - ETA: 7s - loss: 0.8398 - acc: 0.4688 
4/7 [================>.............] - ETA: 4s - loss: 0.8467 - acc: 0.4844
5/7 [====================>.........] - ETA: 2s - loss: 0.7904 - acc: 0.5437
6/7 [========================>.....] - ETA: 1s - loss: 0.7593 - acc: 0.5625
7/7 [==============================] - 12s 2s/step - loss: 0.7536 - acc: 0.5744 - val_loss: 0.6526 - val_acc: 0.6650

Epoch 2/10
1/7 [===>..........................] - ETA: 4s - loss: 0.3881 - acc: 0.8125
2/7 [=======>......................] - ETA: 3s - loss: 0.3945 - acc: 0.7812
3/7 [===========>..................] - ETA: 2s - loss: 0.3956 - acc: 0.8229
4/7 [================>.............] - ETA: 1s - loss: 0.4223 - acc: 0.8047
5/7 [====================>.........] - ETA: 1s - loss: 0.4483 - acc: 0.7812
6/7 [========================>.....] - ETA: 0s - loss: 0.4325 - acc: 0.7917
7/7 [==============================] - 8s 1s/step - loss: 0.4095 - acc: 0.8089 - val_loss: 0.4722 - val_acc: 0.7700

Epoch 3/10
1/7 [===>..........................] - ETA: 4s - loss: 0.2246 - acc: 0.9375
2/7 [=======>......................] - ETA: 3s - loss: 0.2167 - acc: 0.9375
3/7 [===========>..................] - ETA: 2s - loss: 0.2260 - acc: 0.9479
4/7 [================>.............] - ETA: 2s - loss: 0.2179 - acc: 0.9375
5/7 [====================>.........] - ETA: 1s - loss: 0.2356 - acc: 0.9313
6/7 [========================>.....] - ETA: 0s - loss: 0.2392 - acc: 0.9427
7/7 [==============================] - 8s 1s/step - loss: 0.2288 - acc: 0.9456 - val_loss: 0.4282 - val_acc: 0.7800

Epoch 4/10
1/7 [===>..........................] - ETA: 4s - loss: 0.2183 - acc: 0.9688
2/7 [=======>......................] - ETA: 3s - loss: 0.1899 - acc: 0.9844
3/7 [===========>..................] - ETA: 2s - loss: 0.1887 - acc: 0.9792
4/7 [================>.............] - ETA: 1s - loss: 0.1995 - acc: 0.9531
5/7 [====================>.........] - ETA: 1s - loss: 0.1932 - acc: 0.9625
6/7 [========================>.....] - ETA: 0s - loss: 0.1819 - acc: 0.9688
7/7 [==============================] - 8s 1s/step - loss: 0.1743 - acc: 0.9747 - val_loss: 0.3778 - val_acc: 0.8400

Epoch 5/10
1/7 [===>..........................] - ETA: 3s - loss: 0.0973 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0828 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0851 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0897 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0928 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0936 - acc: 1.0000
7/7 [==============================] - 8s 1s/step - loss: 0.1337 - acc: 0.9838 - val_loss: 0.3916 - val_acc: 0.8100

Epoch 6/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0747 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0852 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0812 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0831 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0779 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0766 - acc: 1.0000
7/7 [==============================] - 8s 1s/step - loss: 0.0813 - acc: 1.0000 - val_loss: 0.3637 - val_acc: 0.8550

Epoch 7/10
1/7 [===>..........................] - ETA: 1s - loss: 0.2478 - acc: 0.8750
2/7 [=======>......................] - ETA: 2s - loss: 0.1966 - acc: 0.9375
3/7 [===========>..................] - ETA: 2s - loss: 0.1528 - acc: 0.9583
4/7 [================>.............] - ETA: 1s - loss: 0.1300 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.1193 - acc: 0.9750
6/7 [========================>.....] - ETA: 0s - loss: 0.1196 - acc: 0.9792
7/7 [==============================] - 8s 1s/step - loss: 0.1084 - acc: 0.9838 - val_loss: 0.3546 - val_acc: 0.8600

Epoch 8/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0539 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.0900 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0815 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0740 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0700 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0701 - acc: 1.0000
7/7 [==============================] - 8s 1s/step - loss: 0.0695 - acc: 1.0000 - val_loss: 0.3269 - val_acc: 0.8600

Epoch 9/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0306 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0377 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0898 - acc: 0.9583
4/7 [================>.............] - ETA: 1s - loss: 0.0773 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.0742 - acc: 0.9750
6/7 [========================>.....] - ETA: 0s - loss: 0.0708 - acc: 0.9792
7/7 [==============================] - 8s 1s/step - loss: 0.0659 - acc: 0.9838 - val_loss: 0.3604 - val_acc: 0.8600

Epoch 10/10
1/7 [===>..........................] - ETA: 3s - loss: 0.0354 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0381 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0354 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0828 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.0791 - acc: 0.9750
6/7 [========================>.....] - ETA: 0s - loss: 0.0794 - acc: 0.9792
7/7 [==============================] - 8s 1s/step - loss: 0.0704 - acc: 0.9838 - val_loss: 0.3615 - val_acc: 0.8600

DYNAMIC LEARNING_PHASE
[0.3614931714534759, 0.86]

STATIC LEARNING_PHASE = 0
[0.3614931714534759, 0.86]

STATIC LEARNING_PHASE = 1
[0.025861846953630446, 1.0]

Zoals we hierboven kunnen zien, leert het model tijdens de training heel goed de gegevens en behaalt het op de trainingsset een bijna perfecte nauwkeurigheid. Toch krijgen we aan het einde van elke iteratie, terwijl we het model evalueren op dezelfde dataset, aanzienlijke verschillen in verlies en nauwkeurigheid. Merk op dat we dit niet zouden moeten krijgen; we hebben opzettelijk het model op de specifieke dataset voorzien en de trainings / validatiedatasets zijn identiek.

Nadat de training is voltooid, evalueren we het model met behulp van 3 verschillende learning_phase-configuraties: Dynamic, Static = 0 (testmodus) en Static = 1 (trainingsmodus). Zoals we kunnen zien, zullen de eerste twee configuraties identieke resultaten opleveren in termen van verlies en nauwkeurigheid en hun waarde komt overeen met de gerapporteerde nauwkeurigheid van het model op de validatieset in de laatste iteratie. Desalniettemin zien we een enorme discrepantie (verbetering) zodra we overschakelen naar de trainingsmodus. Waarom is dat zo? Zoals we eerder zeiden, zijn de gewichten van het netwerk afgestemd in de verwachting dat er gegevens worden ontvangen die zijn geschaald met de gemiddelde / variantie van de trainingsgegevens. Helaas verschillen die statistieken van de statistieken die zijn opgeslagen in de BN-lagen. Omdat de BN-lagen bevroren waren, werden deze statistieken nooit bijgewerkt. Deze discrepantie tussen de waarden van de BN-statistieken leidt tot een verslechtering van de nauwkeurigheid tijdens inferentie.

Laten we eens kijken wat er gebeurt als we de toepassen stuk:

Epoch 1/10
1/7 [===>..........................] - ETA: 26s - loss: 0.9992 - acc: 0.4375
2/7 [=======>......................] - ETA: 12s - loss: 1.0534 - acc: 0.4375
3/7 [===========>..................] - ETA: 7s - loss: 1.0592 - acc: 0.4479 
4/7 [================>.............] - ETA: 4s - loss: 0.9618 - acc: 0.5000
5/7 [====================>.........] - ETA: 2s - loss: 0.8933 - acc: 0.5250
6/7 [========================>.....] - ETA: 1s - loss: 0.8638 - acc: 0.5417
7/7 [==============================] - 13s 2s/step - loss: 0.8357 - acc: 0.5570 - val_loss: 0.2414 - val_acc: 0.9450

Epoch 2/10
1/7 [===>..........................] - ETA: 4s - loss: 0.2331 - acc: 0.9688
2/7 [=======>......................] - ETA: 2s - loss: 0.3308 - acc: 0.8594
3/7 [===========>..................] - ETA: 2s - loss: 0.3986 - acc: 0.8125
4/7 [================>.............] - ETA: 1s - loss: 0.3721 - acc: 0.8281
5/7 [====================>.........] - ETA: 1s - loss: 0.3449 - acc: 0.8438
6/7 [========================>.....] - ETA: 0s - loss: 0.3168 - acc: 0.8646
7/7 [==============================] - 9s 1s/step - loss: 0.3165 - acc: 0.8633 - val_loss: 0.1167 - val_acc: 0.9950

Epoch 3/10
1/7 [===>..........................] - ETA: 1s - loss: 0.2457 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.2592 - acc: 0.9688
3/7 [===========>..................] - ETA: 2s - loss: 0.2173 - acc: 0.9688
4/7 [================>.............] - ETA: 1s - loss: 0.2122 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.2003 - acc: 0.9688
6/7 [========================>.....] - ETA: 0s - loss: 0.1896 - acc: 0.9740
7/7 [==============================] - 9s 1s/step - loss: 0.1835 - acc: 0.9773 - val_loss: 0.0678 - val_acc: 1.0000

Epoch 4/10
1/7 [===>..........................] - ETA: 1s - loss: 0.2051 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.1652 - acc: 0.9844
3/7 [===========>..................] - ETA: 2s - loss: 0.1423 - acc: 0.9896
4/7 [================>.............] - ETA: 1s - loss: 0.1289 - acc: 0.9922
5/7 [====================>.........] - ETA: 1s - loss: 0.1225 - acc: 0.9938
6/7 [========================>.....] - ETA: 0s - loss: 0.1149 - acc: 0.9948
7/7 [==============================] - 9s 1s/step - loss: 0.1060 - acc: 0.9955 - val_loss: 0.0455 - val_acc: 1.0000

Epoch 5/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0769 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.0846 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0797 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0736 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0914 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0858 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0808 - acc: 1.0000 - val_loss: 0.0346 - val_acc: 1.0000

Epoch 6/10
1/7 [===>..........................] - ETA: 1s - loss: 0.1267 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.1039 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0893 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0780 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0758 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0789 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0738 - acc: 1.0000 - val_loss: 0.0248 - val_acc: 1.0000

Epoch 7/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0344 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0385 - acc: 1.0000
3/7 [===========>..................] - ETA: 3s - loss: 0.0467 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0445 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0446 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0429 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0421 - acc: 1.0000 - val_loss: 0.0202 - val_acc: 1.0000

Epoch 8/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0319 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0300 - acc: 1.0000
3/7 [===========>..................] - ETA: 3s - loss: 0.0320 - acc: 1.0000
4/7 [================>.............] - ETA: 2s - loss: 0.0307 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0303 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0291 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0358 - acc: 1.0000 - val_loss: 0.0167 - val_acc: 1.0000

Epoch 9/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0246 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0255 - acc: 1.0000
3/7 [===========>..................] - ETA: 3s - loss: 0.0258 - acc: 1.0000
4/7 [================>.............] - ETA: 2s - loss: 0.0250 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0252 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0260 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0327 - acc: 1.0000 - val_loss: 0.0143 - val_acc: 1.0000

Epoch 10/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0251 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.0228 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0217 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0249 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0244 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0239 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0290 - acc: 1.0000 - val_loss: 0.0127 - val_acc: 1.0000

DYNAMIC LEARNING_PHASE
[0.012697912137955427, 1.0]

STATIC LEARNING_PHASE = 0
[0.012697912137955427, 1.0]

STATIC LEARNING_PHASE = 1
[0.01744014158844948, 1.0]

Allereerst zien we dat het netwerk aanzienlijk sneller convergeert en een perfecte nauwkeurigheid bereikt. We zien ook dat er geen discrepantie meer is in termen van nauwkeurigheid wanneer we schakelen tussen verschillende leerfasewaarden.

2.5 Hoe presteert de patch op een echte dataset?

Dus hoe presteert de patch in een realistischer experiment? Laten we Keras 'vooraf getrainde ResNet50 gebruiken (die oorspronkelijk op imagenet past), de bovenste classificatielaag verwijderen en deze met en zonder de patch afstemmen en de resultaten vergelijken. Voor gegevens gebruiken we CIFAR10 (de standaard trein / test-splitsing van Keras) en we verkleinen de afbeeldingen naar 224 × 224 om ze compatibel te maken met de ingangsgrootte van de ResNet50.

We zullen 10 tijdperken doen om de bovenste classificatielaag te trainen met RSMprop en dan zullen we nog 5 doen om alles na de 139e laag te verfijnen met SGD (lr = 1e-4, momentum = 0.9). Zonder de patch behaalt ons model een nauwkeurigheid van 87.44%. Met de patch krijgen we een nauwkeurigheid van 92.36%, bijna 5 punten hoger.

2.6 Moeten we dezelfde fix toepassen op andere lagen zoals Dropout?

Batchnormalisatie is niet de enige laag die anders werkt tussen trein- en testmodi. Drop-out en zijn varianten hebben ook hetzelfde effect. Moeten we op al deze lagen hetzelfde beleid toepassen? Ik geloof van niet (ook al zou ik graag je mening hierover horen). De reden is dat Dropout wordt gebruikt om overfitting te voorkomen, dus het permanent vergrendelen van de voorspellingsmodus tijdens training zou het doel ervan tenietdoen. Wat denk je?

Ik ben er sterk van overtuigd dat deze discrepantie in Keras moet worden opgelost. Ik heb nog diepere effecten gezien (van 100% tot 50% nauwkeurigheid) in real-world applicaties die door dit probleem worden veroorzaakt. ik plan om te verzenden al een PR naar Keras met de oplossing en hopelijk wordt het geaccepteerd.

Als je deze blogpost leuk vond, neem dan even de tijd om hem te delen op Facebook of Twitter. 🙂

Tijdstempel:

Meer van Datumbox