Руководство по написанию пользовательских обратных вызовов TensorFlow/Keras

Введение

Предположим, вы хотите, чтобы ваша модель Keras имела определенное поведение во время обучения, оценки или прогнозирования. Например, вы можете захотеть сохранить свою модель в каждую эпоху обучения. Один из способов сделать это — использовать обратные вызовы.

В общем, обратные вызовы — это функции, которые вызываются, когда происходит какое-то событие, и передаются в качестве аргументов другим функциям. В случае с Keras это инструмент для настройки поведения вашей модели — будь то во время обучения, оценки или вывода. Некоторые приложения ведут логирование, постоянство модели, раннюю остановку или изменение скорости обучения. Это делается путем передачи списка обратных вызовов в качестве аргументов для keras.Model.fit(),keras.Model.evaluate() or keras.Model.predict().

Некоторыми распространенными вариантами использования обратных вызовов являются изменение скорости обучения, ведение журнала, мониторинг и досрочное прекращение обучения. Keras имеет ряд встроенных обратных вызовов, подробно
в документации
.

Однако для некоторых более специфических приложений может потребоваться настраиваемый обратный вызов. Например, реализация прогрева скорости обучения с косинусным затуханием после периода удержания в настоящее время не встроен, но широко используется в качестве планировщика.

Класс обратного вызова и его методы

У Keras есть специальный класс обратного вызова, keras.callbacks.Callback, с методами, которые можно вызывать во время обучения, тестирования и логического вывода на глобальном, пакетном или эпохальном уровне. Чтобы создавать пользовательские обратные вызовы, нам нужно создать подкласс и переопределить эти методы.

Ассоциация keras.callbacks.Callback класс имеет три вида методов:

  • глобальные методы: вызываются в начале или в конце fit(), evaluate() и predict().
  • методы пакетного уровня: вызываются в начале или в конце обработки пакета.
  • методы уровня эпохи: вызываются в начале или в конце обучающего пакета.

Примечание: Каждый метод имеет доступ к словарю с именем logs. Ключи и значения logs являются контекстными — они зависят от события, которое вызывает метод. Более того, у нас есть доступ к модели внутри каждого метода через self.model атрибутов.

Давайте рассмотрим три примера пользовательских обратных вызовов — один для обучения, один для оценки и один для прогнозирования. Каждый будет печатать на каждом этапе, что делает наша модель и к каким журналам у нас есть доступ. Это полезно для понимания того, что можно делать с пользовательскими обратными вызовами на каждом этапе.

Начнем с определения модели игрушки:

import tensorflow as tf
from tensorflow import keras
import numpy as np

model = keras.Sequential()
model.add(keras.layers.Dense(10, input_dim = 1, activation='relu'))
model.add(keras.layers.Dense(10, activation='relu'))
model.add(keras.layers.Dense(1))
model.compile(
    optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
    loss = "mean_squared_error",
    metrics = ["mean_absolute_error"]
)

x = np.random.uniform(low = 0, high = 10, size = 1000)
y = x**2
x_train, x_test = (x[:900],x[900:])
y_train, y_test = (y[:900],y[900:])

Пользовательский обратный вызов обучения

Наш первый обратный вызов должен быть вызван во время обучения. Давайте подкласс Callback учебный класс:

class TrainingCallback(keras.callbacks.Callback):
    def __init__(self):
        self.tabulation = {"train":"", 'batch': " "*8, 'epoch':" "*4}
    def on_train_begin(self, logs=None):
        tab = self.tabulation['train']
        print(f"{tab}Training!")
        print(f"{tab}available logs: {logs}")

    def on_train_batch_begin(self, batch, logs=None):
        tab = self.tabulation['batch']
        print(f"{tab}Batch {batch}")
        print(f"{tab}available logs: {logs}")

    def on_train_batch_end(self, batch, logs=None):
        tab = self.tabulation['batch']
        print(f"{tab}End of Batch {batch}")
        print(f"{tab}available logs: {logs}")

    def on_epoch_begin(self, epoch, logs=None):
        tab = self.tabulation['epoch']
        print(f"{tab}Epoch {epoch} of training")
        print(f"{tab}available logs: {logs}")

    def on_epoch_end(self, epoch, logs=None):
        tab = self.tabulation['epoch']
        print(f"{tab}End of Epoch {epoch} of training")
        print(f"{tab}available logs: {logs}")

    def on_train_end(self, logs=None):
        tab = self.tabulation['train']
        print(f"{tab}Finishing training!")
        print(f"{tab}available logs: {logs}")

