Stratul de normalizare batch al Keras este rupt PlatoBlockchain Data Intelligence. Căutare verticală. Ai.

Stratul de normalizare a lotului Keras este spart

UPDATE: Din păcate, cererea mea de tragere către Keras care a schimbat comportamentul stratului de normalizare a lotului nu a fost acceptată. Puteți citi detaliile aici. Pentru aceia dintre voi care sunt suficient de curajoși să se încurce cu implementările personalizate, puteți găsi codul în ramura mea. Aș putea să-l întrețin și să îl îmbin cu cea mai recentă versiune stabilă a Keras (2.1.6, 2.2.2 și 2.2.4) atâta timp cât îl folosesc, dar fără promisiuni.

Majoritatea oamenilor care lucrează în Deep Learning au folosit sau au auzit de Keras. Pentru aceia dintre voi care nu au făcut-o, este o bibliotecă grozavă care retrage cadrele de bază de Deep Learning, cum ar fi TensorFlow, Theano și CNTK și oferă o API de nivel înalt pentru formarea ANN-urilor. Este ușor de utilizat, permite crearea rapidă de prototipuri și are o comunitate activă prietenoasă. Îl folosesc foarte mult și contribuiesc periodic la proiect de ceva timp și îl recomand cu siguranță oricui dorește să lucreze la Deep Learning.

Chiar dacă Keras mi-a făcut viața mai ușoară, de multe ori am fost mușcat de comportamentul ciudat al stratului de normalizare a loturilor. Comportamentul său implicit s-a schimbat de-a lungul timpului, cu toate acestea, încă provoacă probleme multor utilizatori și, ca urmare, există mai multe probleme deschise pe Github. În această postare pe blog, voi încerca să construiesc un caz pentru ce stratul BatchNormalization al lui Keras nu se joacă bine cu Transfer Learning, voi oferi codul care rezolvă problema și voi da exemple cu rezultatele plasture.

În subsecțiunile de mai jos, ofer o introducere despre modul în care Transfer Learning este utilizat în Deep Learning, ce este stratul de normalizare a loturilor, cum funcționează learning_phase și cum Keras a schimbat comportamentul BN de-a lungul timpului. Dacă le știți deja, puteți sări în siguranță direct la secțiunea 2.

1.1 Utilizarea Transfer Learning este crucială pentru Deep Learning

Unul dintre motivele pentru care Deep Learning a fost criticat în trecut este că necesită prea multe date. Acest lucru nu este întotdeauna adevărat; Există mai multe tehnici pentru a aborda această limitare, dintre care una este Transfer Learning.

Să presupunem că lucrați la o aplicație Computer Vision și doriți să construiți un clasificator care să distingă pisicile de câini. De fapt, nu aveți nevoie de milioane de imagini de pisică/câine pentru a antrena modelul. În schimb, puteți utiliza un clasificator pre-antrenat și puteți regla fin convoluțiile de top cu mai puține date. Ideea din spatele lui este că, deoarece modelul pre-antrenat s-a potrivit pe imagini, convoluțiile de jos pot recunoaște caracteristici precum linii, margini și alte modele utile, ceea ce înseamnă că îi puteți folosi greutățile fie ca valori bune de inițializare, fie reantrenați parțial rețeaua cu datele dvs. .
Stratul de normalizare batch al Keras este rupt PlatoBlockchain Data Intelligence. Căutare verticală. Ai.
Keras vine cu mai multe modele pre-antrenate și exemple ușor de utilizat despre cum să reglați modelele. Puteți citi mai multe pe documentaţie.

1.2 Ce este stratul Normalizare lot?

Stratul de normalizare a loturilor a fost introdus în 2014 de Ioffe și Szegedy. Abordează problema gradientului care dispare prin standardizarea ieșirii stratului anterior, accelerează antrenamentul prin reducerea numărului de iterații necesare și permite antrenamentul rețelelor neuronale mai profunde. Explicarea exactă a modului în care funcționează este dincolo de scopul acestei postări, dar vă încurajez cu tărie să citiți hârtie originală. O explicație suprasimplificată este că redimensionează intrarea prin scăderea mediei acesteia și prin împărțirea cu abaterea sa standard; de asemenea, poate învăța să anuleze transformarea dacă este necesar.
Stratul de normalizare batch al Keras este rupt PlatoBlockchain Data Intelligence. Căutare verticală. Ai.

