Il livello di normalizzazione batch di Keras è interrotto da PlatoBlockchain Data Intelligence. Ricerca verticale. Ai.

Il livello di normalizzazione batch di Keras è rotto

AGGIORNAMENTO: Sfortunatamente la mia richiesta pull a Keras che ha modificato il comportamento del livello di normalizzazione batch non è stata accettata. Puoi leggere i dettagli qui. Per quelli di voi che sono abbastanza coraggiosi da pasticciare con le implementazioni personalizzate, è possibile trovare il codice in il mio ramo. Potrei mantenerlo e unirlo all'ultima versione stabile di Keras (2.1.6, 2.2.2 ed 2.2.4) per tutto il tempo in cui lo uso ma nessuna promessa.

La maggior parte delle persone che lavorano nel Deep Learning lo ha utilizzato o ne ha sentito parlare Keras. Per quelli di voi che non l'hanno fatto, è un'ottima libreria che astrae i framework di Deep Learning sottostanti come TensorFlow, Theano e CNTK e fornisce un API di alto livello per la formazione delle ANN. È facile da usare, consente una rapida prototipazione e ha una comunità attiva amichevole. Lo uso molto e contribuisco periodicamente al progetto da un po 'di tempo e lo consiglio vivamente a chiunque voglia lavorare sul Deep Learning.

Anche se Keras mi ha reso la vita più facile, molte volte sono stato morso dallo strano comportamento del livello di normalizzazione batch. Il suo comportamento predefinito è cambiato nel tempo, tuttavia causa ancora problemi a molti utenti e di conseguenza ce ne sono diversi correlati questioni aperte su GitHub. In questo post del blog, cercherò di creare un caso per cui il livello BatchNormalization di Keras non funziona correttamente con Transfer Learning, fornirò il codice che risolve il problema e fornirò esempi con i risultati del patch.

Nelle sottosezioni seguenti, fornisco un'introduzione su come il Transfer Learning viene utilizzato nel Deep Learning, cos'è il livello di normalizzazione batch, come funziona learningining_phase e come Keras ha modificato il comportamento di BN nel tempo. Se li conosci già, puoi saltare direttamente alla sezione 2.

1.1 L'utilizzo del Transfer Learning è fondamentale per il Deep Learning

Uno dei motivi per cui il Deep Learning è stato criticato in passato è che richiede troppi dati. Questo non è sempre vero; Esistono diverse tecniche per affrontare questa limitazione, una delle quali è il Transfer Learning.

Supponiamo che tu stia lavorando a un'applicazione di visione artificiale e desideri creare un classificatore che distingua i gatti dai cani. In realtà non hai bisogno di milioni di immagini di gatti / cani per addestrare il modello. È invece possibile utilizzare un classificatore pre-addestrato e mettere a punto le principali convoluzioni con meno dati. L'idea alla base è che, poiché il modello pre-addestrato era adatto alle immagini, le convoluzioni inferiori possono riconoscere caratteristiche come linee, bordi e altri modelli utili, il che significa che puoi utilizzare i suoi pesi come buoni valori di inizializzazione o riqualificare parzialmente la rete con i tuoi dati .
Il livello di normalizzazione batch di Keras è interrotto da PlatoBlockchain Data Intelligence. Ricerca verticale. Ai.
Keras viene fornito con diversi modelli pre-addestrati ed esempi facili da usare su come mettere a punto i modelli. Puoi leggere di più su documentazione.

1.2 Cos'è il livello di normalizzazione batch?

Il livello di normalizzazione batch è stato introdotto nel 2014 da Ioffe e Szegedy. Affronta il problema del gradiente di fuga standardizzando l'output del livello precedente, accelera l'addestramento riducendo il numero di iterazioni richieste e consente l'addestramento di reti neurali più profonde. Spiegare esattamente come funziona va oltre lo scopo di questo post, ma ti incoraggio vivamente a leggere il carta originale. Una spiegazione semplificata è che ridimensiona l'input sottraendo la sua media e dividendo con la sua deviazione standard; può anche imparare ad annullare la trasformazione, se necessario.
Il livello di normalizzazione batch di Keras è interrotto da PlatoBlockchain Data Intelligence. Ricerca verticale. Ai.