Если какой-либо из этих методов не переопределен, поведение по умолчанию останется прежним. В нашем примере — мы просто распечатываем доступные журналы и уровень, на котором применяется обратный вызов, с правильным отступом.

Давайте посмотрим на выходы:

model.fit(
    x_train,
    y_train,
    batch_size=500,
    epochs=2,
    verbose=0,
    callbacks=[TrainingCallback()],
)
Training!
available logs: {}
    Epoch 0 of training
    available logs: {}
        Batch 0
        available logs: {}
        End of Batch 0
        available logs: {'loss': 2172.373291015625, 'mean_absolute_error': 34.79669952392578}
        Batch 1
        available logs: {}
        End of Batch 1
        available logs: {'loss': 2030.1309814453125, 'mean_absolute_error': 33.30256271362305}
    End of Epoch 0 of training
    available logs: {'loss': 2030.1309814453125, 'mean_absolute_error': 33.30256271362305}
    Epoch 1 of training
    available logs: {}
        Batch 0
        available logs: {}
        End of Batch 0
        available logs: {'loss': 1746.2772216796875, 'mean_absolute_error': 30.268001556396484}
        Batch 1
        available logs: {}
        End of Batch 1
        available logs: {'loss': 1467.36376953125, 'mean_absolute_error': 27.10252571105957}
    End of Epoch 1 of training
    available logs: {'loss': 1467.36376953125, 'mean_absolute_error': 27.10252571105957}
Finishing training!
available logs: {'loss': 1467.36376953125, 'mean_absolute_error': 27.10252571105957}


Обратите внимание, что на каждом этапе мы можем следить за тем, что делает модель и к каким показателям у нас есть доступ. В конце каждой партии и эпохи у нас есть доступ к функции потерь в выборке и показателям нашей модели.

Пользовательский обратный вызов оценки

Теперь давайте позвоним Model.evaluate() метод. Мы видим, что в конце партии у нас есть доступ к функции потерь и метрикам на данный момент, а в конце оценки у нас есть доступ к общим потерям и метрикам:

class TestingCallback(keras.callbacks.Callback):
    def __init__(self):
          self.tabulation = {"test":"", 'batch': " "*8}
      
    def on_test_begin(self, logs=None):
        tab = self.tabulation['test']
        print(f'{tab}Evaluating!')
        print(f'{tab}available logs: {logs}')

    def on_test_end(self, logs=None):
        tab = self.tabulation['test']
        print(f'{tab}Finishing evaluation!')
        print(f'{tab}available logs: {logs}')

    def on_test_batch_begin(self, batch, logs=None):
        tab = self.tabulation['batch']
        print(f"{tab}Batch {batch}")
        print(f"{tab}available logs: {logs}")

    def on_test_batch_end(self, batch, logs=None):
        tab = self.tabulation['batch']
        print(f"{tab}End of batch {batch}")
        print(f"{tab}available logs: {logs}")
res = model.evaluate(
    x_test, y_test, batch_size=100, verbose=0, callbacks=[TestingCallback()]
)
Evaluating!
available logs: {}
        Batch 0
        available logs: {}
        End of batch 0
        available logs: {'loss': 382.2723083496094, 'mean_absolute_error': 14.069927215576172}
Finishing evaluation!
available logs: {'loss': 382.2723083496094, 'mean_absolute_error': 14.069927215576172}

Пользовательский обратный вызов прогноза