1.3 Ce este faza_învățare în Keras?

Unele straturi funcționează diferit în timpul antrenamentului și al modului de inferență. Cele mai notabile exemple sunt standardizarea lotului și straturile de abandon. În cazul BN, în timpul antrenamentului folosim media și varianța mini-lot pentru a redimensiona intrarea. Pe de altă parte, în timpul inferenței folosim media mobilă și varianța care au fost estimate în timpul antrenamentului.

Keras știe în ce mod să ruleze, deoarece are un mecanism încorporat numit faza_învățare. Faza de învățare controlează dacă rețeaua este în modul tren sau test. Dacă nu este setată manual de utilizator, în timpul fit() rețeaua rulează cu learning_phase=1 (mod tren). În timpul producerii de predicții (de exemplu, când apelăm metodele predict() & evaluate() sau la pasul de validare a fit()), rețeaua rulează cu learning_phase=0 (mod test). Chiar dacă nu este recomandat, utilizatorul poate, de asemenea, să schimbe static learning_phase la o anumită valoare, dar acest lucru trebuie să se întâmple înainte ca orice model sau tensor să fie adăugat în grafic. Dacă learning_phase este setat static, Keras va fi blocat în orice mod selectat de utilizator.

1.4 Cum a implementat Keras normalizarea loturilor de-a lungul timpului?

Keras a schimbat de mai multe ori comportamentul normalizării loturilor, dar cea mai recentă actualizare semnificativă a avut loc în Keras 2.1.3. Înainte de v2.1.3, când stratul BN a fost înghețat (antrenabil = Fals), a continuat să-și actualizeze statisticile loturilor, ceva care a provocat dureri de cap epice utilizatorilor săi.

Aceasta nu a fost doar o politică ciudată, a fost de fapt greșită. Imaginați-vă că există un strat BN între circumvoluții; dacă stratul este înghețat, nu ar trebui să apară modificări. Dacă actualizăm parțial greutățile și straturile următoare sunt, de asemenea, înghețate, nu vor avea niciodată șansa de a se adapta la actualizările statisticilor mini-loturi, ducând la o eroare mai mare. Din fericire, începând cu versiunea 2.1.3, atunci când un strat BN este înghețat, acesta nu își mai actualizează statisticile. Dar este suficient? Nu dacă utilizați Transfer Learning.

Mai jos descriu exact care este problema și schițez implementarea tehnică pentru rezolvarea acesteia. De asemenea, ofer câteva exemple pentru a arăta efectele asupra preciziei modelului înainte și după plasture este aplicat.

2.1 Descrierea tehnică a problemei

Problema cu implementarea actuală a Keras este că atunci când un strat BN este înghețat, acesta continuă să folosească statisticile mini-loturi în timpul antrenamentului. Cred că o abordare mai bună atunci când BN-ul este înghețat este să folosiți media mobilă și varianța pe care le-a învățat în timpul antrenamentului. De ce? Din aceleași motive pentru care statisticile mini-loturi nu ar trebui actualizate atunci când stratul este înghețat: poate duce la rezultate slabe deoarece straturile următoare nu sunt antrenate corespunzător.