1.3 Cos'è learning_phase in Keras?

Alcuni livelli funzionano in modo diverso durante l'addestramento e la modalità inferenza. Gli esempi più notevoli sono la normalizzazione batch e i livelli di esclusione. Nel caso di BN, durante l'addestramento utilizziamo la media e la varianza del mini-batch per riscalare l'input. D'altra parte, durante l'inferenza utilizziamo la media mobile e la varianza stimata durante l'addestramento.

Keras sa in quale modalità eseguire perché ha un meccanismo incorporato chiamato fase_apprendimento. La fase di apprendimento controlla se la rete è in modalità treno o test. Se non è impostato manualmente dall'utente, durante fit () la rete funziona con learning_phase = 1 (modalità train). Durante la produzione di previsioni (ad esempio quando chiamiamo i metodi prediction () e assess () o nella fase di convalida di fit ()) la rete viene eseguita con learning_phase = 0 (modalità test). Anche se non è consigliato, l'utente è anche in grado di modificare staticamente learning_phase su un valore specifico, ma ciò deve avvenire prima che qualsiasi modello o tensore venga aggiunto nel grafico. Se learning_phase è impostato staticamente, Keras sarà bloccato in qualsiasi modalità selezionata dall'utente.

1.4 In che modo Keras ha implementato la normalizzazione in batch nel tempo?

Keras ha modificato più volte il comportamento della normalizzazione batch, ma l'aggiornamento significativo più recente è avvenuto in Keras 2.1.3. Prima della v2.1.3, quando il livello BN era congelato (trainable = False), continuava ad aggiornare le sue statistiche batch, cosa che causava epici mal di testa ai suoi utenti.

Questa non era solo una politica strana, in realtà era sbagliata. Immagina che esista uno strato BN tra le convoluzioni; se il livello è congelato, non dovrebbero essere apportate modifiche. Se aggiorniamo parzialmente i suoi pesi e anche i livelli successivi vengono congelati, non avranno mai la possibilità di adattarsi agli aggiornamenti delle statistiche del mini-batch che portano a un errore maggiore. Per fortuna a partire dalla versione 2.1.3, quando un layer BN è congelato non aggiorna più le sue statistiche. Ma è abbastanza? No, se stai usando Transfer Learning.

Di seguito descrivo esattamente qual è il problema e abbozzo l'implementazione tecnica per risolverlo. Fornisco anche alcuni esempi per mostrare gli effetti sulla precisione del modello prima e dopo il patch viene applicato.

2.1 Descrizione tecnica del problema

Il problema con l'attuale implementazione di Keras è che quando un livello BN viene congelato, continua a utilizzare le statistiche del mini-batch durante l'addestramento. Credo che un approccio migliore quando il BN è congelato sia usare la media mobile e la varianza che ha appreso durante l'allenamento. Perché? Per gli stessi motivi per cui le statistiche del mini-batch non dovrebbero essere aggiornate quando il livello è congelato: può portare a risultati scadenti perché i livelli successivi non sono addestrati correttamente.

Supponiamo che tu stia costruendo un modello di Visione artificiale ma non hai abbastanza dati, quindi decidi di utilizzare una delle CNN pre-addestrate di Keras e perfezionarla. Sfortunatamente, così facendo non ottieni alcuna garanzia che la media e la varianza del tuo nuovo set di dati all'interno dei livelli BN saranno simili a quelli del set di dati originale. Ricorda che al momento, durante l'addestramento, la tua rete utilizzerà sempre le statistiche mini-batch sia che il layer BN sia congelato o meno; anche durante l'inferenza userete le statistiche apprese in precedenza dei layer BN congelati. Di conseguenza, se si ottimizzano i livelli superiori, i loro pesi verranno adeguati alla media / varianza del nuovi set di dati. Tuttavia, durante l'inferenza riceveranno dati che vengono scalati diversamente perché la media / varianza di i verrà utilizzato il set di dati.
Il livello di normalizzazione batch di Keras è interrotto da PlatoBlockchain Data Intelligence. Ricerca verticale. Ai.
Sopra fornisco un'architettura semplicistica (e non realistica) a scopo dimostrativo. Supponiamo di mettere a punto il modello dalla Convoluzione k + 1 fino alla parte superiore della rete (lato destro) e di mantenere congelata la parte inferiore (lato sinistro). Durante l'allenamento tutti i livelli BN da 1 a k useranno la media / varianza dei dati di allenamento. Ciò avrà effetti negativi sulle ReLU congelate se la media e la varianza su ciascun BN non sono vicine a quelle apprese durante il pre-allenamento. Inoltre, il resto della rete (da CONV k + 1 e versioni successive) verrà addestrato con input che hanno scale diverse rispetto a ciò che riceverà durante l'inferenza. Durante l'addestramento la tua rete può adattarsi a questi cambiamenti, tuttavia nel momento in cui passi alla modalità di previsione, Keras utilizzerà diverse statistiche di standardizzazione, qualcosa che accelererà la distribuzione degli input dei livelli successivi portando a scarsi risultati.