Наконец, позвоним Model.predict() метод. Обратите внимание, что в конце каждого пакета у нас есть доступ к прогнозируемым результатам нашей модели:

class PredictionCallback(keras.callbacks.Callback):
    def __init__(self):
        self.tabulation = {"prediction":"", 'batch': " "*8}

    def on_predict_begin(self, logs=None):
        tab = self.tabulation['prediction']
        print(f"{tab}Predicting!")
        print(f"{tab}available logs: {logs}")

    def on_predict_end(self, logs=None):
        tab = self.tabulation['prediction']
        print(f"{tab}End of Prediction!")
        print(f"{tab}available logs: {logs}")

    def on_predict_batch_begin(self, batch, logs=None):
        tab = self.tabulation['batch']
        print(f"{tab}batch {batch}")
        print(f"{tab}available logs: {logs}")

    def on_predict_batch_end(self, batch, logs=None):
        tab = self.tabulation['batch']
        print(f"{tab}End of batch {batch}")
        print(f"{tab}available logs:n {logs}")
res = model.predict(x_test[:10],
                    verbose = 0, 
                    callbacks=[PredictionCallback()])

Ознакомьтесь с нашим практическим руководством по изучению Git с рекомендациями, принятыми в отрасли стандартами и прилагаемой памяткой. Перестаньте гуглить команды Git и на самом деле изучить это!

Predicting!
available logs: {}
        batch 0
        available logs: {}
        End of batch 0
        available logs:
 {'outputs': array([[ 7.743822],
       [27.748264],
       [33.082104],
       [26.530678],
       [27.939169],
       [18.414223],
       [42.610645],
       [36.69335 ],
       [13.096557],
       [37.120853]], dtype=float32)}
End of Prediction!
available logs: {}

С их помощью вы можете настроить поведение, настроить мониторинг или иным образом изменить процессы обучения, оценки или вывода. Альтернативой подклассу является использование LambdaCallback.

Использование LambaCallback

Одним из встроенных обратных вызовов в Keras является LambdaCallback учебный класс. Этот обратный вызов принимает функцию, которая определяет, как она себя ведет и что делает! В некотором смысле, это позволяет вам использовать любую произвольную функцию в качестве обратного вызова, что позволяет создавать собственные обратные вызовы.

Класс имеет необязательные параметры:
on_epoch_begin

  • on_epoch_end
  • on_batch_begin
  • on_batch_end
  • on_train_begin
  • on_train_end

Каждый параметр принимает функция который вызывается в соответствующем событии модели. В качестве примера давайте сделаем обратный вызов, чтобы отправить электронное письмо, когда модель закончит обучение:

import smtplib
from email.message import EmailMessage

def send_email(logs): 
    msg = EmailMessage()
    content = f"""The model has finished training."""
    for key, value in logs.items():
      content = content + f"n{key}:{value:.2f}"
    msg.set_content(content)
    msg['Subject'] = f'Training report'
    msg['From'] = '[email protected]'
    msg['To'] = 'receiver-email'

    s = smtplib.SMTP('smtp.gmail.com', 587)
    s.starttls()
    s.login("[email protected]", "your-gmail-app-password")
    s.send_message(msg)
    s.quit()

lambda_send_email = lambda logs : send_email(logs)

email_callback = keras.callbacks.LambdaCallback(on_train_end = lambda_send_email)

model.fit(
    x_train,
    y_train,
    batch_size=100,
    epochs=1,
    verbose=0,
    callbacks=[email_callback],
)

Чтобы сделать наш пользовательский обратный вызов, используя LambdaCallback, нам просто нужно реализовать функцию, которую мы хотим вызывать, обернуть ее как lambda функцию и передать ее
LambdaCallback класс как параметр.

Обратный вызов для визуализации обучения модели

В этом разделе мы приведем пример пользовательского обратного вызова, который создает анимацию улучшения производительности нашей модели во время обучения. Для этого мы сохраняем значения журналов в конце каждого пакета. Затем, в конце цикла обучения, мы создаем анимацию, используя matplotlib.