Să presupunem că construiți un model de computer Vision, dar nu aveți suficiente date, așa că decideți să utilizați unul dintre CNN-urile Keras pre-instruite și să îl reglați fin. Din păcate, făcând acest lucru, nu obțineți garanții că media și variația noului set de date în interiorul straturilor BN vor fi similare cu cele ale setului de date original. Rețineți că în acest moment, în timpul antrenamentului, rețeaua dvs. va folosi întotdeauna statisticile mini-loturi fie că stratul BN este înghețat, fie nu; de asemenea, în timpul inferenței, veți folosi statisticile învățate anterior ale straturilor BN înghețate. Ca urmare, dacă reglați fin straturile superioare, greutățile acestora vor fi ajustate la media/varianța nou set de date. Cu toate acestea, în timpul inferenței, ei vor primi date care sunt scalate diferit deoarece media/varianta a original va fi utilizat setul de date.
Stratul de normalizare batch al Keras este rupt PlatoBlockchain Data Intelligence. Căutare verticală. Ai.
Mai sus ofer o arhitectură simplistă (și nerealistă) în scopuri demonstrative. Să presupunem că reglam fin modelul de la Convoluția k+1 până în partea de sus a rețelei (partea dreaptă) și păstrăm înghețată partea de jos (partea stângă). În timpul antrenamentului, toate straturile BN de la 1 la k vor folosi media/varianța datelor tale de antrenament. Acest lucru va avea efecte negative asupra ReLU-urilor înghețate dacă media și varianța fiecărui BN nu sunt apropiate de cele învățate în timpul pre-antrenamentului. De asemenea, va face ca restul rețelei (de la CONV k+1 și mai târziu) să fie antrenat cu intrări care au scale diferite în comparație cu ceea ce va primi în timpul inferenței. În timpul antrenamentului, rețeaua dvs. se poate adapta la aceste schimbări, cu toate acestea, în momentul în care treceți în modul de predicție, Keras va folosi diferite statistici de standardizare, ceea ce va accelera distribuția intrărilor din straturile următoare, ducând la rezultate slabe.

2.2 Cum puteți detecta dacă sunteți afectat?

O modalitate de a-l detecta este să setați static faza de învățare a Keras la 1 (modul tren) și la 0 (modul test) și să vă evaluați modelul în fiecare caz. Dacă există o diferență semnificativă de precizie pe același set de date, sunteți afectat de problemă. Merită subliniat faptul că, datorită modului în care mecanismul learning_phase este implementat în Keras, de obicei nu este sfătuit să se încurce cu el. Modificările în faza_învățare nu vor avea niciun efect asupra modelelor care sunt deja compilate și utilizate; după cum puteți vedea în exemplele din subsecțiunile următoare, cel mai bun mod de a face acest lucru este să începeți cu o sesiune curată și să schimbați faza_învățare înainte ca orice tensor să fie definit în grafic.

O altă modalitate de a detecta problema în timp ce lucrați cu clasificatoare binari este să verificați acuratețea și AUC. Dacă precizia este aproape de 50%, dar AUC este aproape de 1 (și, de asemenea, observați diferențe între modul tren/test pe același set de date), este posibil ca probabilitățile să fie depășite din cauza statisticilor BN. În mod similar, pentru regresie puteți utiliza MSE și corelația lui Spearman pentru a o detecta.

2.3 Cum îl putem remedia?

Cred că problema poate fi rezolvată dacă straturile BN înghețate sunt de fapt doar atât: blocate permanent în modul de testare. Din punct de vedere al implementării, indicatorul care poate fi antrenat trebuie să facă parte din graficul de calcul, iar comportamentul BN-ului trebuie să depindă nu numai de faza_învățare, ci și de valoarea proprietății antrenabile. Puteți găsi detalii despre implementarea mea pe Github.

Prin aplicarea corecției de mai sus, atunci când un strat BN este înghețat, acesta nu va mai folosi statisticile mini-loturi, ci le va folosi pe cele învățate în timpul antrenamentului. Ca urmare, nu va exista nicio discrepanță între modurile de antrenament și de testare, ceea ce duce la o precizie sporită. Evident, atunci când stratul BN nu este înghețat, va continua să folosească statisticile mini-loturi în timpul antrenamentului.

2.4 Evaluarea efectelor plasturelui

Chiar dacă am scris recent implementarea de mai sus, ideea din spatele acesteia este puternic testată pe probleme din lumea reală folosind diverse soluții care au același efect. De exemplu, discrepanța dintre modurile de antrenament și testare poate fi evitată prin împărțirea rețelei în două părți (înghețat și neînghețat) și efectuarea antrenamentului în cache (trecerea datelor prin modelul înghețat o dată și apoi folosirea lor pentru a antrena rețeaua dezghețată). Cu toate acestea, deoarece „credeți-mă, am mai făcut asta înainte” de obicei nu are nicio importanță, mai jos vă ofer câteva exemple care arată efectele noii implementări în practică.

