Lớp Chuẩn hóa hàng loạt của Keras bị hỏng Thông minh dữ liệu PlatoBlockchain. Tìm kiếm dọc. Ái.

Lớp Batch Chuẩn hóa của Keras bị hỏng

CẬP NHẬT: Thật không may, Yêu cầu kéo tới Keras của tôi đã thay đổi hành vi của lớp Chuẩn hóa hàng loạt không được chấp nhận. Bạn có thể đọc chi tiết tại đây. Đối với những người đủ can đảm để gặp rắc rối với việc triển khai tùy chỉnh, bạn có thể tìm thấy mã trong chi nhánh của tôi. Tôi có thể duy trì nó và hợp nhất nó với phiên bản ổn định mới nhất của Keras (2.1.6, 2.2.22.2.4) miễn là tôi sử dụng nó nhưng không có lời hứa nào.

Hầu hết những người làm việc trong lĩnh vực Deep Learning đều đã sử dụng hoặc nghe nói đến Máy ảnh. Đối với những người chưa biết, đây là một thư viện tuyệt vời tóm tắt các khung Deep Learning cơ bản như TensorFlow, Theano và CNTK, đồng thời cung cấp một API cấp cao để huấn luyện ANN. Nó rất dễ sử dụng, cho phép tạo mẫu nhanh và có một cộng đồng năng động thân thiện. Tôi đã sử dụng nó rất nhiều và đóng góp định kỳ cho dự án trong một thời gian khá dài và tôi chắc chắn sẽ giới thiệu nó cho bất kỳ ai muốn làm việc về Deep Learning.

Mặc dù Keras đã làm cho cuộc sống của tôi dễ dàng hơn nhưng khá nhiều lần tôi vẫn bị ảnh hưởng bởi hành vi kỳ quặc của lớp Chuẩn hóa hàng loạt. Hành vi mặc định của nó đã thay đổi theo thời gian, tuy nhiên nó vẫn gây ra sự cố cho nhiều người dùng và kết quả là có một số lỗi liên quan. vấn đề mở trên Github. Trong bài đăng trên blog này, tôi sẽ cố gắng giải thích lý do tại sao lớp BatchNormalization của Keras không hoạt động tốt với Transfer Learning, tôi sẽ cung cấp mã để khắc phục sự cố và tôi sẽ đưa ra ví dụ về kết quả của bản vá.

Trong các phần phụ bên dưới, tôi giới thiệu về cách sử dụng Transfer Learning trong Deep Learning, lớp Chuẩn hóa hàng loạt là gì, cách learning_phase hoạt động và cách Keras thay đổi hành vi BN theo thời gian. Nếu bạn đã biết những điều này, bạn có thể chuyển thẳng sang phần 2 một cách an toàn.

1.1 Sử dụng Transfer Learning là rất quan trọng đối với Deep Learning

Một trong những lý do khiến Deep Learning bị chỉ trích trước đây là vì nó đòi hỏi quá nhiều dữ liệu. Điêu nay không phải luc nao cung đung; có một số kỹ thuật để giải quyết hạn chế này, một trong số đó là Học chuyển giao.

Giả sử rằng bạn đang làm việc trên một ứng dụng Thị giác Máy tính và bạn muốn xây dựng một bộ phân loại để phân biệt Mèo với Chó. Bạn thực sự không cần hàng triệu hình ảnh mèo/chó để huấn luyện mô hình. Thay vào đó, bạn có thể sử dụng trình phân loại được đào tạo trước và tinh chỉnh các phần tích chập trên cùng với ít dữ liệu hơn. Ý tưởng đằng sau nó là vì mô hình được đào tạo trước phù hợp với hình ảnh nên các phần tích chập phía dưới có thể nhận ra các đặc điểm như đường, cạnh và các mẫu hữu ích khác, nghĩa là bạn có thể sử dụng trọng số của nó làm giá trị khởi tạo tốt hoặc đào tạo lại một phần mạng bằng dữ liệu của mình .
Lớp Chuẩn hóa hàng loạt của Keras bị hỏng Thông minh dữ liệu PlatoBlockchain. Tìm kiếm dọc. Ái.
Keras đi kèm với một số mô hình được đào tạo trước và các ví dụ dễ sử dụng về cách tinh chỉnh mô hình. Bạn có thể đọc thêm trên tài liệu hướng dẫn.

1.2 Lớp chuẩn hóa hàng loạt là gì?