2.2 Come puoi rilevare se sei interessato?

Un modo per rilevarlo è impostare staticamente la fase di apprendimento di Keras su 1 (modalità treno) e 0 (modalità test) e valutare il tuo modello in ogni caso. Se c'è una differenza significativa nell'accuratezza sullo stesso set di dati, sei interessato dal problema. Vale la pena sottolineare che, a causa del modo in cui il meccanismo learning_phase è implementato in Keras, in genere non è consigliabile manipolarlo. Le modifiche su learning_phase non avranno effetto sui modelli che sono già compilati e utilizzati; come puoi vedere negli esempi nelle sottosezioni successive, il modo migliore per farlo è iniziare con una sessione pulita e cambiare learning_phase prima che qualsiasi tensore sia definito nel grafico.

Un altro modo per rilevare il problema mentre si lavora con i classificatori binari è controllare l'accuratezza e l'AUC. Se la precisione è vicina al 50% ma l'AUC è vicina a 1 (e si osservano anche differenze tra la modalità treno / test sullo stesso set di dati), potrebbe essere che le probabilità siano fuori scala a causa delle statistiche BN. Allo stesso modo, per la regressione è possibile utilizzare MSE e la correlazione di Spearman per rilevarla.

2.3 Come possiamo risolverlo?

Credo che il problema possa essere risolto se i livelli BN congelati sono in realtà solo questo: bloccati in modo permanente in modalità test. Dal punto di vista dell'implementazione, il flag addestrabile deve essere parte del grafo computazionale e il comportamento di BN deve dipendere non solo dalla fase_apprendimento ma anche dal valore della proprietà addestrabile. Puoi trovare i dettagli della mia implementazione su Github.

Applicando la correzione di cui sopra, quando un livello BN viene congelato, non utilizzerà più le statistiche mini-batch ma utilizzerà invece quelle apprese durante l'addestramento. Di conseguenza, non ci saranno discrepanze tra le modalità di allenamento e di prova che portano a una maggiore precisione. Ovviamente quando il layer BN non è congelato, continuerà a utilizzare le statistiche del mini-batch durante l'allenamento.

2.4 Valutazione degli effetti della patch

Anche se di recente ho scritto l'implementazione di cui sopra, l'idea alla base è stata ampiamente testata su problemi del mondo reale utilizzando varie soluzioni alternative che hanno lo stesso effetto. Ad esempio, la discrepanza tra le modalità di addestramento e di test può essere evitata suddividendo la rete in due parti (congelata e non congelata) ed eseguendo l'addestramento nella cache (passando i dati attraverso il modello congelato una volta e quindi utilizzandoli per addestrare la rete non congelata). Tuttavia, poiché il "credimi, l'ho già fatto" in genere non ha alcun peso, di seguito fornisco alcuni esempi che mostrano gli effetti della nuova implementazione nella pratica.

