A camada de normalização em lote de Keras está quebrada PlatoBlockchain Data Intelligence. Pesquisa vertical. Ai.

A camada Normalização de lote do Keras está quebrada

ATUALIZAÇÃO: Infelizmente meu Pull-Request para Keras que alterou o comportamento da camada de normalização em lote não foi aceito. Você pode ler os detalhes SUA PARTICIPAÇÃO FAZ A DIFERENÇA. Para aqueles de vocês que são corajosos o suficiente para mexer com implementações personalizadas, você pode encontrar o código em meu ramo. Posso mantê-lo e mesclá-lo com a última versão estável do Keras (2.1.6, 2.2.2 e 2.2.4) pelo tempo que eu usar, mas sem promessas.

A maioria das pessoas que trabalham com Deep Learning já usou ou ouviu falar Keras. Para aqueles que não o fizeram, é uma ótima biblioteca que abstrai os frameworks de Deep Learning subjacentes, como TensorFlow, Theano e CNTK e fornece um API de alto nível para treinamento de RNAs. É fácil de usar, permite prototipagem rápida e tem uma comunidade ativa amigável. Eu tenho usado muito e contribuído com o projeto periodicamente há algum tempo e definitivamente o recomendo para quem quer trabalhar com Deep Learning.

Embora Keras tenha facilitado minha vida, muitas vezes fui mordido pelo comportamento estranho da camada de Normalização em lote. Seu comportamento padrão mudou ao longo do tempo, no entanto, ainda causa problemas para muitos usuários e, como resultado, existem vários questões em aberto no Github. Nesta postagem do blog, tentarei construir um caso de por que a camada BatchNormalization de Keras não funciona bem com o Transfer Learning. Fornecerei o código que corrige o problema e darei exemplos com os resultados do remendo.

Nas subseções abaixo, eu forneço uma introdução sobre como o Transfer Learning é usado no Deep Learning, o que é a camada Batch Normalization, como o learnining_phase funciona e como Keras mudou o comportamento do BN ao longo do tempo. Se você já os conhece, pode pular com segurança diretamente para a seção 2.

1.1 Usar o Transfer Learning é crucial para o Deep Learning

Uma das razões pelas quais o Deep Learning foi criticado no passado é que ele requer muitos dados. Isto não é sempre verdade; Existem várias técnicas para lidar com essa limitação, uma das quais é o Transfer Learning.

Suponha que você esteja trabalhando em um aplicativo Computer Vision e deseja construir um classificador que diferencie gatos de cães. Na verdade, você não precisa de milhões de imagens de gato / cachorro para treinar o modelo. Em vez disso, você pode usar um classificador pré-treinado e ajustar as principais convoluções com menos dados. A ideia por trás disso é que, uma vez que o modelo pré-treinado foi ajustado às imagens, as convoluções inferiores podem reconhecer recursos como linhas, bordas e outros padrões úteis, o que significa que você pode usar seus pesos como bons valores de inicialização ou retreinar parcialmente a rede com seus dados .
A camada de normalização em lote de Keras está quebrada PlatoBlockchain Data Intelligence. Pesquisa vertical. Ai.
Keras vem com vários modelos pré-treinados e exemplos fáceis de usar sobre como ajustar os modelos. Você pode ler mais no documentação.

1.2 O que é a camada de normalização em lote?

A camada de normalização em lote foi introduzida em 2014 por Ioffe e Szegedy. Ele aborda o problema do gradiente de desaparecimento padronizando a saída da camada anterior, acelera o treinamento reduzindo o número de iterações necessárias e permite o treinamento de redes neurais mais profundas. Explicar exatamente como funciona está além do escopo desta postagem, mas eu recomendo fortemente que você leia o Papel original. Uma explicação simplificada é que ele redimensiona a entrada subtraindo sua média e dividindo com seu desvio padrão; ele também pode aprender a desfazer a transformação, se necessário.
A camada de normalização em lote de Keras está quebrada PlatoBlockchain Data Intelligence. Pesquisa vertical. Ai.

1.3 O que é learning_phase em Keras?