Lớp Chuẩn hóa hàng loạt được Ioffe và Szegedy giới thiệu vào năm 2014. Nó giải quyết vấn đề độ dốc biến mất bằng cách chuẩn hóa đầu ra của lớp trước, nó tăng tốc quá trình đào tạo bằng cách giảm số lần lặp cần thiết và cho phép đào tạo các mạng lưới thần kinh sâu hơn. Việc giải thích chính xác cách thức hoạt động của nó nằm ngoài phạm vi của bài viết này nhưng tôi thực sự khuyến khích bạn đọc phần này. bản gốc. Một lời giải thích đơn giản hóa quá mức là nó định lại tỷ lệ đầu vào bằng cách trừ giá trị trung bình của nó và bằng cách chia cho độ lệch chuẩn của nó; nó cũng có thể học cách hoàn tác việc chuyển đổi nếu cần thiết.
Lớp Chuẩn hóa hàng loạt của Keras bị hỏng Thông minh dữ liệu PlatoBlockchain. Tìm kiếm dọc. Ái.

1.3 Learning_phase trong Keras là gì?

Một số lớp hoạt động khác nhau trong chế độ huấn luyện và suy luận. Các ví dụ đáng chú ý nhất là các lớp Chuẩn hóa hàng loạt và Lớp bỏ học. Trong trường hợp BN, trong quá trình đào tạo, chúng tôi sử dụng giá trị trung bình và phương sai của lô nhỏ để thay đổi tỷ lệ đầu vào. Mặt khác, trong quá trình suy luận, chúng tôi sử dụng đường trung bình động và phương sai được ước tính trong quá trình huấn luyện.

Keras biết nên chạy ở chế độ nào vì nó có cơ chế tích hợp gọi là giai đoạn học tập. Giai đoạn học kiểm soát xem mạng đang ở chế độ huấn luyện hay thử nghiệm. Nếu người dùng không đặt thủ công thì trong fit() mạng sẽ chạy với learning_phase=1 (chế độ đào tạo). Trong khi đưa ra dự đoán (ví dụ: khi chúng ta gọi các phương thức dự đoán() & đánh giá() hoặc ở bước xác thực của fit()), mạng sẽ chạy với learning_phase=0 (chế độ kiểm tra). Mặc dù không được khuyến nghị nhưng người dùng cũng có thể thay đổi tĩnh learning_phase thành một giá trị cụ thể nhưng điều này cần phải xảy ra trước khi bất kỳ mô hình hoặc tensor nào được thêm vào biểu đồ. Nếu learning_phase được đặt tĩnh, Keras sẽ bị khóa ở bất kỳ chế độ nào mà người dùng đã chọn.

1.4 Keras đã triển khai Chuẩn hóa hàng loạt theo thời gian như thế nào?

Keras đã nhiều lần thay đổi hành vi Chuẩn hóa hàng loạt nhưng bản cập nhật quan trọng gần đây nhất xảy ra trong Keras 2.1.3. Trước phiên bản 2.1.3, khi lớp BN bị đóng băng (có thể huấn luyện = Sai), nó liên tục cập nhật số liệu thống kê hàng loạt, điều này khiến người dùng phải đau đầu.

Đây không chỉ là một chính sách kỳ lạ mà nó còn thực sự sai lầm. Hãy tưởng tượng rằng một lớp BN tồn tại giữa các tổ hợp; nếu lớp bị đóng băng thì sẽ không có thay đổi nào xảy ra với nó. Nếu chúng tôi cập nhật một phần trọng số của nó và các lớp tiếp theo cũng bị đóng băng, chúng sẽ không bao giờ có cơ hội điều chỉnh các bản cập nhật của số liệu thống kê lô nhỏ dẫn đến sai số cao hơn. Rất may, bắt đầu từ phiên bản 2.1.3, khi một lớp BN bị đóng băng, nó không còn cập nhật số liệu thống kê nữa. Nhưng thế đã đủ chưa? Không nếu bạn đang sử dụng Transfer Learning.

Dưới đây tôi mô tả chính xác vấn đề là gì và tôi phác thảo cách triển khai kỹ thuật để giải quyết vấn đề đó. Tôi cũng cung cấp một vài ví dụ để cho thấy những ảnh hưởng đến độ chính xác của mô hình trước và sau bản vá được áp dụng.

2.1 Mô tả kỹ thuật của vấn đề

