Verwenden Sie nicht Flatten() – Globales Pooling für CNNs mit TensorFlow und Keras PlatoBlockchain Data Intelligence. Vertikale Suche. Ai.

Verwenden Sie nicht Flatten() – Globales Pooling für CNNs mit TensorFlow und Keras

Die meisten Praktiker lernen beim ersten Kennenlernen von Convolutional Neural Network (CNN)-Architekturen, dass sie aus drei grundlegenden Segmenten bestehen:

  • Faltungsschichten
  • Pooling von Schichten
  • Vollständig verbundene Schichten

Die meisten Ressourcen haben einige Variation dieser Segmentierung, einschließlich meines eigenen Buches. Besonders online – vollständig verbundene Schichten beziehen sich auf a glättende Schicht und (normalerweise) mehrere dichte Schichten.

Dies war früher die Norm, und bekannte Architekturen wie VGGNets verwendeten diesen Ansatz und endeten in:

model = keras.Sequential([
    
    keras.layers.MaxPooling2D((2, 2), strides=(2, 2), padding='same'),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'), 
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dense(n_classes, activation='softmax')
])

Aus irgendeinem Grund wird jedoch oft vergessen, dass VGGNet praktisch die letzte Architektur war, die diesen Ansatz verwendet hat, aufgrund des offensichtlichen Rechenengpasses, den er verursacht. Sobald ResNets, das nur ein Jahr nach VGGNets (und vor 7 Jahren) veröffentlicht wurde, beendeten alle Mainstream-Architekturen ihre Modelldefinitionen mit:

model = keras.Sequential([
    
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(n_classes, activation='softmax')
])

Flattening in CNNs gibt es seit 7 Jahren. 7 Jahre! Und nicht genug Leute scheinen über die schädlichen Auswirkungen zu sprechen, die dies sowohl auf Ihre Lernerfahrung als auch auf die von Ihnen verwendeten Rechenressourcen hat.

Global Average Pooling ist bei vielen Konten dem Flattening vorzuziehen. Wenn Sie ein kleines CNN prototyen, verwenden Sie Global Pooling. Wenn Sie jemandem etwas über CNNs beibringen – verwenden Sie Global Pooling. Wenn Sie ein MVP erstellen, verwenden Sie Global Pooling. Verwenden Sie Flattening-Layer für andere Anwendungsfälle, in denen sie tatsächlich benötigt werden.

Fallstudie – Abflachung vs. globales Pooling

Global Pooling verdichtet alle Merkmalskarten zu einer einzigen Karte und bündelt alle relevanten Informationen in einer einzigen Karte, die von einer einzigen dichten Klassifizierungsebene anstelle von mehreren Ebenen leicht verstanden werden kann. Es wird normalerweise als durchschnittliches Pooling angewendet (GlobalAveragePooling2D) oder Max-Pooling (GlobalMaxPooling2D) und kann auch für 1D- und 3D-Eingabe verwendet werden.

Anstatt eine Feature-Map wie z (7, 7, 32) in einen Vektor der Länge 1536 und Trainieren einer oder mehrerer Schichten, um Muster aus diesem langen Vektor zu erkennen: Wir können ihn zu a verdichten (7, 7) Vektor und klassifizieren Sie direkt von dort aus. So einfach ist das!

Beachten Sie, dass Engpassschichten für Netzwerke wie ResNets Zehntausende von Funktionen umfassen, nicht nur 1536. Beim Abflachen quälen Sie Ihr Netzwerk, um auf sehr ineffiziente Weise von seltsam geformten Vektoren zu lernen. Stellen Sie sich ein 2D-Bild vor, das in jede Pixelreihe geschnitten und dann zu einem flachen Vektor verkettet wird. Die beiden Pixel, die früher vertikal 0 Pixel voneinander entfernt waren, sind es nicht feature_map_width Pixel horizontal entfernt! Während dies für einen Klassifikationsalgorithmus, der räumliche Invarianz bevorzugt, möglicherweise nicht allzu wichtig ist, wäre dies für andere Anwendungen des Computersehens nicht einmal konzeptionell gut.

Lassen Sie uns ein kleines demonstratives Netzwerk definieren, das eine Abflachungsebene mit einigen dichten Ebenen verwendet:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.summary()

Wie sieht die Zusammenfassung aus?

...                                                              
 dense_6 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 11,574,090
Trainable params: 11,573,898
Non-trainable params: 192
_________________________________________________________________

11.5 Millionen Parameter für ein Spielzeugnetzwerk – und beobachten Sie, wie die Parameter bei größerer Eingabe explodieren. 11.5 Millionen Parameter. EfficientNets, eines der leistungsstärksten Netzwerke, das jemals entwickelt wurde, arbeitet mit ~6 Millionen Parametern und kann nicht mit diesem einfachen Modell in Bezug auf die tatsächliche Leistung und die Fähigkeit, aus Daten zu lernen, verglichen werden.

Wir könnten diese Zahl erheblich reduzieren, indem wir das Netzwerk tiefer machen, was mehr Max-Pooling (und möglicherweise Strided Convolution) einführen würde, um die Feature-Maps zu reduzieren, bevor sie abgeflacht werden. Bedenken Sie jedoch, dass wir das Netzwerk komplexer machen würden, um es weniger rechenintensiv zu machen, alles um einer einzigen Schicht willen, die einen Schraubenschlüssel in die Pläne wirft.