Algumas camadas funcionam de maneira diferente durante o treinamento e o modo de inferência. Os exemplos mais notáveis ​​são as camadas Normalização em lote e Eliminação. No caso de BN, durante o treinamento usamos a média e a variância do minilote para redimensionar a entrada. Por outro lado, durante a inferência, usamos a média móvel e a variância que foi estimada durante o treinamento.

Keras sabe em qual modo executar porque tem um mecanismo integrado chamado fase_de_prendizagem. A fase de aprendizagem controla se a rede está em modo de trem ou teste. Se não for definido manualmente pelo usuário, durante fit () a rede funciona com learning_phase = 1 (modo de trem). Durante a produção de previsões (por exemplo, quando chamamos os métodos Predict () & Evalu () ou na etapa de validação do fit ()), a rede funciona com learning_phase = 0 (modo de teste). Mesmo que não seja recomendado, o usuário também é capaz de alterar estaticamente o learning_phase para um valor específico, mas isso precisa acontecer antes que qualquer modelo ou tensor seja adicionado ao gráfico. Se learning_phase for definido estaticamente, Keras será bloqueado para qualquer modo que o usuário selecionou.

1.4 Como Keras implementou a normalização em lote ao longo do tempo?

Keras mudou o comportamento da normalização em lote várias vezes, mas a atualização significativa mais recente aconteceu no Keras 2.1.3. Antes da v2.1.3, quando a camada BN estava congelada (treinável = Falso), ela atualizava suas estatísticas de lote, algo que causava dores de cabeça épicas aos usuários.

Esta não era apenas uma política estranha, estava realmente errada. Imagine que existe uma camada BN entre as convoluções; se a camada estiver congelada, nenhuma alteração deve acontecer com ela. Se atualizarmos parcialmente seus pesos e as próximas camadas também estiverem congeladas, elas nunca terão a chance de se ajustar às atualizações das estatísticas do minilote, levando a um erro maior. Felizmente, a partir da versão 2.1.3, quando uma camada BN é congelada, ela não atualiza mais suas estatísticas. Mas isso é o suficiente? Não se você estiver usando o Transfer Learning.

Abaixo, descrevo exatamente qual é o problema e esboco a implementação técnica para resolvê-lo. Também forneço alguns exemplos para mostrar os efeitos na precisão do modelo antes e depois do remendo é aplicado.

2.1 Descrição técnica do problema

O problema com a implementação atual do Keras é que quando uma camada BN é congelada, ela continua a usar as estatísticas do minilote durante o treinamento. Eu acredito que uma abordagem melhor quando o BN está congelado é usar a média móvel e a variância que ele aprendeu durante o treinamento. Porque? Pelas mesmas razões pelas quais as estatísticas do minilote não devem ser atualizadas quando a camada está congelada: isso pode levar a resultados ruins porque as próximas camadas não são treinadas corretamente.

Suponha que você esteja construindo um modelo de visão computacional, mas não tenha dados suficientes, então você decide usar um dos CNNs pré-treinados de Keras e ajustá-lo. Infelizmente, ao fazer isso, você não obtém garantias de que a média e a variância de seu novo conjunto de dados dentro das camadas BN serão semelhantes às do conjunto de dados original. Lembre-se que no momento, durante o treinamento, sua rede sempre usará as estatísticas do minilote, esteja a camada BN congelada ou não; também durante a inferência, você usará as estatísticas aprendidas anteriormente das camadas BN congeladas. Como resultado, se você ajustar as camadas superiores, seus pesos serão ajustados para a média / variância do novo conjunto de dados. No entanto, durante a inferência, eles receberão dados que são escalados diferentemente porque a média / variância do original conjunto de dados será usado.
A camada de normalização em lote de Keras está quebrada PlatoBlockchain Data Intelligence. Pesquisa vertical. Ai.
Acima, apresento uma arquitetura simplista (e irreal) para fins de demonstração. Vamos supor que ajustamos o modelo da Convolução k + 1 até o topo da rede (lado direito) e mantemos congelados na parte inferior (lado esquerdo). Durante o treinamento, todas as camadas BN de 1 a k usarão a média / variância de seus dados de treinamento. Isso terá efeitos negativos nos ReLUs congelados se a média e a variância em cada BN não forem próximas às aprendidas durante o pré-treinamento. Também fará com que o resto da rede (de CONV k + 1 e posterior) seja treinado com entradas que têm escalas diferentes em comparação com o que receberá durante a inferência. Durante o treinamento, sua rede pode se adaptar a essas mudanças, no entanto, no momento em que você muda para o modo de previsão, Keras usará diferentes estatísticas de padronização, algo que agilizará a distribuição das entradas das próximas camadas levando a resultados ruins.