Vấn đề với việc triển khai Keras hiện tại là khi lớp BN bị đóng băng, nó vẫn tiếp tục sử dụng số liệu thống kê lô nhỏ trong quá trình đào tạo. Tôi tin rằng cách tiếp cận tốt hơn khi BN bị cố định là sử dụng giá trị trung bình động và phương sai mà nó đã học được trong quá trình huấn luyện. Tại sao? Vì những lý do tương tự tại sao không nên cập nhật số liệu thống kê theo lô nhỏ khi lớp bị đóng băng: nó có thể dẫn đến kết quả kém do các lớp tiếp theo không được đào tạo đúng cách.

Giả sử bạn đang xây dựng mô hình Thị giác máy tính nhưng không có đủ dữ liệu, vì vậy bạn quyết định sử dụng một trong các CNN được đào tạo trước của Keras và tinh chỉnh nó. Thật không may, khi làm như vậy, bạn không có gì đảm bảo rằng giá trị trung bình và phương sai của tập dữ liệu mới bên trong các lớp BN sẽ giống với giá trị trung bình và phương sai của tập dữ liệu gốc. Hãy nhớ rằng hiện tại, trong quá trình đào tạo, mạng của bạn sẽ luôn sử dụng số liệu thống kê theo lô nhỏ cho dù lớp BN có bị đóng băng hay không; Ngoài ra, trong quá trình suy luận, bạn sẽ sử dụng số liệu thống kê đã học trước đó về các lớp BN cố định. Kết quả là, nếu bạn tinh chỉnh các lớp trên cùng, trọng số của chúng sẽ được điều chỉnh theo giá trị trung bình/phương sai của mới tập dữ liệu. Tuy nhiên, trong quá trình suy luận, họ sẽ nhận được dữ liệu được chia tỷ lệ khác nhau bởi vì giá trị trung bình/phương sai của nguyên tập dữ liệu sẽ được sử dụng.
Lớp Chuẩn hóa hàng loạt của Keras bị hỏng Thông minh dữ liệu PlatoBlockchain. Tìm kiếm dọc. Ái.
Ở trên tôi cung cấp một kiến ​​trúc đơn giản (và không thực tế) cho mục đích trình diễn. Giả sử rằng chúng tôi tinh chỉnh mô hình từ Convolution k+1 trở lên cho đến phần trên cùng của mạng (phía bên phải) và chúng tôi tiếp tục cố định phần dưới cùng (phía bên trái). Trong quá trình huấn luyện, tất cả các lớp BN từ 1 đến k sẽ sử dụng giá trị trung bình/phương sai của dữ liệu huấn luyện của bạn. Điều này sẽ có tác động tiêu cực đến các ReLU bị đóng băng nếu giá trị trung bình và phương sai trên mỗi BN không gần với giá trị đã học trong quá trình đào tạo trước. Nó cũng sẽ khiến phần còn lại của mạng (từ CONV k+1 trở lên) được đào tạo với các đầu vào có tỷ lệ khác nhau so với những gì sẽ nhận được trong quá trình suy luận. Trong quá trình đào tạo, mạng của bạn có thể thích ứng với những thay đổi này, tuy nhiên, khi bạn chuyển sang chế độ dự đoán, Keras sẽ sử dụng các số liệu thống kê tiêu chuẩn hóa khác nhau, điều gì đó sẽ nhanh chóng phân phối đầu vào của các lớp tiếp theo dẫn đến kết quả kém.

2.2 Làm thế nào bạn có thể phát hiện nếu bạn bị ảnh hưởng?

Một cách để phát hiện nó là đặt tĩnh giai đoạn học tập của Keras thành 1 (chế độ đào tạo) và 0 (chế độ kiểm tra) và đánh giá mô hình của bạn trong từng trường hợp. Nếu có sự khác biệt đáng kể về độ chính xác trên cùng một tập dữ liệu thì bạn đang bị ảnh hưởng bởi sự cố. Cần chỉ ra rằng, do cách triển khai cơ chế learning_phase trong Keras, bạn thường không nên làm phiền nó. Những thay đổi trên learning_phase sẽ không ảnh hưởng đến các mô hình đã được biên dịch và sử dụng; như bạn có thể thấy trong các ví dụ ở các phần phụ tiếp theo, cách tốt nhất để thực hiện việc này là bắt đầu với một phiên sạch và thay đổi learning_phase trước khi bất kỳ tenxơ nào được xác định trong biểu đồ.