Iată câteva puncte importante despre experiment:

  1. Voi folosi o cantitate mică de date pentru a supraadapta în mod intenționat modelul și voi antrena și valida modelul pe același set de date. Procedând astfel, mă aștept la o precizie aproape perfectă și o performanță identică pe setul de date tren/validare.
  2. Dacă în timpul validării obțin o acuratețe semnificativ mai mică pe același set de date, voi avea o indicație clară că politica actuală BN afectează negativ performanța modelului în timpul inferenței.
  3. Orice preprocesare va avea loc în afara generatoarelor. Acest lucru este făcut pentru a rezolva o eroare care a fost introdusă în v2.1.5 (remediat în prezent la v2.1.6 viitoare și cel mai recent master).
  4. Vom forța Keras să folosească diferite faze de învățare în timpul evaluării. Dacă descoperim diferențe între acuratețea raportată, vom ști că suntem afectați de politica actuală BN.

Codul pentru experiment este prezentat mai jos:

import numpy as np
from keras.datasets import cifar10
from scipy.misc import imresize

from keras.preprocessing.image import ImageDataGenerator
from keras.applications.resnet50 import ResNet50, preprocess_input
from keras.models import Model, load_model
from keras.layers import Dense, Flatten
from keras import backend as K


seed = 42
epochs = 10
records_per_class = 100

# We take only 2 classes from CIFAR10 and a very small sample to intentionally overfit the model.
# We will also use the same data for train/test and expect that Keras will give the same accuracy.
(x, y), _ = cifar10.load_data()

def filter_resize(category):
   # We do the preprocessing here instead in the Generator to get around a bug on Keras 2.1.5.
   return [preprocess_input(imresize(img, (224,224)).astype('float')) for img in x[y.flatten()==category][:records_per_class]]

x = np.stack(filter_resize(3)+filter_resize(5))
records_per_class = x.shape[0] // 2
y = np.array([[1,0]]*records_per_class + [[0,1]]*records_per_class)


# We will use a pre-trained model and finetune the top layers.
np.random.seed(seed)
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
l = Flatten()(base_model.output)
predictions = Dense(2, activation='softmax')(l)
model = Model(inputs=base_model.input, outputs=predictions)

for layer in model.layers[:140]:
   layer.trainable = False

for layer in model.layers[140:]:
   layer.trainable = True

model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(ImageDataGenerator().flow(x, y, seed=42), epochs=epochs, validation_data=ImageDataGenerator().flow(x, y, seed=42))

# Store the model on disk
model.save('tmp.h5')


# In every test we will clear the session and reload the model to force Learning_Phase values to change.
print('DYNAMIC LEARNING_PHASE')
K.clear_session()
model = load_model('tmp.h5')
# This accuracy should match exactly the one of the validation set on the last iteration.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))


print('STATIC LEARNING_PHASE = 0')
K.clear_session()
K.set_learning_phase(0)
model = load_model('tmp.h5')
# Again the accuracy should match the above.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))


print('STATIC LEARNING_PHASE = 1')
K.clear_session()
K.set_learning_phase(1)
model = load_model('tmp.h5')
# The accuracy will be close to the one of the training set on the last iteration.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))

Să verificăm rezultatele pe Keras v2.1.5:

Epoch 1/10
1/7 [===>..........................] - ETA: 25s - loss: 0.8751 - acc: 0.5312
2/7 [=======>......................] - ETA: 11s - loss: 0.8594 - acc: 0.4531
3/7 [===========>..................] - ETA: 7s - loss: 0.8398 - acc: 0.4688 
4/7 [================>.............] - ETA: 4s - loss: 0.8467 - acc: 0.4844
5/7 [====================>.........] - ETA: 2s - loss: 0.7904 - acc: 0.5437
6/7 [========================>.....] - ETA: 1s - loss: 0.7593 - acc: 0.5625
7/7 [==============================] - 12s 2s/step - loss: 0.7536 - acc: 0.5744 - val_loss: 0.6526 - val_acc: 0.6650

Epoch 2/10
1/7 [===>..........................] - ETA: 4s - loss: 0.3881 - acc: 0.8125
2/7 [=======>......................] - ETA: 3s - loss: 0.3945 - acc: 0.7812
3/7 [===========>..................] - ETA: 2s - loss: 0.3956 - acc: 0.8229
4/7 [================>.............] - ETA: 1s - loss: 0.4223 - acc: 0.8047
5/7 [====================>.........] - ETA: 1s - loss: 0.4483 - acc: 0.7812
6/7 [========================>.....] - ETA: 0s - loss: 0.4325 - acc: 0.7917
7/7 [==============================] - 8s 1s/step - loss: 0.4095 - acc: 0.8089 - val_loss: 0.4722 - val_acc: 0.7700

