Anleitung zum Schreiben benutzerdefinierter TensorFlow/Keras-Callbacks

Einleitung

Angenommen, Sie möchten, dass Ihr Keras-Modell während des Trainings, der Bewertung oder der Vorhersage ein bestimmtes Verhalten aufweist. Beispielsweise möchten Sie Ihr Modell möglicherweise in jeder Trainingsepoche speichern. Eine Möglichkeit, dies zu tun, ist die Verwendung von Callbacks.

Im Allgemeinen sind Rückrufe Funktionen, die aufgerufen werden, wenn ein Ereignis eintritt, und als Argumente an andere Funktionen übergeben werden. Im Fall von Keras sind sie ein Werkzeug, um das Verhalten Ihres Modells anzupassen – sei es während des Trainings, der Bewertung oder der Inferenz. Einige Anwendungen sind Protokollierung, Modellpersistenz, vorzeitiges Stoppen oder Ändern der Lernrate. Dies geschieht, indem eine Liste von Rückrufen als Argumente für übergeben wird keras.Model.fit(),keras.Model.evaluate() or keras.Model.predict().

Einige häufige Anwendungsfälle für Rückrufe sind das Ändern der Lernrate, Protokollierung, Überwachung und vorzeitiges Beenden des Trainings. Keras hat eine Reihe von integrierten Rückrufen, detailliert
in der Dokumentation
.

Einige spezifischere Anwendungen erfordern jedoch möglicherweise einen benutzerdefinierten Rückruf. Zum Beispiel, Implementieren des Aufwärmens der Lernrate mit einem Cosine Decay nach einer Halteperiode ist derzeit nicht integriert, wird jedoch häufig verwendet und als Planer übernommen.

Callback-Klasse und ihre Methoden

Keras hat eine bestimmte Callback-Klasse, keras.callbacks.Callback, mit Methoden, die während des Trainings, des Testens und der Inferenz auf globaler, Batch- oder Epochenebene aufgerufen werden können. Um zu Erstellen Sie benutzerdefinierte Rückrufe, müssen wir eine Unterklasse erstellen und diese Methoden überschreiben.

Das keras.callbacks.Callback Die Klasse hat drei Arten von Methoden:

  • globale Methoden: aufgerufen am Anfang oder am Ende von fit(), evaluate() und predict().
  • Methoden auf Stapelebene: werden zu Beginn oder am Ende der Verarbeitung eines Stapels aufgerufen.
  • Methoden auf Epochenebene: werden zu Beginn oder am Ende eines Trainingsstapels aufgerufen.

Hinweis: Jede Methode hat Zugriff auf ein dict namens logs. Die Schlüssel und Werte von logs sind kontextabhängig – sie hängen von dem Ereignis ab, das die Methode aufruft. Darüber hinaus haben wir Zugriff auf das Modell in jeder Methode über die self.model Attribut.

Sehen wir uns drei benutzerdefinierte Callback-Beispiele an – eines für das Training, eines für die Bewertung und eines für die Vorhersage. Jeder druckt in jeder Phase aus, was unser Modell tut und auf welche Protokolle wir Zugriff haben. Dies ist hilfreich, um zu verstehen, was mit benutzerdefinierten Rückrufen in jeder Phase möglich ist.

Beginnen wir mit der Definition eines Spielzeugmodells:

import tensorflow as tf
from tensorflow import keras
import numpy as np

model = keras.Sequential()
model.add(keras.layers.Dense(10, input_dim = 1, activation='relu'))
model.add(keras.layers.Dense(10, activation='relu'))
model.add(keras.layers.Dense(1))
model.compile(
    optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
    loss = "mean_squared_error",
    metrics = ["mean_absolute_error"]
)

x = np.random.uniform(low = 0, high = 10, size = 1000)
y = x**2
x_train, x_test = (x[:900],x[900:])
y_train, y_test = (y[:900],y[900:])

Rückruf für benutzerdefiniertes Training

Unser erster Rückruf soll während des Trainings angerufen werden. Unterklassen wir die Callback Klasse:

class TrainingCallback(keras.callbacks.Callback):
    def __init__(self):
        self.tabulation = {"train":"", 'batch': " "*8, 'epoch':" "*4}
    def on_train_begin(self, logs=None):
        tab = self.tabulation['train']
        print(f"{tab}Training!")
        print(f"{tab}available logs: {logs}")

    def on_train_batch_begin(self, batch, logs=None):
        tab = self.tabulation['batch']
        print(f"{tab}Batch {batch}")
        print(f"{tab}available logs: {logs}")

    def on_train_batch_end(self, batch, logs=None):
        tab = self.tabulation['batch']
        print(f"{tab}End of Batch {batch}")
        print(f"{tab}available logs: {logs}")

    def on_epoch_begin(self, epoch, logs=None):
        tab = self.tabulation['epoch']
        print(f"{tab}Epoch {epoch} of training")
        print(f"{tab}available logs: {logs}")

    def on_epoch_end(self, epoch, logs=None):
        tab = self.tabulation['epoch']
        print(f"{tab}End of Epoch {epoch} of training")
        print(f"{tab}available logs: {logs}")

    def on_train_end(self, logs=None):
        tab = self.tabulation['train']
        print(f"{tab}Finishing training!")
        print(f"{tab}available logs: {logs}")

Wenn eine dieser Methoden nicht überschrieben wird, wird das Standardverhalten wie zuvor fortgesetzt. In unserem Beispiel drucken wir einfach die verfügbaren Protokolle und die Ebene, auf der der Rückruf angewendet wird, mit der richtigen Einrückung aus.

Schauen wir uns die Ausgänge an:

model.fit(
    x_train,
    y_train,
    batch_size=500,
    epochs=2,
    verbose=0,
    callbacks=[TrainingCallback()],
)
Training!
available logs: {}
    Epoch 0 of training
    available logs: {}
        Batch 0
        available logs: {}
        End of Batch 0
        available logs: {'loss': 2172.373291015625, 'mean_absolute_error': 34.79669952392578}
        Batch 1
        available logs: {}
        End of Batch 1
        available logs: {'loss': 2030.1309814453125, 'mean_absolute_error': 33.30256271362305}
    End of Epoch 0 of training
    available logs: {'loss': 2030.1309814453125, 'mean_absolute_error': 33.30256271362305}
    Epoch 1 of training
    available logs: {}
        Batch 0
        available logs: {}
        End of Batch 0
        available logs: {'loss': 1746.2772216796875, 'mean_absolute_error': 30.268001556396484}
        Batch 1
        available logs: {}
        End of Batch 1
        available logs: {'loss': 1467.36376953125, 'mean_absolute_error': 27.10252571105957}
    End of Epoch 1 of training
    available logs: {'loss': 1467.36376953125, 'mean_absolute_error': 27.10252571105957}
Finishing training!
available logs: {'loss': 1467.36376953125, 'mean_absolute_error': 27.10252571105957}


Beachten Sie, dass wir bei jedem Schritt verfolgen können, was das Modell tut und auf welche Metriken wir Zugriff haben. Am Ende jeder Charge und Epoche haben wir Zugriff auf die In-Sample-Loss-Funktion und die Metriken unseres Modells.

Rückruf für benutzerdefinierte Bewertung

Rufen wir jetzt die an Model.evaluate() Methode. Wir können sehen, dass wir am Ende eines Stapels Zugriff auf die Verlustfunktion und die Metriken zu diesem Zeitpunkt haben, und am Ende der Auswertung haben wir Zugriff auf den Gesamtverlust und die Metriken:

class TestingCallback(keras.callbacks.Callback):
    def __init__(self):
          self.tabulation = {"test":"", 'batch': " "*8}
      
    def on_test_begin(self, logs=None):
        tab = self.tabulation['test']
        print(f'{tab}Evaluating!')
        print(f'{tab}available logs: {logs}')

    def on_test_end(self, logs=None):
        tab = self.tabulation['test']
        print(f'{tab}Finishing evaluation!')
        print(f'{tab}available logs: {logs}')

    def on_test_batch_begin(self, batch, logs=None):
        tab = self.tabulation['batch']
        print(f"{tab}Batch {batch}")
        print(f"{tab}available logs: {logs}")

    def on_test_batch_end(self, batch, logs=None):
        tab = self.tabulation['batch']
        print(f"{tab}End of batch {batch}")
        print(f"{tab}available logs: {logs}")
res = model.evaluate(
    x_test, y_test, batch_size=100, verbose=0, callbacks=[TestingCallback()]
)
Evaluating!
available logs: {}
        Batch 0
        available logs: {}
        End of batch 0
        available logs: {'loss': 382.2723083496094, 'mean_absolute_error': 14.069927215576172}
Finishing evaluation!
available logs: {'loss': 382.2723083496094, 'mean_absolute_error': 14.069927215576172}