Một cách khác để phát hiện sự cố khi làm việc với bộ phân loại nhị phân là kiểm tra độ chính xác và AUC. Nếu độ chính xác gần 50% nhưng AUC gần bằng 1 (và bạn cũng quan sát thấy sự khác biệt giữa chế độ huấn luyện/kiểm tra trên cùng một tập dữ liệu), thì có thể xác suất nằm ngoài thang đo do thống kê BN. Tương tự, để hồi quy, bạn có thể sử dụng mối tương quan của MSE và Spearman để phát hiện nó.

2.3 Chúng ta có thể khắc phục nó như thế nào?

Tôi tin rằng vấn đề có thể được khắc phục nếu các lớp BN bị đóng băng thực sự chỉ như vậy: bị khóa vĩnh viễn ở chế độ thử nghiệm. Về mặt triển khai, cờ có thể huấn luyện cần phải là một phần của biểu đồ tính toán và hoạt động của BN không chỉ phụ thuộc vào learning_phase mà còn phụ thuộc vào giá trị của thuộc tính có thể huấn luyện. Bạn có thể tìm thấy chi tiết việc triển khai của tôi trên Github.

Bằng cách áp dụng cách khắc phục ở trên, khi lớp BN bị đóng băng, nó sẽ không còn sử dụng số liệu thống kê theo lô nhỏ nữa mà thay vào đó sử dụng số liệu thống kê đã học được trong quá trình đào tạo. Do đó, sẽ không có sự khác biệt giữa chế độ luyện tập và kiểm tra, dẫn đến độ chính xác tăng lên. Rõ ràng khi lớp BN không bị đóng băng, nó sẽ tiếp tục sử dụng số liệu thống kê lô nhỏ trong quá trình đào tạo.

2.4 Đánh giá tác dụng của bản vá

Mặc dù gần đây tôi đã viết cách triển khai ở trên nhưng ý tưởng đằng sau nó đã được thử nghiệm rất nhiều trên các vấn đề trong thế giới thực bằng cách sử dụng nhiều cách giải quyết khác nhau có cùng tác dụng. Ví dụ: có thể tránh được sự khác biệt giữa chế độ huấn luyện và thử nghiệm bằng cách chia mạng thành hai phần (đóng băng và không đóng băng) và thực hiện đào tạo trong bộ nhớ đệm (chuyển dữ liệu qua mô hình đã được cố định một lần rồi sử dụng chúng để huấn luyện mạng không bị đóng băng). Tuy nhiên, vì câu nói “hãy tin tôi, tôi đã từng làm điều này trước đây” thường không có trọng lượng nên dưới đây tôi sẽ cung cấp một số ví dụ cho thấy tác động của việc triển khai mới trong thực tế.

Dưới đây là một số điểm quan trọng về thí nghiệm:

  1. Tôi sẽ sử dụng một lượng nhỏ dữ liệu để cố tình điều chỉnh quá mức mô hình và tôi sẽ đào tạo & xác thực mô hình trên cùng một tập dữ liệu. Bằng cách đó, tôi mong đợi độ chính xác gần như hoàn hảo và hiệu suất giống hệt nhau trên tập dữ liệu đào tạo/xác thực.
  2. Nếu trong quá trình xác thực, tôi nhận được độ chính xác thấp hơn đáng kể trên cùng một tập dữ liệu, thì tôi sẽ có dấu hiệu rõ ràng rằng chính sách BN hiện tại ảnh hưởng tiêu cực đến hiệu suất của mô hình trong quá trình suy luận.
  3. Mọi quá trình tiền xử lý sẽ diễn ra bên ngoài Generators. Điều này được thực hiện để khắc phục lỗi xuất hiện trong v2.1.5 (hiện đã được sửa trên phiên bản v2.1.6 sắp tới và bản chính mới nhất).
  4. Chúng tôi sẽ buộc Keras sử dụng các giai đoạn học tập khác nhau trong quá trình đánh giá. Nếu chúng tôi phát hiện ra sự khác biệt giữa độ chính xác được báo cáo, chúng tôi sẽ biết rằng chúng tôi bị ảnh hưởng bởi chính sách BN hiện tại.

Mã cho thử nghiệm được hiển thị dưới đây:

import numpy as np
from keras.datasets import cifar10
from scipy.misc import imresize