Epoch 3/10
1/7 [===>..........................] - ETA: 4s - loss: 0.2246 - acc: 0.9375
2/7 [=======>......................] - ETA: 3s - loss: 0.2167 - acc: 0.9375
3/7 [===========>..................] - ETA: 2s - loss: 0.2260 - acc: 0.9479
4/7 [================>.............] - ETA: 2s - loss: 0.2179 - acc: 0.9375
5/7 [====================>.........] - ETA: 1s - loss: 0.2356 - acc: 0.9313
6/7 [========================>.....] - ETA: 0s - loss: 0.2392 - acc: 0.9427
7/7 [==============================] - 8s 1s/step - loss: 0.2288 - acc: 0.9456 - val_loss: 0.4282 - val_acc: 0.7800

Epoch 4/10
1/7 [===>..........................] - ETA: 4s - loss: 0.2183 - acc: 0.9688
2/7 [=======>......................] - ETA: 3s - loss: 0.1899 - acc: 0.9844
3/7 [===========>..................] - ETA: 2s - loss: 0.1887 - acc: 0.9792
4/7 [================>.............] - ETA: 1s - loss: 0.1995 - acc: 0.9531
5/7 [====================>.........] - ETA: 1s - loss: 0.1932 - acc: 0.9625
6/7 [========================>.....] - ETA: 0s - loss: 0.1819 - acc: 0.9688
7/7 [==============================] - 8s 1s/step - loss: 0.1743 - acc: 0.9747 - val_loss: 0.3778 - val_acc: 0.8400

Epoch 5/10
1/7 [===>..........................] - ETA: 3s - loss: 0.0973 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0828 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0851 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0897 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0928 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0936 - acc: 1.0000
7/7 [==============================] - 8s 1s/step - loss: 0.1337 - acc: 0.9838 - val_loss: 0.3916 - val_acc: 0.8100

Epoch 6/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0747 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0852 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0812 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0831 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0779 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0766 - acc: 1.0000
7/7 [==============================] - 8s 1s/step - loss: 0.0813 - acc: 1.0000 - val_loss: 0.3637 - val_acc: 0.8550

Epoch 7/10
1/7 [===>..........................] - ETA: 1s - loss: 0.2478 - acc: 0.8750
2/7 [=======>......................] - ETA: 2s - loss: 0.1966 - acc: 0.9375
3/7 [===========>..................] - ETA: 2s - loss: 0.1528 - acc: 0.9583
4/7 [================>.............] - ETA: 1s - loss: 0.1300 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.1193 - acc: 0.9750
6/7 [========================>.....] - ETA: 0s - loss: 0.1196 - acc: 0.9792
7/7 [==============================] - 8s 1s/step - loss: 0.1084 - acc: 0.9838 - val_loss: 0.3546 - val_acc: 0.8600

Epoch 8/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0539 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.0900 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0815 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0740 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0700 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0701 - acc: 1.0000
7/7 [==============================] - 8s 1s/step - loss: 0.0695 - acc: 1.0000 - val_loss: 0.3269 - val_acc: 0.8600

Epoch 9/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0306 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0377 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0898 - acc: 0.9583
4/7 [================>.............] - ETA: 1s - loss: 0.0773 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.0742 - acc: 0.9750
6/7 [========================>.....] - ETA: 0s - loss: 0.0708 - acc: 0.9792
7/7 [==============================] - 8s 1s/step - loss: 0.0659 - acc: 0.9838 - val_loss: 0.3604 - val_acc: 0.8600

Epoch 10/10
1/7 [===>..........................] - ETA: 3s - loss: 0.0354 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0381 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0354 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0828 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.0791 - acc: 0.9750
6/7 [========================>.....] - ETA: 0s - loss: 0.0794 - acc: 0.9792
7/7 [==============================] - 8s 1s/step - loss: 0.0704 - acc: 0.9838 - val_loss: 0.3615 - val_acc: 0.8600

DYNAMIC LEARNING_PHASE
[0.3614931714534759, 0.86]

STATIC LEARNING_PHASE = 0
[0.3614931714534759, 0.86]