2.2 Como você pode detectar se é afetado?

Uma forma de detectá-lo é definir estaticamente a fase de aprendizagem do Keras para 1 (modo de treinamento) e para 0 (modo de teste) e avaliar seu modelo em cada caso. Se houver uma diferença significativa na precisão no mesmo conjunto de dados, você está sendo afetado pelo problema. Vale ressaltar que, devido à forma como o mecanismo learning_phase é implementado no Keras, normalmente não é aconselhável mexer com ele. As alterações em learning_phase não terão efeito nos modelos que já foram compilados e usados; como você pode ver nos exemplos das próximas subseções, a melhor maneira de fazer isso é começar com uma sessão limpa e alterar o learning_phase antes que qualquer tensor seja definido no gráfico.

Outra forma de detectar o problema ao trabalhar com classificadores binários é verificar a precisão e o AUC. Se a precisão for próxima de 50%, mas a AUC estiver próxima de 1 (e você também observar diferenças entre o modo de trem / teste no mesmo conjunto de dados), pode ser que as probabilidades estejam fora da escala devido às estatísticas BN. Da mesma forma, para a regressão, você pode usar o MSE e a correlação de Spearman para detectá-la.

2.3 Como podemos consertar?

Eu acredito que o problema pode ser corrigido se as camadas BN congeladas forem realmente apenas isso: permanentemente bloqueadas no modo de teste. Em termos de implementação, o sinalizador treinável precisa fazer parte do gráfico computacional e o comportamento do BN precisa depender não apenas da fase_de_prendizagem, mas também do valor da propriedade treinável. Você pode encontrar os detalhes da minha implementação em Github.

Ao aplicar a correção acima, quando uma camada BN for congelada, ela não usará mais as estatísticas do minilote, mas sim as aprendidas durante o treinamento. Como resultado, não haverá discrepância entre os modos de treinamento e teste, o que leva a uma maior precisão. Obviamente, quando a camada BN não estiver congelada, ela continuará usando as estatísticas do minilote durante o treinamento.

2.4 Avaliando os efeitos do patch

Embora eu tenha escrito a implementação acima recentemente, a ideia por trás dela foi amplamente testada em problemas do mundo real usando várias soluções alternativas que têm o mesmo efeito. Por exemplo, a discrepância entre os modos de treinamento e teste pode ser evitada dividindo a rede em duas partes (congelada e descongelada) e realizando o treinamento em cache (passando os dados pelo modelo congelado uma vez e, em seguida, usando-os para treinar a rede descongelada). No entanto, como o “acredite em mim, eu já fiz isso antes” normalmente não tem peso, a seguir forneço alguns exemplos que mostram os efeitos da nova implementação na prática.

Aqui estão alguns pontos importantes sobre o experimento:

  1. Vou usar uma pequena quantidade de dados para ajustar intencionalmente o modelo e vou treinar e validar o modelo no mesmo conjunto de dados. Ao fazer isso, espero precisão quase perfeita e desempenho idêntico no conjunto de dados de trem / validação.
  2. Se durante a validação obtiver uma precisão significativamente menor no mesmo conjunto de dados, terei uma indicação clara de que a política BN atual afeta negativamente o desempenho do modelo durante a inferência.
  3. Qualquer pré-processamento ocorrerá fora dos Geradores. Isso é feito para contornar um bug que foi introduzido na v2.1.5 (atualmente corrigido na próxima v2.1.6 e na versão master mais recente).
  4. Forçaremos Keras a usar diferentes fases de aprendizagem durante a avaliação. Se detectarmos diferenças entre a precisão relatada, saberemos que somos afetados pela política atual da BN.