from keras.preprocessing.image import ImageDataGenerator
from keras.applications.resnet50 import ResNet50, preprocess_input
from keras.models import Model, load_model
from keras.layers import Dense, Flatten
from keras import backend as K


seed = 42
epochs = 10
records_per_class = 100

# We take only 2 classes from CIFAR10 and a very small sample to intentionally overfit the model.
# We will also use the same data for train/test and expect that Keras will give the same accuracy.
(x, y), _ = cifar10.load_data()

def filter_resize(category):
   # We do the preprocessing here instead in the Generator to get around a bug on Keras 2.1.5.
   return [preprocess_input(imresize(img, (224,224)).astype('float')) for img in x[y.flatten()==category][:records_per_class]]

x = np.stack(filter_resize(3)+filter_resize(5))
records_per_class = x.shape[0] // 2
y = np.array([[1,0]]*records_per_class + [[0,1]]*records_per_class)


# We will use a pre-trained model and finetune the top layers.
np.random.seed(seed)
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
l = Flatten()(base_model.output)
predictions = Dense(2, activation='softmax')(l)
model = Model(inputs=base_model.input, outputs=predictions)

for layer in model.layers[:140]:
   layer.trainable = False

for layer in model.layers[140:]:
   layer.trainable = True

model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(ImageDataGenerator().flow(x, y, seed=42), epochs=epochs, validation_data=ImageDataGenerator().flow(x, y, seed=42))

# Store the model on disk
model.save('tmp.h5')


# In every test we will clear the session and reload the model to force Learning_Phase values to change.
print('DYNAMIC LEARNING_PHASE')
K.clear_session()
model = load_model('tmp.h5')
# This accuracy should match exactly the one of the validation set on the last iteration.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))


print('STATIC LEARNING_PHASE = 0')
K.clear_session()
K.set_learning_phase(0)
model = load_model('tmp.h5')
# Again the accuracy should match the above.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))


print('STATIC LEARNING_PHASE = 1')
K.clear_session()
K.set_learning_phase(1)
model = load_model('tmp.h5')
# The accuracy will be close to the one of the training set on the last iteration.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))

Hãy kiểm tra kết quả trên Keras v2.1.5:

Epoch 1/10
1/7 [===>..........................] - ETA: 25s - loss: 0.8751 - acc: 0.5312
2/7 [=======>......................] - ETA: 11s - loss: 0.8594 - acc: 0.4531
3/7 [===========>..................] - ETA: 7s - loss: 0.8398 - acc: 0.4688 
4/7 [================>.............] - ETA: 4s - loss: 0.8467 - acc: 0.4844
5/7 [====================>.........] - ETA: 2s - loss: 0.7904 - acc: 0.5437
6/7 [========================>.....] - ETA: 1s - loss: 0.7593 - acc: 0.5625
7/7 [==============================] - 12s 2s/step - loss: 0.7536 - acc: 0.5744 - val_loss: 0.6526 - val_acc: 0.6650

Epoch 2/10
1/7 [===>..........................] - ETA: 4s - loss: 0.3881 - acc: 0.8125
2/7 [=======>......................] - ETA: 3s - loss: 0.3945 - acc: 0.7812
3/7 [===========>..................] - ETA: 2s - loss: 0.3956 - acc: 0.8229
4/7 [================>.............] - ETA: 1s - loss: 0.4223 - acc: 0.8047
5/7 [====================>.........] - ETA: 1s - loss: 0.4483 - acc: 0.7812
6/7 [========================>.....] - ETA: 0s - loss: 0.4325 - acc: 0.7917
7/7 [==============================] - 8s 1s/step - loss: 0.4095 - acc: 0.8089 - val_loss: 0.4722 - val_acc: 0.7700

Epoch 3/10
1/7 [===>..........................] - ETA: 4s - loss: 0.2246 - acc: 0.9375
2/7 [=======>......................] - ETA: 3s - loss: 0.2167 - acc: 0.9375
3/7 [===========>..................] - ETA: 2s - loss: 0.2260 - acc: 0.9479
4/7 [================>.............] - ETA: 2s - loss: 0.2179 - acc: 0.9375
5/7 [====================>.........] - ETA: 1s - loss: 0.2356 - acc: 0.9313
6/7 [========================>.....] - ETA: 0s - loss: 0.2392 - acc: 0.9427
7/7 [==============================] - 8s 1s/step - loss: 0.2288 - acc: 0.9456 - val_loss: 0.4282 - val_acc: 0.7800