Для улучшения визуализации потери и показатели будут представлены в логарифмическом масштабе:

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation
from IPython import display

class TrainingAnimationCallback(keras.callbacks.Callback):
    def __init__(self, duration = 40, fps = 1000/25):
        self.duration = duration
        self.fps = fps
        self.logs_history = []

    def set_plot(self):   
        self.figure = plt.figure()
        
        plt.xticks(
            range(0,self.params['steps']*self.params['epochs'], self.params['steps']),
            range(0,self.params['epochs']))
        plt.xlabel('Epoch')
        plt.ylabel('Loss & Metrics ($Log_{10}$ scale)')

        self.plot = {}
        for metric in self.model.metrics_names:
          self.plot[metric], = plt.plot([],[], label = metric)
          
        max_y = [max(log.values()) for log in self.logs_history]
        
        self.title = plt.title(f'batches:0')
        plt.xlim(0,len(self.logs_history)) 
        plt.ylim(0,max(max_y))

           
        plt.legend(loc='upper right')
  
    def animation_function(self,frame):
        batch = frame % self.params['steps']
        self.title.set_text(f'batch:{batch}')
        x = list(range(frame))
        
        for metric in self.model.metrics_names:
            y = [log[metric] for log in self.logs_history[:frame]]
            self.plot[metric].set_data(x,y)
        
    def on_train_batch_end(self, batch, logs=None):
        logarithm_transform = lambda item: (item[0], np.log(item[1]))
        logs = dict(map(logarithm_transform,logs.items()))
        self.logs_history.append(logs)
       
    def on_train_end(self, logs=None):
        self.set_plot()
        num_frames = int(self.duration*self.fps)
        num_batches = self.params['steps']*self.params['epochs']
        selected_batches = range(0, num_batches , num_batches//num_frames )
        interval = 1000*(1/self.fps)
        anim_created = FuncAnimation(self.figure, 
                                     self.animation_function,
                                     frames=selected_batches,
                                     interval=interval)
        video = anim_created.to_html5_video()
        
        html = display.HTML(video)
        display.display(html)
        plt.close()

Мы будем использовать ту же модель, что и раньше, но с большим количеством обучающих выборок:

import tensorflow as tf
from tensorflow import keras
import numpy as np

model = keras.Sequential()
model.add(keras.layers.Dense(10, input_dim = 1, activation='relu'))
model.add(keras.layers.Dense(10, activation='relu'))
model.add(keras.layers.Dense(1))
model.compile(
    optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
    loss = "mean_squared_error",
    metrics = ["mean_absolute_error"]
)

def create_sample(sample_size, train_test_proportion = 0.9):
    x = np.random.uniform(low = 0, high = 10, size = sample_size)
    y = x**2
    train_test_split = int(sample_size*train_test_proportion)
    x_train, x_test = (x[:train_test_split],x[train_test_split:])
    y_train, y_test = (y[:train_test_split],y[train_test_split:])
    return (x_train,x_test,y_train,y_test)

x_train,x_test,y_train,y_test = create_sample(35200)


model.fit(
    x_train,
    y_train,
    batch_size=32,
    epochs=2,
    verbose=0,
    callbacks=[TrainingAnimationCallback()],
)

Нашим результатом является анимация метрик и функции потерь по мере их изменения в процессе обучения:

Ваш браузер не поддерживает HTML-видео.

Заключение

В этом руководстве мы рассмотрели реализацию пользовательских обратных вызовов в Keras.
Существует два варианта реализации пользовательских обратных вызовов — путем создания подкласса keras.callbacks.Callback класса или с помощью keras.callbacks.LambdaCallback класса.

Мы видели один практический пример использования LambdaCallbackдля отправки электронного письма в конце цикла обучения и один пример подкласса Callback класс, создающий анимацию цикла обучения.

Хотя у Keras есть много встроенных обратных вызовов, знание того, как реализовать пользовательский обратный вызов, может быть полезно для более конкретных приложений.

Отметка времени:

Больше от Стекабьюс