O código do experimento é mostrado abaixo:

import numpy as np
from keras.datasets import cifar10
from scipy.misc import imresize

from keras.preprocessing.image import ImageDataGenerator
from keras.applications.resnet50 import ResNet50, preprocess_input
from keras.models import Model, load_model
from keras.layers import Dense, Flatten
from keras import backend as K


seed = 42
epochs = 10
records_per_class = 100

# We take only 2 classes from CIFAR10 and a very small sample to intentionally overfit the model.
# We will also use the same data for train/test and expect that Keras will give the same accuracy.
(x, y), _ = cifar10.load_data()

def filter_resize(category):
   # We do the preprocessing here instead in the Generator to get around a bug on Keras 2.1.5.
   return [preprocess_input(imresize(img, (224,224)).astype('float')) for img in x[y.flatten()==category][:records_per_class]]

x = np.stack(filter_resize(3)+filter_resize(5))
records_per_class = x.shape[0] // 2
y = np.array([[1,0]]*records_per_class + [[0,1]]*records_per_class)


# We will use a pre-trained model and finetune the top layers.
np.random.seed(seed)
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
l = Flatten()(base_model.output)
predictions = Dense(2, activation='softmax')(l)
model = Model(inputs=base_model.input, outputs=predictions)

for layer in model.layers[:140]:
   layer.trainable = False

for layer in model.layers[140:]:
   layer.trainable = True

model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(ImageDataGenerator().flow(x, y, seed=42), epochs=epochs, validation_data=ImageDataGenerator().flow(x, y, seed=42))

# Store the model on disk
model.save('tmp.h5')


# In every test we will clear the session and reload the model to force Learning_Phase values to change.
print('DYNAMIC LEARNING_PHASE')
K.clear_session()
model = load_model('tmp.h5')
# This accuracy should match exactly the one of the validation set on the last iteration.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))


print('STATIC LEARNING_PHASE = 0')
K.clear_session()
K.set_learning_phase(0)
model = load_model('tmp.h5')
# Again the accuracy should match the above.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))


print('STATIC LEARNING_PHASE = 1')
K.clear_session()
K.set_learning_phase(1)
model = load_model('tmp.h5')
# The accuracy will be close to the one of the training set on the last iteration.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))

Vamos verificar os resultados no Keras v2.1.5:

Epoch 1/10
1/7 [===>..........................] - ETA: 25s - loss: 0.8751 - acc: 0.5312
2/7 [=======>......................] - ETA: 11s - loss: 0.8594 - acc: 0.4531
3/7 [===========>..................] - ETA: 7s - loss: 0.8398 - acc: 0.4688 
4/7 [================>.............] - ETA: 4s - loss: 0.8467 - acc: 0.4844
5/7 [====================>.........] - ETA: 2s - loss: 0.7904 - acc: 0.5437
6/7 [========================>.....] - ETA: 1s - loss: 0.7593 - acc: 0.5625
7/7 [==============================] - 12s 2s/step - loss: 0.7536 - acc: 0.5744 - val_loss: 0.6526 - val_acc: 0.6650

Epoch 2/10
1/7 [===>..........................] - ETA: 4s - loss: 0.3881 - acc: 0.8125
2/7 [=======>......................] - ETA: 3s - loss: 0.3945 - acc: 0.7812
3/7 [===========>..................] - ETA: 2s - loss: 0.3956 - acc: 0.8229
4/7 [================>.............] - ETA: 1s - loss: 0.4223 - acc: 0.8047
5/7 [====================>.........] - ETA: 1s - loss: 0.4483 - acc: 0.7812
6/7 [========================>.....] - ETA: 0s - loss: 0.4325 - acc: 0.7917
7/7 [==============================] - 8s 1s/step - loss: 0.4095 - acc: 0.8089 - val_loss: 0.4722 - val_acc: 0.7700