STATIC LEARNING_PHASE = 1
[0.025861846953630446, 1.0]

După cum putem vedea mai sus, în timpul antrenamentului modelul învață foarte bine datele și obține pe setul de antrenament o acuratețe aproape perfectă. Tot la sfârșitul fiecărei iterații, în timp ce evaluăm modelul pe același set de date, obținem diferențe semnificative de pierdere și precizie. Rețineți că nu ar trebui să primim acest lucru; am supraadaptat în mod intenționat modelul pe setul de date specific, iar seturile de date de antrenament/validare sunt identice.

După finalizarea antrenamentului, evaluăm modelul folosind 3 configurații diferite de fază de învățare: Dinamic, Static = 0 (modul de testare) și Static = 1 (modul de antrenament). După cum putem vedea, primele două configurații vor oferi rezultate identice în ceea ce privește pierderea și acuratețea, iar valoarea lor se potrivește cu acuratețea raportată a modelului pe setul de validare din ultima iterație. Cu toate acestea, odată ce trecem la modul de antrenament, observăm o discrepanță (îmbunătățire) masivă. De ce asta? După cum am spus mai devreme, ponderile rețelei sunt reglate așteptând să primească date scalate cu media/varianța datelor de antrenament. Din păcate, acele statistici sunt diferite de cele stocate în straturile BN. Deoarece straturile BN au fost înghețate, aceste statistici nu au fost niciodată actualizate. Această discrepanță între valorile statisticilor BN duce la deteriorarea preciziei în timpul inferenței.

Să vedem ce se întâmplă odată ce aplicăm plasture:

Epoch 1/10
1/7 [===>..........................] - ETA: 26s - loss: 0.9992 - acc: 0.4375
2/7 [=======>......................] - ETA: 12s - loss: 1.0534 - acc: 0.4375
3/7 [===========>..................] - ETA: 7s - loss: 1.0592 - acc: 0.4479 
4/7 [================>.............] - ETA: 4s - loss: 0.9618 - acc: 0.5000
5/7 [====================>.........] - ETA: 2s - loss: 0.8933 - acc: 0.5250
6/7 [========================>.....] - ETA: 1s - loss: 0.8638 - acc: 0.5417
7/7 [==============================] - 13s 2s/step - loss: 0.8357 - acc: 0.5570 - val_loss: 0.2414 - val_acc: 0.9450

Epoch 2/10
1/7 [===>..........................] - ETA: 4s - loss: 0.2331 - acc: 0.9688
2/7 [=======>......................] - ETA: 2s - loss: 0.3308 - acc: 0.8594
3/7 [===========>..................] - ETA: 2s - loss: 0.3986 - acc: 0.8125
4/7 [================>.............] - ETA: 1s - loss: 0.3721 - acc: 0.8281
5/7 [====================>.........] - ETA: 1s - loss: 0.3449 - acc: 0.8438
6/7 [========================>.....] - ETA: 0s - loss: 0.3168 - acc: 0.8646
7/7 [==============================] - 9s 1s/step - loss: 0.3165 - acc: 0.8633 - val_loss: 0.1167 - val_acc: 0.9950

Epoch 3/10
1/7 [===>..........................] - ETA: 1s - loss: 0.2457 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.2592 - acc: 0.9688
3/7 [===========>..................] - ETA: 2s - loss: 0.2173 - acc: 0.9688
4/7 [================>.............] - ETA: 1s - loss: 0.2122 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.2003 - acc: 0.9688
6/7 [========================>.....] - ETA: 0s - loss: 0.1896 - acc: 0.9740
7/7 [==============================] - 9s 1s/step - loss: 0.1835 - acc: 0.9773 - val_loss: 0.0678 - val_acc: 1.0000

Epoch 4/10
1/7 [===>..........................] - ETA: 1s - loss: 0.2051 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.1652 - acc: 0.9844
3/7 [===========>..................] - ETA: 2s - loss: 0.1423 - acc: 0.9896
4/7 [================>.............] - ETA: 1s - loss: 0.1289 - acc: 0.9922
5/7 [====================>.........] - ETA: 1s - loss: 0.1225 - acc: 0.9938
6/7 [========================>.....] - ETA: 0s - loss: 0.1149 - acc: 0.9948
7/7 [==============================] - 9s 1s/step - loss: 0.1060 - acc: 0.9955 - val_loss: 0.0455 - val_acc: 1.0000