Benutzerdefinierter Vorhersage-Callback

Schließlich rufen wir die an Model.predict() Methode. Beachten Sie, dass wir am Ende jedes Stapels Zugriff auf die vorhergesagten Ausgaben unseres Modells haben:

class PredictionCallback(keras.callbacks.Callback):
    def __init__(self):
        self.tabulation = {"prediction":"", 'batch': " "*8}

    def on_predict_begin(self, logs=None):
        tab = self.tabulation['prediction']
        print(f"{tab}Predicting!")
        print(f"{tab}available logs: {logs}")

    def on_predict_end(self, logs=None):
        tab = self.tabulation['prediction']
        print(f"{tab}End of Prediction!")
        print(f"{tab}available logs: {logs}")

    def on_predict_batch_begin(self, batch, logs=None):
        tab = self.tabulation['batch']
        print(f"{tab}batch {batch}")
        print(f"{tab}available logs: {logs}")

    def on_predict_batch_end(self, batch, logs=None):
        tab = self.tabulation['batch']
        print(f"{tab}End of batch {batch}")
        print(f"{tab}available logs:n {logs}")
res = model.predict(x_test[:10],
                    verbose = 0, 
                    callbacks=[PredictionCallback()])

Sehen Sie sich unseren praxisnahen, praktischen Leitfaden zum Erlernen von Git an, mit Best Practices, branchenweit akzeptierten Standards und einem mitgelieferten Spickzettel. Hören Sie auf, Git-Befehle zu googeln und tatsächlich in Verbindung, um es!

Predicting!
available logs: {}
        batch 0
        available logs: {}
        End of batch 0
        available logs:
 {'outputs': array([[ 7.743822],
       [27.748264],
       [33.082104],
       [26.530678],
       [27.939169],
       [18.414223],
       [42.610645],
       [36.69335 ],
       [13.096557],
       [37.120853]], dtype=float32)}
End of Prediction!
available logs: {}

Mit diesen können Sie das Verhalten anpassen, Überwachung einrichten oder die Trainings-, Bewertungs- oder Inferenzprozesse auf andere Weise ändern. Eine Alternative zur Untervergabe ist die Verwendung der LambdaCallback.

Verwenden von LambaCallback

Einer der eingebauten Rückrufe in Keras ist der LambdaCallback Klasse. Dieser Callback akzeptiert eine Funktion, die definiert, wie er sich verhält und was er tut! In gewisser Weise ermöglicht es Ihnen, jede beliebige Funktion als Callback zu verwenden, wodurch Sie benutzerdefinierte Callbacks erstellen können.

Die Klasse hat die optionalen Parameter:
-on_epoch_begin

  • on_epoch_end
  • on_batch_begin
  • on_batch_end
  • on_train_begin
  • on_train_end

Jeder Parameter akzeptiert eine Funktion die im jeweiligen Modellereignis aufgerufen wird. Lassen Sie uns als Beispiel einen Rückruf tätigen, um eine E-Mail zu senden, wenn das Modell das Training beendet:

import smtplib
from email.message import EmailMessage

def send_email(logs): 
    msg = EmailMessage()
    content = f"""The model has finished training."""
    for key, value in logs.items():
      content = content + f"n{key}:{value:.2f}"
    msg.set_content(content)
    msg['Subject'] = f'Training report'
    msg['From'] = '[email protected]'
    msg['To'] = 'receiver-email'

    s = smtplib.SMTP('smtp.gmail.com', 587)
    s.starttls()
    s.login("[email protected]", "your-gmail-app-password")
    s.send_message(msg)
    s.quit()

lambda_send_email = lambda logs : send_email(logs)

email_callback = keras.callbacks.LambdaCallback(on_train_end = lambda_send_email)

model.fit(
    x_train,
    y_train,
    batch_size=100,
    epochs=1,
    verbose=0,
    callbacks=[email_callback],
)

Um unseren benutzerdefinierten Rückruf zu tätigen, verwenden Sie LambdaCallback, müssen wir nur die Funktion implementieren, die aufgerufen werden soll, und sie als a lambda Funktion und übergebe sie an die
LambdaCallback Klasse als Parameter.

Ein Callback zum Visualisieren des Modelltrainings

In diesem Abschnitt geben wir ein Beispiel für einen benutzerdefinierten Rückruf, der eine Animation der Leistungsverbesserung unseres Modells während des Trainings erstellt. Dazu speichern wir die Werte der Protokolle am Ende jeder Charge. Dann, am Ende der Trainingsschleife, erstellen wir eine Animation mit matplotlib.