Epoch 3/10
1/7 [===>..........................] - ETA: 4s - loss: 0.2246 - acc: 0.9375
2/7 [=======>......................] - ETA: 3s - loss: 0.2167 - acc: 0.9375
3/7 [===========>..................] - ETA: 2s - loss: 0.2260 - acc: 0.9479
4/7 [================>.............] - ETA: 2s - loss: 0.2179 - acc: 0.9375
5/7 [====================>.........] - ETA: 1s - loss: 0.2356 - acc: 0.9313
6/7 [========================>.....] - ETA: 0s - loss: 0.2392 - acc: 0.9427
7/7 [==============================] - 8s 1s/step - loss: 0.2288 - acc: 0.9456 - val_loss: 0.4282 - val_acc: 0.7800

Epoch 4/10
1/7 [===>..........................] - ETA: 4s - loss: 0.2183 - acc: 0.9688
2/7 [=======>......................] - ETA: 3s - loss: 0.1899 - acc: 0.9844
3/7 [===========>..................] - ETA: 2s - loss: 0.1887 - acc: 0.9792
4/7 [================>.............] - ETA: 1s - loss: 0.1995 - acc: 0.9531
5/7 [====================>.........] - ETA: 1s - loss: 0.1932 - acc: 0.9625
6/7 [========================>.....] - ETA: 0s - loss: 0.1819 - acc: 0.9688
7/7 [==============================] - 8s 1s/step - loss: 0.1743 - acc: 0.9747 - val_loss: 0.3778 - val_acc: 0.8400

Epoch 5/10
1/7 [===>..........................] - ETA: 3s - loss: 0.0973 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0828 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0851 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0897 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0928 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0936 - acc: 1.0000
7/7 [==============================] - 8s 1s/step - loss: 0.1337 - acc: 0.9838 - val_loss: 0.3916 - val_acc: 0.8100

Epoch 6/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0747 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0852 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0812 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0831 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0779 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0766 - acc: 1.0000
7/7 [==============================] - 8s 1s/step - loss: 0.0813 - acc: 1.0000 - val_loss: 0.3637 - val_acc: 0.8550

Epoch 7/10
1/7 [===>..........................] - ETA: 1s - loss: 0.2478 - acc: 0.8750
2/7 [=======>......................] - ETA: 2s - loss: 0.1966 - acc: 0.9375
3/7 [===========>..................] - ETA: 2s - loss: 0.1528 - acc: 0.9583
4/7 [================>.............] - ETA: 1s - loss: 0.1300 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.1193 - acc: 0.9750
6/7 [========================>.....] - ETA: 0s - loss: 0.1196 - acc: 0.9792
7/7 [==============================] - 8s 1s/step - loss: 0.1084 - acc: 0.9838 - val_loss: 0.3546 - val_acc: 0.8600

Epoch 8/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0539 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.0900 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0815 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0740 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0700 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0701 - acc: 1.0000
7/7 [==============================] - 8s 1s/step - loss: 0.0695 - acc: 1.0000 - val_loss: 0.3269 - val_acc: 0.8600

Epoch 9/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0306 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0377 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0898 - acc: 0.9583
4/7 [================>.............] - ETA: 1s - loss: 0.0773 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.0742 - acc: 0.9750
6/7 [========================>.....] - ETA: 0s - loss: 0.0708 - acc: 0.9792
7/7 [==============================] - 8s 1s/step - loss: 0.0659 - acc: 0.9838 - val_loss: 0.3604 - val_acc: 0.8600

Epoch 10/10
1/7 [===>..........................] - ETA: 3s - loss: 0.0354 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0381 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0354 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0828 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.0791 - acc: 0.9750
6/7 [========================>.....] - ETA: 0s - loss: 0.0794 - acc: 0.9792
7/7 [==============================] - 8s 1s/step - loss: 0.0704 - acc: 0.9838 - val_loss: 0.3615 - val_acc: 0.8600

DYNAMIC LEARNING_PHASE
[0.3614931714534759, 0.86]

STATIC LEARNING_PHASE = 0
[0.3614931714534759, 0.86]

STATIC LEARNING_PHASE = 1
[0.025861846953630446, 1.0]