Epoch 4/10
1/7 [===>..........................] - ETA: 4s - loss: 0.2183 - acc: 0.9688
2/7 [=======>......................] - ETA: 3s - loss: 0.1899 - acc: 0.9844
3/7 [===========>..................] - ETA: 2s - loss: 0.1887 - acc: 0.9792
4/7 [================>.............] - ETA: 1s - loss: 0.1995 - acc: 0.9531
5/7 [====================>.........] - ETA: 1s - loss: 0.1932 - acc: 0.9625
6/7 [========================>.....] - ETA: 0s - loss: 0.1819 - acc: 0.9688
7/7 [==============================] - 8s 1s/step - loss: 0.1743 - acc: 0.9747 - val_loss: 0.3778 - val_acc: 0.8400

Epoch 5/10
1/7 [===>..........................] - ETA: 3s - loss: 0.0973 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0828 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0851 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0897 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0928 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0936 - acc: 1.0000
7/7 [==============================] - 8s 1s/step - loss: 0.1337 - acc: 0.9838 - val_loss: 0.3916 - val_acc: 0.8100

Epoch 6/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0747 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0852 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0812 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0831 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0779 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0766 - acc: 1.0000
7/7 [==============================] - 8s 1s/step - loss: 0.0813 - acc: 1.0000 - val_loss: 0.3637 - val_acc: 0.8550

Epoch 7/10
1/7 [===>..........................] - ETA: 1s - loss: 0.2478 - acc: 0.8750
2/7 [=======>......................] - ETA: 2s - loss: 0.1966 - acc: 0.9375
3/7 [===========>..................] - ETA: 2s - loss: 0.1528 - acc: 0.9583
4/7 [================>.............] - ETA: 1s - loss: 0.1300 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.1193 - acc: 0.9750
6/7 [========================>.....] - ETA: 0s - loss: 0.1196 - acc: 0.9792
7/7 [==============================] - 8s 1s/step - loss: 0.1084 - acc: 0.9838 - val_loss: 0.3546 - val_acc: 0.8600

Epoch 8/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0539 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.0900 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0815 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0740 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0700 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0701 - acc: 1.0000
7/7 [==============================] - 8s 1s/step - loss: 0.0695 - acc: 1.0000 - val_loss: 0.3269 - val_acc: 0.8600

Epoch 9/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0306 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0377 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0898 - acc: 0.9583
4/7 [================>.............] - ETA: 1s - loss: 0.0773 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.0742 - acc: 0.9750
6/7 [========================>.....] - ETA: 0s - loss: 0.0708 - acc: 0.9792
7/7 [==============================] - 8s 1s/step - loss: 0.0659 - acc: 0.9838 - val_loss: 0.3604 - val_acc: 0.8600

Epoch 10/10
1/7 [===>..........................] - ETA: 3s - loss: 0.0354 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0381 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0354 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0828 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.0791 - acc: 0.9750
6/7 [========================>.....] - ETA: 0s - loss: 0.0794 - acc: 0.9792
7/7 [==============================] - 8s 1s/step - loss: 0.0704 - acc: 0.9838 - val_loss: 0.3615 - val_acc: 0.8600

DYNAMIC LEARNING_PHASE
[0.3614931714534759, 0.86]

STATIC LEARNING_PHASE = 0
[0.3614931714534759, 0.86]

STATIC LEARNING_PHASE = 1
[0.025861846953630446, 1.0]

Như chúng ta có thể thấy ở trên, trong quá trình đào tạo, mô hình học rất tốt dữ liệu và đạt được độ chính xác gần như hoàn hảo trên tập huấn luyện. Vẫn ở cuối mỗi lần lặp, khi đánh giá mô hình trên cùng một tập dữ liệu, chúng tôi nhận được sự khác biệt đáng kể về độ mất mát và độ chính xác. Lưu ý rằng chúng ta sẽ không nhận được điều này; chúng tôi đã cố ý trang bị quá mức mô hình trên tập dữ liệu cụ thể và các tập dữ liệu huấn luyện/xác thực giống hệt nhau.