Um die Visualisierung zu verbessern, werden der Verlust und die Metriken im logarithmischen Maßstab dargestellt:

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation
from IPython import display

class TrainingAnimationCallback(keras.callbacks.Callback):
    def __init__(self, duration = 40, fps = 1000/25):
        self.duration = duration
        self.fps = fps
        self.logs_history = []

    def set_plot(self):   
        self.figure = plt.figure()
        
        plt.xticks(
            range(0,self.params['steps']*self.params['epochs'], self.params['steps']),
            range(0,self.params['epochs']))
        plt.xlabel('Epoch')
        plt.ylabel('Loss & Metrics ($Log_{10}$ scale)')

        self.plot = {}
        for metric in self.model.metrics_names:
          self.plot[metric], = plt.plot([],[], label = metric)
          
        max_y = [max(log.values()) for log in self.logs_history]
        
        self.title = plt.title(f'batches:0')
        plt.xlim(0,len(self.logs_history)) 
        plt.ylim(0,max(max_y))

           
        plt.legend(loc='upper right')
  
    def animation_function(self,frame):
        batch = frame % self.params['steps']
        self.title.set_text(f'batch:{batch}')
        x = list(range(frame))
        
        for metric in self.model.metrics_names:
            y = [log[metric] for log in self.logs_history[:frame]]
            self.plot[metric].set_data(x,y)
        
    def on_train_batch_end(self, batch, logs=None):
        logarithm_transform = lambda item: (item[0], np.log(item[1]))
        logs = dict(map(logarithm_transform,logs.items()))
        self.logs_history.append(logs)
       
    def on_train_end(self, logs=None):
        self.set_plot()
        num_frames = int(self.duration*self.fps)
        num_batches = self.params['steps']*self.params['epochs']
        selected_batches = range(0, num_batches , num_batches//num_frames )
        interval = 1000*(1/self.fps)
        anim_created = FuncAnimation(self.figure, 
                                     self.animation_function,
                                     frames=selected_batches,
                                     interval=interval)
        video = anim_created.to_html5_video()
        
        html = display.HTML(video)
        display.display(html)
        plt.close()

Wir verwenden dasselbe Modell wie zuvor, jedoch mit mehr Trainingsbeispielen:

import tensorflow as tf
from tensorflow import keras
import numpy as np

model = keras.Sequential()
model.add(keras.layers.Dense(10, input_dim = 1, activation='relu'))
model.add(keras.layers.Dense(10, activation='relu'))
model.add(keras.layers.Dense(1))
model.compile(
    optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
    loss = "mean_squared_error",
    metrics = ["mean_absolute_error"]
)

def create_sample(sample_size, train_test_proportion = 0.9):
    x = np.random.uniform(low = 0, high = 10, size = sample_size)
    y = x**2
    train_test_split = int(sample_size*train_test_proportion)
    x_train, x_test = (x[:train_test_split],x[train_test_split:])
    y_train, y_test = (y[:train_test_split],y[train_test_split:])
    return (x_train,x_test,y_train,y_test)

x_train,x_test,y_train,y_test = create_sample(35200)


model.fit(
    x_train,
    y_train,
    batch_size=32,
    epochs=2,
    verbose=0,
    callbacks=[TrainingAnimationCallback()],
)

Unsere Ausgabe ist eine Animation der Metriken und der Verlustfunktion, wie sie sich während des Trainingsprozesses ändern:

Ihr Browser unterstützt kein HTML-Video.

Zusammenfassung

In diesem Leitfaden haben wir uns die Implementierung benutzerdefinierter Callbacks in Keras angesehen.
Es gibt zwei Optionen zum Implementieren benutzerdefinierter Rückrufe – durch Unterklassen der keras.callbacks.Callback Klasse oder mit der keras.callbacks.LambdaCallback Klasse.

Wir haben ein praktisches Beispiel mit gesehen LambdaCallbackfür das Senden einer E-Mail am Ende der Trainingsschleife und ein Beispiel für die Unterklasse der Callback Klasse, die eine Animation der Trainingsschleife erstellt.

Obwohl Keras viele eingebaute Rückrufe hat, kann es für spezifischere Anwendungen nützlich sein, zu wissen, wie ein benutzerdefinierter Rückruf implementiert wird.

Zeitstempel:

Mehr von Stapelmissbrauch