Como podemos ver acima, durante o treinamento, o modelo aprende muito bem os dados e atinge uma precisão quase perfeita no conjunto de treinamento. Ainda no final de cada iteração, ao avaliar o modelo no mesmo conjunto de dados, obtemos diferenças significativas de perda e precisão. Observe que não devemos receber isso; ajustamos intencionalmente o modelo no conjunto de dados específico e os conjuntos de dados de treinamento / validação são idênticos.

Após a conclusão do treinamento, avaliamos o modelo usando 3 configurações diferentes de learning_phase: Dynamic, Static = 0 (modo de teste) e Static = 1 (modo de treinamento). Como podemos ver, as duas primeiras configurações fornecerão resultados idênticos em termos de perda e precisão e seu valor corresponde à precisão relatada do modelo no conjunto de validação na última iteração. No entanto, uma vez que mudamos para o modo de treinamento, observamos uma grande discrepância (melhoria). Por que isso? Como dissemos anteriormente, os pesos da rede são ajustados esperando receber dados escalonados com a média / variância dos dados de treinamento. Infelizmente, essas estatísticas são diferentes das armazenadas nas camadas BN. Como as camadas BN foram congeladas, essas estatísticas nunca foram atualizadas. Essa discrepância entre os valores das estatísticas BN leva à deterioração da precisão durante a inferência.

Vamos ver o que acontece quando aplicarmos o remendo:

Epoch 1/10
1/7 [===>..........................] - ETA: 26s - loss: 0.9992 - acc: 0.4375
2/7 [=======>......................] - ETA: 12s - loss: 1.0534 - acc: 0.4375
3/7 [===========>..................] - ETA: 7s - loss: 1.0592 - acc: 0.4479 
4/7 [================>.............] - ETA: 4s - loss: 0.9618 - acc: 0.5000
5/7 [====================>.........] - ETA: 2s - loss: 0.8933 - acc: 0.5250
6/7 [========================>.....] - ETA: 1s - loss: 0.8638 - acc: 0.5417
7/7 [==============================] - 13s 2s/step - loss: 0.8357 - acc: 0.5570 - val_loss: 0.2414 - val_acc: 0.9450

Epoch 2/10
1/7 [===>..........................] - ETA: 4s - loss: 0.2331 - acc: 0.9688
2/7 [=======>......................] - ETA: 2s - loss: 0.3308 - acc: 0.8594
3/7 [===========>..................] - ETA: 2s - loss: 0.3986 - acc: 0.8125
4/7 [================>.............] - ETA: 1s - loss: 0.3721 - acc: 0.8281
5/7 [====================>.........] - ETA: 1s - loss: 0.3449 - acc: 0.8438
6/7 [========================>.....] - ETA: 0s - loss: 0.3168 - acc: 0.8646
7/7 [==============================] - 9s 1s/step - loss: 0.3165 - acc: 0.8633 - val_loss: 0.1167 - val_acc: 0.9950

Epoch 3/10
1/7 [===>..........................] - ETA: 1s - loss: 0.2457 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.2592 - acc: 0.9688
3/7 [===========>..................] - ETA: 2s - loss: 0.2173 - acc: 0.9688
4/7 [================>.............] - ETA: 1s - loss: 0.2122 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.2003 - acc: 0.9688
6/7 [========================>.....] - ETA: 0s - loss: 0.1896 - acc: 0.9740
7/7 [==============================] - 9s 1s/step - loss: 0.1835 - acc: 0.9773 - val_loss: 0.0678 - val_acc: 1.0000

Epoch 4/10
1/7 [===>..........................] - ETA: 1s - loss: 0.2051 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.1652 - acc: 0.9844
3/7 [===========>..................] - ETA: 2s - loss: 0.1423 - acc: 0.9896
4/7 [================>.............] - ETA: 1s - loss: 0.1289 - acc: 0.9922
5/7 [====================>.........] - ETA: 1s - loss: 0.1225 - acc: 0.9938
6/7 [========================>.....] - ETA: 0s - loss: 0.1149 - acc: 0.9948
7/7 [==============================] - 9s 1s/step - loss: 0.1060 - acc: 0.9955 - val_loss: 0.0455 - val_acc: 1.0000