Sau khi quá trình đào tạo hoàn tất, chúng tôi đánh giá mô hình bằng 3 cấu hình learning_phase khác nhau: Động, Tĩnh = 0 (chế độ kiểm tra) và Tĩnh = 1 (chế độ huấn luyện). Như chúng ta có thể thấy, hai cấu hình đầu tiên sẽ cung cấp kết quả giống hệt nhau về độ mất mát và độ chính xác, đồng thời giá trị của chúng khớp với độ chính xác được báo cáo của mô hình trên bộ xác thực ở lần lặp cuối cùng. Tuy nhiên, khi chuyển sang chế độ luyện tập, chúng tôi nhận thấy có sự khác biệt lớn (cải thiện). Tại sao vậy? Như chúng tôi đã nói trước đó, trọng số của mạng được điều chỉnh với mong muốn nhận được dữ liệu được chia tỷ lệ theo giá trị trung bình/phương sai của dữ liệu huấn luyện. Thật không may, những số liệu thống kê đó khác với số liệu được lưu trữ trong các lớp BN. Vì các lớp BN bị đóng băng nên những số liệu thống kê này không bao giờ được cập nhật. Sự khác biệt giữa các giá trị của số liệu thống kê BN dẫn đến sự suy giảm độ chính xác trong quá trình suy luận.

Hãy xem điều gì xảy ra khi chúng ta áp dụng bản vá:

Epoch 1/10
1/7 [===>..........................] - ETA: 26s - loss: 0.9992 - acc: 0.4375
2/7 [=======>......................] - ETA: 12s - loss: 1.0534 - acc: 0.4375
3/7 [===========>..................] - ETA: 7s - loss: 1.0592 - acc: 0.4479 
4/7 [================>.............] - ETA: 4s - loss: 0.9618 - acc: 0.5000
5/7 [====================>.........] - ETA: 2s - loss: 0.8933 - acc: 0.5250
6/7 [========================>.....] - ETA: 1s - loss: 0.8638 - acc: 0.5417
7/7 [==============================] - 13s 2s/step - loss: 0.8357 - acc: 0.5570 - val_loss: 0.2414 - val_acc: 0.9450

Epoch 2/10
1/7 [===>..........................] - ETA: 4s - loss: 0.2331 - acc: 0.9688
2/7 [=======>......................] - ETA: 2s - loss: 0.3308 - acc: 0.8594
3/7 [===========>..................] - ETA: 2s - loss: 0.3986 - acc: 0.8125
4/7 [================>.............] - ETA: 1s - loss: 0.3721 - acc: 0.8281
5/7 [====================>.........] - ETA: 1s - loss: 0.3449 - acc: 0.8438
6/7 [========================>.....] - ETA: 0s - loss: 0.3168 - acc: 0.8646
7/7 [==============================] - 9s 1s/step - loss: 0.3165 - acc: 0.8633 - val_loss: 0.1167 - val_acc: 0.9950

Epoch 3/10
1/7 [===>..........................] - ETA: 1s - loss: 0.2457 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.2592 - acc: 0.9688
3/7 [===========>..................] - ETA: 2s - loss: 0.2173 - acc: 0.9688
4/7 [================>.............] - ETA: 1s - loss: 0.2122 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.2003 - acc: 0.9688
6/7 [========================>.....] - ETA: 0s - loss: 0.1896 - acc: 0.9740
7/7 [==============================] - 9s 1s/step - loss: 0.1835 - acc: 0.9773 - val_loss: 0.0678 - val_acc: 1.0000

Epoch 4/10
1/7 [===>..........................] - ETA: 1s - loss: 0.2051 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.1652 - acc: 0.9844
3/7 [===========>..................] - ETA: 2s - loss: 0.1423 - acc: 0.9896
4/7 [================>.............] - ETA: 1s - loss: 0.1289 - acc: 0.9922
5/7 [====================>.........] - ETA: 1s - loss: 0.1225 - acc: 0.9938
6/7 [========================>.....] - ETA: 0s - loss: 0.1149 - acc: 0.9948
7/7 [==============================] - 9s 1s/step - loss: 0.1060 - acc: 0.9955 - val_loss: 0.0455 - val_acc: 1.0000

Epoch 5/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0769 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.0846 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0797 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0736 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0914 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0858 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0808 - acc: 1.0000 - val_loss: 0.0346 - val_acc: 1.0000

Epoch 6/10
1/7 [===>..........................] - ETA: 1s - loss: 0.1267 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.1039 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0893 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0780 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0758 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0789 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0738 - acc: 1.0000 - val_loss: 0.0248 - val_acc: 1.0000