Mit Layern tiefer zu gehen, sollte aussagekräftigere, nichtlineare Beziehungen zwischen Datenpunkten extrahieren und nicht die Eingabegröße reduzieren, um einem abflachenden Layer gerecht zu werden.

Hier ist ein Netzwerk mit globalem Pooling:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(10, activation='softmax')
])

model.summary()

Zusammenfassung?

 dense_8 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 66,602
Trainable params: 66,410
Non-trainable params: 192
_________________________________________________________________

Viel besser! Wenn wir mit diesem Modell tiefer gehen, erhöht sich die Anzahl der Parameter, und wir können möglicherweise kompliziertere Datenmuster mit den neuen Schichten erfassen. Wenn es jedoch naiv gemacht wird, werden die gleichen Probleme auftreten, die VGGNets gebunden haben.

Weiter gehen – Handgehaltenes End-to-End-Projekt

Ihre neugierige Natur macht Lust auf mehr? Wir empfehlen Ihnen, sich unsere anzuschauen Geführtes Projekt: „Convolutional Neural Networks – Beyond Basic Architectures“.

Sehen Sie sich unseren praxisnahen, praktischen Leitfaden zum Erlernen von Git an, mit Best Practices, branchenweit akzeptierten Standards und einem mitgelieferten Spickzettel. Hören Sie auf, Git-Befehle zu googeln und tatsächlich in Verbindung, um es!

Ich nehme Sie mit auf eine kleine Zeitreise – von 1998 bis 2022, wobei ich die im Laufe der Jahre entwickelten Architekturen hervorhebe, was sie einzigartig gemacht hat, was ihre Nachteile sind, und die bemerkenswerten von Grund auf neu implementiert habe. Es gibt nichts Besseres, als etwas Dreck an den Händen zu haben, wenn es um diese geht.

Sie können ein Auto fahren, ohne zu wissen, ob der Motor 4 oder 8 Zylinder hat und wie die Ventile im Motor angeordnet sind. Wenn Sie jedoch eine Engine (Computer-Vision-Modell) entwerfen und bewerten möchten, sollten Sie etwas tiefer gehen. Auch wenn Sie keine Zeit mit dem Entwerfen von Architekturen verbringen und stattdessen Produkte bauen möchten, was die meisten tun möchten, finden Sie in dieser Lektion wichtige Informationen. Sie erfahren, warum die Verwendung veralteter Architekturen wie VGGNet Ihrem Produkt und Ihrer Leistung schadet und warum Sie sie überspringen sollten, wenn Sie etwas Modernes bauen, und Sie erfahren, welche Architekturen Sie verwenden können, um praktische Probleme zu lösen, und welche die Vor- und Nachteile sind für jeden.

Wenn Sie Computer Vision auf Ihrem Gebiet anwenden möchten, können Sie mithilfe der Ressourcen aus dieser Lektion die neuesten Modelle finden, verstehen, wie sie funktionieren und nach welchen Kriterien Sie sie vergleichen und eine Entscheidung treffen können verwenden.

Du nicht müssen Sie nach Architekturen und deren Implementierungen bei Google suchen – sie werden in der Regel sehr klar in den Papieren erklärt, und Frameworks wie Keras machen diese Implementierungen einfacher als je zuvor. Das Wichtigste aus diesem geführten Projekt ist, Ihnen beizubringen, wie Sie Architekturen und Papiere finden, lesen, implementieren und verstehen. Kein Rohstoff der Welt wird mit den neuesten Entwicklungen Schritt halten können. Ich habe die neuesten Zeitungen hier eingefügt – aber in ein paar Monaten werden neue auftauchen, und das ist unvermeidlich. Zu wissen, wo Sie glaubwürdige Implementierungen finden, sie mit Papieren vergleichen und optimieren können, kann Ihnen den Wettbewerbsvorteil verschaffen, der für viele Computer-Vision-Produkte erforderlich ist, die Sie vielleicht bauen möchten.

Zusammenfassung

In diesem kurzen Leitfaden haben wir uns eine Alternative zum Flattening im CNN-Architekturdesign angesehen. Wenn auch kurz – der Leitfaden spricht ein häufiges Problem beim Entwerfen von Prototypen oder MVPs an und rät Ihnen, eine bessere Alternative zum Flattening zu verwenden.

Jeder erfahrene Computer Vision Engineer wird dieses Prinzip kennen und anwenden, und die Praxis wird als selbstverständlich angesehen. Unglücklicherweise scheint es nicht richtig an neue Praktizierende weitergegeben zu werden, die gerade erst das Feld betreten, und kann klebrige Gewohnheiten schaffen, die eine Weile brauchen, um sie loszuwerden.

Wenn Sie in Computer Vision einsteigen, tun Sie sich selbst einen Gefallen und verwenden Sie auf Ihrer Lernreise keine Abflachungsebenen vor Klassifizierungsköpfen.

Zeitstempel:

Mehr von Stapelmissbrauch