Epoch 5/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0769 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.0846 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0797 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0736 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0914 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0858 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0808 - acc: 1.0000 - val_loss: 0.0346 - val_acc: 1.0000

Epoch 6/10
1/7 [===>..........................] - ETA: 1s - loss: 0.1267 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.1039 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0893 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0780 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0758 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0789 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0738 - acc: 1.0000 - val_loss: 0.0248 - val_acc: 1.0000

Epoch 7/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0344 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0385 - acc: 1.0000
3/7 [===========>..................] - ETA: 3s - loss: 0.0467 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0445 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0446 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0429 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0421 - acc: 1.0000 - val_loss: 0.0202 - val_acc: 1.0000

Epoch 8/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0319 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0300 - acc: 1.0000
3/7 [===========>..................] - ETA: 3s - loss: 0.0320 - acc: 1.0000
4/7 [================>.............] - ETA: 2s - loss: 0.0307 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0303 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0291 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0358 - acc: 1.0000 - val_loss: 0.0167 - val_acc: 1.0000

Epoch 9/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0246 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0255 - acc: 1.0000
3/7 [===========>..................] - ETA: 3s - loss: 0.0258 - acc: 1.0000
4/7 [================>.............] - ETA: 2s - loss: 0.0250 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0252 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0260 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0327 - acc: 1.0000 - val_loss: 0.0143 - val_acc: 1.0000

Epoch 10/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0251 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.0228 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0217 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0249 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0244 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0239 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0290 - acc: 1.0000 - val_loss: 0.0127 - val_acc: 1.0000

DYNAMIC LEARNING_PHASE
[0.012697912137955427, 1.0]

STATIC LEARNING_PHASE = 0
[0.012697912137955427, 1.0]

STATIC LEARNING_PHASE = 1
[0.01744014158844948, 1.0]

Em primeiro lugar, observamos que a rede converge significativamente mais rápido e atinge uma precisão perfeita. Também vemos que não há mais discrepância em termos de precisão quando alternamos entre os diferentes valores de learning_phase.

2.5 Como o patch funciona em um conjunto de dados real?

Então, como o patch funciona em um experimento mais realista? Vamos usar o ResNet50 pré-treinado de Keras (originalmente ajustado ao imagenet), remover a camada de classificação superior e ajustá-la com e sem o patch e comparar os resultados. Para os dados, usaremos CIFAR10 (a divisão de teste / trem padrão fornecida por Keras) e redimensionaremos as imagens para 224 × 224 para torná-las compatíveis com o tamanho de entrada do ResNet50.

Faremos 10 épocas para treinar a camada de classificação superior usando RSMprop e depois faremos mais 5 para ajustar tudo após a camada 139 usando SGD (lr = 1e-4, momentum = 0.9). Sem o patch, nosso modelo atinge uma precisão de 87.44%. Usando o patch, obtemos uma precisão de 92.36%, quase 5 pontos a mais.

2.6 Devemos aplicar a mesma correção a outras camadas, como Dropout?

A normalização em lote não é a única camada que opera de maneira diferente entre os modos de trem e teste. O abandono e suas variantes também têm o mesmo efeito. Devemos aplicar a mesma política a todas essas camadas? Eu não acredito (embora eu adorasse ouvir seus pensamentos sobre isso). O motivo é que Dropout é usado para evitar overfitting, portanto, bloqueá-lo permanentemente no modo de previsão durante o treinamento anularia seu propósito. O que você acha?

Acredito fortemente que essa discrepância deve ser resolvida em Keras. Já vi efeitos ainda mais profundos (de 100% até 50% de precisão) em aplicativos do mundo real causados ​​por esse problema. eu plano para enviar já enviou um PR para Keras com a correção e esperançosamente será aceita.

Se você gostou desta postagem do blog, reserve um momento para compartilhá-la no Facebook ou Twitter. 🙂

Carimbo de hora:

Mais de Caixa de dados