Epoch 7/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0344 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0385 - acc: 1.0000
3/7 [===========>..................] - ETA: 3s - loss: 0.0467 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0445 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0446 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0429 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0421 - acc: 1.0000 - val_loss: 0.0202 - val_acc: 1.0000

Epoch 8/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0319 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0300 - acc: 1.0000
3/7 [===========>..................] - ETA: 3s - loss: 0.0320 - acc: 1.0000
4/7 [================>.............] - ETA: 2s - loss: 0.0307 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0303 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0291 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0358 - acc: 1.0000 - val_loss: 0.0167 - val_acc: 1.0000

Epoch 9/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0246 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0255 - acc: 1.0000
3/7 [===========>..................] - ETA: 3s - loss: 0.0258 - acc: 1.0000
4/7 [================>.............] - ETA: 2s - loss: 0.0250 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0252 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0260 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0327 - acc: 1.0000 - val_loss: 0.0143 - val_acc: 1.0000

Epoch 10/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0251 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.0228 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0217 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0249 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0244 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0239 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0290 - acc: 1.0000 - val_loss: 0.0127 - val_acc: 1.0000

DYNAMIC LEARNING_PHASE
[0.012697912137955427, 1.0]

STATIC LEARNING_PHASE = 0
[0.012697912137955427, 1.0]

STATIC LEARNING_PHASE = 1
[0.01744014158844948, 1.0]

Trước hết, chúng tôi quan sát thấy mạng hội tụ nhanh hơn đáng kể và đạt được độ chính xác hoàn hảo. Chúng tôi cũng thấy rằng không còn sự khác biệt về độ chính xác khi chúng tôi chuyển đổi giữa các giá trị learning_phase khác nhau.

2.5 Bản vá hoạt động như thế nào trên tập dữ liệu thực?

Vậy bản vá hoạt động như thế nào trong một thử nghiệm thực tế hơn? Hãy sử dụng ResNet50 được đào tạo trước của Keras (ban đầu phù hợp với imagenet), loại bỏ lớp phân loại trên cùng và tinh chỉnh nó khi có và không có bản vá rồi so sánh kết quả. Đối với dữ liệu, chúng tôi sẽ sử dụng CIFAR10 (phần phân chia đào tạo/kiểm tra tiêu chuẩn do Keras cung cấp) và chúng tôi sẽ đổi kích thước hình ảnh thành 224×224 để chúng tương thích với kích thước đầu vào của ResNet50.

Chúng tôi sẽ thực hiện 10 kỷ nguyên để huấn luyện lớp phân loại hàng đầu bằng RSMprop và sau đó chúng tôi sẽ thực hiện thêm 5 kỷ nguyên nữa để tinh chỉnh mọi thứ sau lớp thứ 139 bằng cách sử dụng SGD(lr=1e-4, đà=0.9). Nếu không có bản vá, mô hình của chúng tôi đạt được độ chính xác 87.44%. Sử dụng bản vá, chúng tôi đạt được độ chính xác 92.36%, cao hơn gần 5 điểm.

2.6 Chúng ta có nên áp dụng cách sửa lỗi tương tự cho các lớp khác như Dropout không?

Chuẩn hóa hàng loạt không phải là lớp duy nhất hoạt động khác nhau giữa chế độ đào tạo và thử nghiệm. Dropout và các biến thể của nó cũng có tác dụng tương tự. Chúng ta có nên áp dụng chính sách tương tự cho tất cả các lớp này không? Tôi tin là không (mặc dù tôi rất muốn nghe suy nghĩ của bạn về điều này). Lý do là Dropout được sử dụng để tránh trang bị quá mức, do đó việc khóa vĩnh viễn nó ở chế độ dự đoán trong quá trình huấn luyện sẽ làm mất đi mục đích của nó. Bạn nghĩ sao?

Tôi thực sự tin rằng sự khác biệt này phải được giải quyết ở Keras. Tôi thậm chí còn thấy những tác động sâu sắc hơn (độ chính xác từ 100% xuống 50%) trong các ứng dụng trong thế giới thực do vấn đề này gây ra. TÔI dự định gửi đã gửi một PR gửi Keras về bản sửa lỗi và hy vọng nó sẽ được chấp nhận.

Nếu bạn thích bài đăng trên blog này, vui lòng dành chút thời gian để chia sẻ nó trên Facebook hoặc Twitter. 🙂

Dấu thời gian:

Thêm từ Hộp dữ liệu