Ecco alcuni punti importanti sull'esperimento:

  1. Userò una piccola quantità di dati per sovradattare intenzionalmente il modello e addestrerò e convaliderò il modello sullo stesso set di dati. In questo modo, mi aspetto una precisione quasi perfetta e prestazioni identiche sul set di dati di treno / convalida.
  2. Se durante la convalida ottengo un'accuratezza significativamente inferiore sullo stesso set di dati, avrò una chiara indicazione che l'attuale politica BN influisce negativamente sulle prestazioni del modello durante l'inferenza.
  3. Qualsiasi preelaborazione avverrà al di fuori di Generators. Questo viene fatto per aggirare un bug che è stato introdotto nella v2.1.5 (attualmente risolto sulla prossima v2.1.6 e l'ultimo master).
  4. Costringeremo Keras a utilizzare diverse fasi di apprendimento durante la valutazione. Se individuiamo differenze tra l'accuratezza riportata, sapremo di essere influenzati dall'attuale politica di BN.

Il codice per l'esperimento è mostrato di seguito:

import numpy as np
from keras.datasets import cifar10
from scipy.misc import imresize

from keras.preprocessing.image import ImageDataGenerator
from keras.applications.resnet50 import ResNet50, preprocess_input
from keras.models import Model, load_model
from keras.layers import Dense, Flatten
from keras import backend as K


seed = 42
epochs = 10
records_per_class = 100

# We take only 2 classes from CIFAR10 and a very small sample to intentionally overfit the model.
# We will also use the same data for train/test and expect that Keras will give the same accuracy.
(x, y), _ = cifar10.load_data()

def filter_resize(category):
   # We do the preprocessing here instead in the Generator to get around a bug on Keras 2.1.5.
   return [preprocess_input(imresize(img, (224,224)).astype('float')) for img in x[y.flatten()==category][:records_per_class]]

x = np.stack(filter_resize(3)+filter_resize(5))
records_per_class = x.shape[0] // 2
y = np.array([[1,0]]*records_per_class + [[0,1]]*records_per_class)


# We will use a pre-trained model and finetune the top layers.
np.random.seed(seed)
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
l = Flatten()(base_model.output)
predictions = Dense(2, activation='softmax')(l)
model = Model(inputs=base_model.input, outputs=predictions)

for layer in model.layers[:140]:
   layer.trainable = False

for layer in model.layers[140:]:
   layer.trainable = True

model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(ImageDataGenerator().flow(x, y, seed=42), epochs=epochs, validation_data=ImageDataGenerator().flow(x, y, seed=42))

# Store the model on disk
model.save('tmp.h5')


# In every test we will clear the session and reload the model to force Learning_Phase values to change.
print('DYNAMIC LEARNING_PHASE')
K.clear_session()
model = load_model('tmp.h5')
# This accuracy should match exactly the one of the validation set on the last iteration.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))


print('STATIC LEARNING_PHASE = 0')
K.clear_session()
K.set_learning_phase(0)
model = load_model('tmp.h5')
# Again the accuracy should match the above.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))


print('STATIC LEARNING_PHASE = 1')
K.clear_session()
K.set_learning_phase(1)
model = load_model('tmp.h5')
# The accuracy will be close to the one of the training set on the last iteration.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))

Controlliamo i risultati su Keras v2.1.5:

Epoch 1/10
1/7 [===>..........................] - ETA: 25s - loss: 0.8751 - acc: 0.5312
2/7 [=======>......................] - ETA: 11s - loss: 0.8594 - acc: 0.4531
3/7 [===========>..................] - ETA: 7s - loss: 0.8398 - acc: 0.4688 
4/7 [================>.............] - ETA: 4s - loss: 0.8467 - acc: 0.4844
5/7 [====================>.........] - ETA: 2s - loss: 0.7904 - acc: 0.5437
6/7 [========================>.....] - ETA: 1s - loss: 0.7593 - acc: 0.5625
7/7 [==============================] - 12s 2s/step - loss: 0.7536 - acc: 0.5744 - val_loss: 0.6526 - val_acc: 0.6650

Epoch 2/10
1/7 [===>..........................] - ETA: 4s - loss: 0.3881 - acc: 0.8125
2/7 [=======>......................] - ETA: 3s - loss: 0.3945 - acc: 0.7812
3/7 [===========>..................] - ETA: 2s - loss: 0.3956 - acc: 0.8229
4/7 [================>.............] - ETA: 1s - loss: 0.4223 - acc: 0.8047
5/7 [====================>.........] - ETA: 1s - loss: 0.4483 - acc: 0.7812
6/7 [========================>.....] - ETA: 0s - loss: 0.4325 - acc: 0.7917
7/7 [==============================] - 8s 1s/step - loss: 0.4095 - acc: 0.8089 - val_loss: 0.4722 - val_acc: 0.7700

Epoch 3/10
1/7 [===>..........................] - ETA: 4s - loss: 0.2246 - acc: 0.9375
2/7 [=======>......................] - ETA: 3s - loss: 0.2167 - acc: 0.9375
3/7 [===========>..................] - ETA: 2s - loss: 0.2260 - acc: 0.9479
4/7 [================>.............] - ETA: 2s - loss: 0.2179 - acc: 0.9375
5/7 [====================>.........] - ETA: 1s - loss: 0.2356 - acc: 0.9313
6/7 [========================>.....] - ETA: 0s - loss: 0.2392 - acc: 0.9427
7/7 [==============================] - 8s 1s/step - loss: 0.2288 - acc: 0.9456 - val_loss: 0.4282 - val_acc: 0.7800

Epoch 4/10
1/7 [===>..........................] - ETA: 4s - loss: 0.2183 - acc: 0.9688
2/7 [=======>......................] - ETA: 3s - loss: 0.1899 - acc: 0.9844
3/7 [===========>..................] - ETA: 2s - loss: 0.1887 - acc: 0.9792
4/7 [================>.............] - ETA: 1s - loss: 0.1995 - acc: 0.9531
5/7 [====================>.........] - ETA: 1s - loss: 0.1932 - acc: 0.9625
6/7 [========================>.....] - ETA: 0s - loss: 0.1819 - acc: 0.9688
7/7 [==============================] - 8s 1s/step - loss: 0.1743 - acc: 0.9747 - val_loss: 0.3778 - val_acc: 0.8400

Epoch 5/10
1/7 [===>..........................] - ETA: 3s - loss: 0.0973 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0828 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0851 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0897 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0928 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0936 - acc: 1.0000
7/7 [==============================] - 8s 1s/step - loss: 0.1337 - acc: 0.9838 - val_loss: 0.3916 - val_acc: 0.8100

Epoch 6/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0747 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0852 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0812 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0831 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0779 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0766 - acc: 1.0000
7/7 [==============================] - 8s 1s/step - loss: 0.0813 - acc: 1.0000 - val_loss: 0.3637 - val_acc: 0.8550

Epoch 7/10
1/7 [===>..........................] - ETA: 1s - loss: 0.2478 - acc: 0.8750
2/7 [=======>......................] - ETA: 2s - loss: 0.1966 - acc: 0.9375
3/7 [===========>..................] - ETA: 2s - loss: 0.1528 - acc: 0.9583
4/7 [================>.............] - ETA: 1s - loss: 0.1300 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.1193 - acc: 0.9750
6/7 [========================>.....] - ETA: 0s - loss: 0.1196 - acc: 0.9792
7/7 [==============================] - 8s 1s/step - loss: 0.1084 - acc: 0.9838 - val_loss: 0.3546 - val_acc: 0.8600

Epoch 8/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0539 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.0900 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0815 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0740 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0700 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0701 - acc: 1.0000
7/7 [==============================] - 8s 1s/step - loss: 0.0695 - acc: 1.0000 - val_loss: 0.3269 - val_acc: 0.8600

Epoch 9/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0306 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0377 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0898 - acc: 0.9583
4/7 [================>.............] - ETA: 1s - loss: 0.0773 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.0742 - acc: 0.9750
6/7 [========================>.....] - ETA: 0s - loss: 0.0708 - acc: 0.9792
7/7 [==============================] - 8s 1s/step - loss: 0.0659 - acc: 0.9838 - val_loss: 0.3604 - val_acc: 0.8600

Epoch 10/10
1/7 [===>..........................] - ETA: 3s - loss: 0.0354 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0381 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0354 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0828 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.0791 - acc: 0.9750
6/7 [========================>.....] - ETA: 0s - loss: 0.0794 - acc: 0.9792
7/7 [==============================] - 8s 1s/step - loss: 0.0704 - acc: 0.9838 - val_loss: 0.3615 - val_acc: 0.8600

DYNAMIC LEARNING_PHASE
[0.3614931714534759, 0.86]

STATIC LEARNING_PHASE = 0
[0.3614931714534759, 0.86]

STATIC LEARNING_PHASE = 1
[0.025861846953630446, 1.0]

Come possiamo vedere sopra, durante l'addestramento il modello apprende molto bene i dati e raggiunge una precisione quasi perfetta sul set di addestramento. Sempre alla fine di ogni iterazione, mentre si valuta il modello sullo stesso set di dati, si ottengono differenze significative in termini di perdita e accuratezza. Nota che non dovremmo ricevere questo; abbiamo sovradimensionato intenzionalmente il modello sul set di dati specifico e i set di dati di addestramento / convalida sono identici.

Al termine della formazione valutiamo il modello utilizzando 3 diverse configurazioni learning_phase: Dynamic, Static = 0 (modalità test) e Static = 1 (modalità training). Come possiamo vedere, le prime due configurazioni forniranno risultati identici in termini di perdita e accuratezza e il loro valore corrisponde alla precisione riportata del modello sulla validazione impostata nell'ultima iterazione. Tuttavia, una volta che si passa alla modalità di allenamento, si osserva un'enorme discrepanza (miglioramento). Perché è così? Come abbiamo detto in precedenza, i pesi della rete sono ottimizzati aspettandosi di ricevere dati scalati con la media / varianza dei dati di addestramento. Sfortunatamente, queste statistiche sono diverse da quelle memorizzate nei livelli BN. Poiché i layer BN sono stati congelati, queste statistiche non sono mai state aggiornate. Questa discrepanza tra i valori delle statistiche BN porta al deterioramento dell'accuratezza durante l'inferenza.

Vediamo cosa succede una volta applicato il patch:

Epoch 1/10
1/7 [===>..........................] - ETA: 26s - loss: 0.9992 - acc: 0.4375
2/7 [=======>......................] - ETA: 12s - loss: 1.0534 - acc: 0.4375
3/7 [===========>..................] - ETA: 7s - loss: 1.0592 - acc: 0.4479 
4/7 [================>.............] - ETA: 4s - loss: 0.9618 - acc: 0.5000
5/7 [====================>.........] - ETA: 2s - loss: 0.8933 - acc: 0.5250
6/7 [========================>.....] - ETA: 1s - loss: 0.8638 - acc: 0.5417
7/7 [==============================] - 13s 2s/step - loss: 0.8357 - acc: 0.5570 - val_loss: 0.2414 - val_acc: 0.9450

Epoch 2/10
1/7 [===>..........................] - ETA: 4s - loss: 0.2331 - acc: 0.9688
2/7 [=======>......................] - ETA: 2s - loss: 0.3308 - acc: 0.8594
3/7 [===========>..................] - ETA: 2s - loss: 0.3986 - acc: 0.8125
4/7 [================>.............] - ETA: 1s - loss: 0.3721 - acc: 0.8281
5/7 [====================>.........] - ETA: 1s - loss: 0.3449 - acc: 0.8438
6/7 [========================>.....] - ETA: 0s - loss: 0.3168 - acc: 0.8646
7/7 [==============================] - 9s 1s/step - loss: 0.3165 - acc: 0.8633 - val_loss: 0.1167 - val_acc: 0.9950

Epoch 3/10
1/7 [===>..........................] - ETA: 1s - loss: 0.2457 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.2592 - acc: 0.9688
3/7 [===========>..................] - ETA: 2s - loss: 0.2173 - acc: 0.9688
4/7 [================>.............] - ETA: 1s - loss: 0.2122 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.2003 - acc: 0.9688
6/7 [========================>.....] - ETA: 0s - loss: 0.1896 - acc: 0.9740
7/7 [==============================] - 9s 1s/step - loss: 0.1835 - acc: 0.9773 - val_loss: 0.0678 - val_acc: 1.0000

Epoch 4/10
1/7 [===>..........................] - ETA: 1s - loss: 0.2051 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.1652 - acc: 0.9844
3/7 [===========>..................] - ETA: 2s - loss: 0.1423 - acc: 0.9896
4/7 [================>.............] - ETA: 1s - loss: 0.1289 - acc: 0.9922
5/7 [====================>.........] - ETA: 1s - loss: 0.1225 - acc: 0.9938
6/7 [========================>.....] - ETA: 0s - loss: 0.1149 - acc: 0.9948
7/7 [==============================] - 9s 1s/step - loss: 0.1060 - acc: 0.9955 - val_loss: 0.0455 - val_acc: 1.0000

Epoch 5/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0769 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.0846 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0797 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0736 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0914 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0858 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0808 - acc: 1.0000 - val_loss: 0.0346 - val_acc: 1.0000

Epoch 6/10
1/7 [===>..........................] - ETA: 1s - loss: 0.1267 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.1039 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0893 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0780 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0758 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0789 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0738 - acc: 1.0000 - val_loss: 0.0248 - val_acc: 1.0000

Epoch 7/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0344 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0385 - acc: 1.0000
3/7 [===========>..................] - ETA: 3s - loss: 0.0467 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0445 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0446 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0429 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0421 - acc: 1.0000 - val_loss: 0.0202 - val_acc: 1.0000

Epoch 8/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0319 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0300 - acc: 1.0000
3/7 [===========>..................] - ETA: 3s - loss: 0.0320 - acc: 1.0000
4/7 [================>.............] - ETA: 2s - loss: 0.0307 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0303 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0291 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0358 - acc: 1.0000 - val_loss: 0.0167 - val_acc: 1.0000

Epoch 9/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0246 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0255 - acc: 1.0000
3/7 [===========>..................] - ETA: 3s - loss: 0.0258 - acc: 1.0000
4/7 [================>.............] - ETA: 2s - loss: 0.0250 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0252 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0260 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0327 - acc: 1.0000 - val_loss: 0.0143 - val_acc: 1.0000

Epoch 10/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0251 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.0228 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0217 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0249 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0244 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0239 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0290 - acc: 1.0000 - val_loss: 0.0127 - val_acc: 1.0000

DYNAMIC LEARNING_PHASE
[0.012697912137955427, 1.0]

STATIC LEARNING_PHASE = 0
[0.012697912137955427, 1.0]

STATIC LEARNING_PHASE = 1
[0.01744014158844948, 1.0]

Prima di tutto, osserviamo che la rete converge molto più velocemente e raggiunge una precisione perfetta. Vediamo anche che non c'è più discrepanza in termini di accuratezza quando passiamo tra diversi valori di fase_di apprendimento.

2.5 Come si comporta la patch su un set di dati reale?

Quindi come si comporta la patch su un esperimento più realistico? Usiamo ResNet50 pre-addestrato di Keras (originariamente adatto a imagenet), rimuoviamo lo strato di classificazione superiore e ottimizzalo con e senza la patch e confrontiamo i risultati. Per i dati, useremo CIFAR10 (lo standard train / test split fornito da Keras) e ridimensioneremo le immagini a 224 × 224 per renderle compatibili con le dimensioni di input di ResNet50.

Faremo 10 epoche per addestrare il livello di classificazione superiore utilizzando RSMprop e poi ne faremo altre 5 per mettere a punto tutto dopo il 139 ° livello utilizzando SGD (lr = 1e-4, momentum = 0.9). Senza la patch il nostro modello raggiunge una precisione dell'87.44%. Usando la patch, otteniamo una precisione del 92.36%, quasi 5 punti in più.

2.6 Dobbiamo applicare la stessa correzione ad altri livelli come Dropout?

La normalizzazione batch non è l'unico livello che opera in modo diverso tra le modalità treno e test. Anche il dropout e le sue varianti hanno lo stesso effetto. Dovremmo applicare la stessa politica a tutti questi livelli? Credo di no (anche se mi piacerebbe sentire i tuoi pensieri su questo). Il motivo è che Dropout viene utilizzato per evitare l'overfitting, quindi bloccarlo permanentemente in modalità di previsione durante l'allenamento annullerebbe il suo scopo. Cosa pensi?

Credo fermamente che questa discrepanza debba essere risolta a Keras. Ho visto effetti ancora più profondi (dal 100% al 50% di precisione) nelle applicazioni del mondo reale causati da questo problema. io prevede di inviare già inviato un PR a Keras con la correzione e si spera che venga accettata.

Se ti è piaciuto questo post del blog, prenditi un momento per condividerlo su Facebook o Twitter. 🙂

Timestamp:

Di più da Databox