Epoch 5/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0769 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.0846 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0797 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0736 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0914 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0858 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0808 - acc: 1.0000 - val_loss: 0.0346 - val_acc: 1.0000

Epoch 6/10
1/7 [===>..........................] - ETA: 1s - loss: 0.1267 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.1039 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0893 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0780 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0758 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0789 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0738 - acc: 1.0000 - val_loss: 0.0248 - val_acc: 1.0000

Epoch 7/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0344 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0385 - acc: 1.0000
3/7 [===========>..................] - ETA: 3s - loss: 0.0467 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0445 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0446 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0429 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0421 - acc: 1.0000 - val_loss: 0.0202 - val_acc: 1.0000

Epoch 8/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0319 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0300 - acc: 1.0000
3/7 [===========>..................] - ETA: 3s - loss: 0.0320 - acc: 1.0000
4/7 [================>.............] - ETA: 2s - loss: 0.0307 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0303 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0291 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0358 - acc: 1.0000 - val_loss: 0.0167 - val_acc: 1.0000

Epoch 9/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0246 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0255 - acc: 1.0000
3/7 [===========>..................] - ETA: 3s - loss: 0.0258 - acc: 1.0000
4/7 [================>.............] - ETA: 2s - loss: 0.0250 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0252 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0260 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0327 - acc: 1.0000 - val_loss: 0.0143 - val_acc: 1.0000

Epoch 10/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0251 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.0228 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0217 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0249 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0244 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0239 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0290 - acc: 1.0000 - val_loss: 0.0127 - val_acc: 1.0000

DYNAMIC LEARNING_PHASE
[0.012697912137955427, 1.0]

STATIC LEARNING_PHASE = 0
[0.012697912137955427, 1.0]

STATIC LEARNING_PHASE = 1
[0.01744014158844948, 1.0]

În primul rând, observăm că rețeaua converge semnificativ mai rapid și obține o precizie perfectă. De asemenea, vedem că nu mai există o discrepanță în ceea ce privește acuratețea atunci când comutăm între diferitele valori ale fazei de învățare.

2.5 Cum funcționează patch-ul pe un set de date real?

Deci, cum funcționează plasturele la un experiment mai realist? Să folosim ResNet50 pre-antrenat de la Keras (se potrivește inițial pe imagenet), să eliminăm stratul de clasificare superior și să-l reglam fin cu și fără plasture și să comparăm rezultatele. Pentru date, vom folosi CIFAR10 (diviziunea standard tren/test furnizată de Keras) și vom redimensiona imaginile la 224×224 pentru a le face compatibile cu dimensiunea de intrare a ResNet50.

Vom face 10 epoci pentru a antrena stratul de clasificare superior folosind RSMprop și apoi vom face alte 5 pentru a regla totul după al 139-lea strat folosind SGD(lr=1e-4, momentum=0.9). Fără plasture modelul nostru atinge o precizie de 87.44%. Folosind patch-ul, obținem o precizie de 92.36%, cu aproape 5 puncte mai mare.

2.6 Ar trebui să aplicăm aceeași corecție altor straturi, cum ar fi Dropout?

Normalizarea lotului nu este singurul strat care funcționează diferit între modurile de tren și de testare. De asemenea, abandonul și variantele sale au același efect. Ar trebui să aplicăm aceeași politică tuturor acestor straturi? Nu cred (deși mi-ar plăcea să aud părerile tale despre asta). Motivul este că Dropout este folosit pentru a evita supraadaptarea, astfel blocarea permanentă în modul de predicție în timpul antrenamentului i-ar învinge scopul. Ce crezi?

Cred cu tărie că această discrepanță trebuie rezolvată în Keras. Am văzut efecte și mai profunde (de la 100% până la 50% precizie) în aplicațiile din lumea reală cauzate de această problemă. eu planifică să trimită trimis deja a PR lui Keras cu soluția și sperăm că va fi acceptată.

Dacă v-a plăcut această postare pe blog, vă rugăm să acordați un moment pentru a o distribui pe Facebook sau Twitter. 🙂

Timestamp-ul:

Mai mult de la Datumbox