Non utilizzare Flatten(): pooling globale per CNN con TensorFlow e Keras PlatoBlockchain Data Intelligence. Ricerca verticale. Ai.

Non utilizzare Flatten() – Pooling globale per CNN con TensorFlow e Keras

La maggior parte dei professionisti, mentre apprende per la prima volta le architetture della rete neurale convoluzionale (CNN), apprende che si compone di tre segmenti di base:

  • Strati convoluzionali
  • Livelli di pooling
  • Livelli completamente connessi

La maggior parte delle risorse ha alcuni variazione su questa segmentazione, incluso il mio libro. Soprattutto online: i livelli completamente connessi si riferiscono a a strato appiattito e (di solito) multipli strati densi.

Questa era la norma e architetture ben note come VGGNets utilizzavano questo approccio e finivano in:

model = keras.Sequential([
    
    keras.layers.MaxPooling2D((2, 2), strides=(2, 2), padding='same'),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'), 
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dense(n_classes, activation='softmax')
])

Tuttavia, per qualche ragione, spesso si dimentica che VGGNet è stata praticamente l'ultima architettura a utilizzare questo approccio, a causa dell'ovvio collo di bottiglia computazionale che crea. Non appena ResNets, pubblicato solo l'anno dopo VGGNets (e 7 anni fa), tutte le architetture tradizionali hanno terminato le loro definizioni di modello con:

model = keras.Sequential([
    
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(n_classes, activation='softmax')
])

L'appiattimento nelle CNN dura da 7 anni. 7 anni! E non abbastanza persone sembrano parlare dell'effetto dannoso che ha sia sulla tua esperienza di apprendimento che sulle risorse computazionali che stai utilizzando.

Il pooling medio globale è preferibile su molti account rispetto all'appiattimento. Se stai creando un prototipo di una piccola CNN, usa Global Pooling. Se stai insegnando a qualcuno le CNN, usa il Global Pooling. Se stai facendo un MVP, usa il Global Pooling. Usa i livelli di appiattimento per altri casi d'uso in cui sono effettivamente necessari.

Caso di studio: appiattimento vs pooling globale

Global Pooling condensa tutte le mappe delle caratteristiche in una singola, raggruppando tutte le informazioni rilevanti in un'unica mappa che può essere facilmente compresa da un singolo strato di classificazione denso anziché da più livelli. In genere viene applicato come pooling medio (GlobalAveragePooling2D) o max pooling (GlobalMaxPooling2D) e può funzionare anche per input 1D e 3D.

Invece di appiattire una mappa delle caratteristiche come (7, 7, 32) in un vettore di lunghezza 1536 e allenando uno o più livelli per discernere i modelli da questo lungo vettore: possiamo condensarlo in un (7, 7) vettore e classificare direttamente da lì. È così semplice!

Nota che i livelli di collo di bottiglia per reti come ResNets contano in decine di migliaia di funzionalità, non solo 1536. Durante l'appiattimento, stai torturando la tua rete per imparare da vettori dalla forma strana in modo molto inefficiente. Immagina un'immagine 2D tagliata su ogni riga di pixel e quindi concatenata in un vettore piatto. I due pixel che prima erano a 0 pixel di distanza verticalmente non lo sono feature_map_width pixel di distanza in orizzontale! Anche se questo potrebbe non avere molta importanza per un algoritmo di classificazione, che favorisce l'invarianza spaziale, questo non sarebbe nemmeno concettualmente buono per altre applicazioni di visione artificiale.

Definiamo una piccola rete dimostrativa che utilizza uno strato di appiattimento con un paio di strati densi:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.summary()

Che aspetto ha il riassunto?

...                                                              
 dense_6 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 11,574,090
Trainable params: 11,573,898
Non-trainable params: 192
_________________________________________________________________

11.5 milioni di parametri per una rete di giocattoli e guarda i parametri esplodere con un input più ampio. 11.5 milioni di parametri. EfficientNets, una delle reti con le migliori prestazioni mai progettate, funziona con circa 6 milioni di parametri e non può essere paragonata a questo semplice modello in termini di prestazioni effettive e capacità di apprendere dai dati.

Potremmo ridurre questo numero in modo significativo rendendo la rete più profonda, il che introdurrebbe più pooling massimo (e potenzialmente una convoluzione graduale) per ridurre le mappe delle funzionalità prima che vengano appiattite. Tuttavia, considera che renderemmo la rete più complessa per renderla meno dispendiosa dal punto di vista computazionale, tutto per il bene di un singolo livello che sta gettando una chiave inglese nei piani.

Andare più in profondità con i livelli dovrebbe significare estrarre relazioni più significative e non lineari tra i punti dati, non riducendo le dimensioni dell'input per soddisfare un livello di appiattimento.

Ecco una rete con pooling globale:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(10, activation='softmax')
])

model.summary()

Sommario?

 dense_8 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 66,602
Trainable params: 66,410
Non-trainable params: 192
_________________________________________________________________

Molto meglio! Se andiamo più in profondità con questo modello, il conteggio dei parametri aumenterà e potremmo essere in grado di acquisire modelli di dati più complessi con i nuovi livelli. Se fatto ingenuamente, però, sorgeranno gli stessi problemi che hanno legato i VGGNet.

Andare oltre – Progetto end-to-end manuale

La tua natura curiosa ti fa venire voglia di andare oltre? Ti consigliamo di dare un'occhiata al nostro Progetto guidato: "Reti neurali convoluzionali: oltre le architetture di base".

Dai un'occhiata alla nostra guida pratica e pratica per l'apprendimento di Git, con le migliori pratiche, gli standard accettati dal settore e il cheat sheet incluso. Smetti di cercare su Google i comandi Git e in realtà imparare esso!

Ti porterò in un piccolo viaggio nel tempo, andando dal 1998 al 2022, mettendo in evidenza le architetture definite nel corso degli anni, cosa le ha rese uniche, quali sono i loro svantaggi e implementando quelle notevoli da zero. Non c'è niente di meglio che avere un po' di sporco sulle mani quando si tratta di questi.

Puoi guidare un'auto senza sapere se il motore ha 4 o 8 cilindri e quale sia la posizione delle valvole all'interno del motore. Tuttavia, se vuoi progettare e apprezzare un motore (modello di visione artificiale), ti consigliamo di approfondire un po'. Anche se non vuoi perdere tempo a progettare architetture e vuoi invece costruire prodotti, che è ciò che la maggior parte vuole fare, troverai informazioni importanti in questa lezione. Imparerai perché l'utilizzo di architetture obsolete come VGGNet danneggerà il tuo prodotto e le tue prestazioni e perché dovresti saltarle se stai costruendo qualcosa di moderno, e imparerai a quali architetture puoi rivolgerti per risolvere problemi pratici e cosa i pro e i contro sono per ciascuno.

Se stai cercando di applicare la visione artificiale al tuo campo, utilizzando le risorse di questa lezione, sarai in grado di trovare i modelli più recenti, capire come funzionano e in base a quali criteri puoi confrontarli e prendere una decisione su quale uso.

Tu non devono a Google per le architetture e le loro implementazioni: in genere sono spiegate molto chiaramente nei documenti e framework come Keras rendono queste implementazioni più facili che mai. Il punto chiave di questo progetto guidato è insegnarti come trovare, leggere, implementare e comprendere architetture e documenti. Nessuna risorsa al mondo sarà in grado di stare al passo con tutti gli ultimi sviluppi. Ho incluso qui i documenti più recenti, ma tra pochi mesi ne appariranno di nuovi, ed è inevitabile. Sapere dove trovare implementazioni credibili, confrontarle con i documenti e modificarle può darti il ​​vantaggio competitivo richiesto per molti prodotti di visione artificiale che potresti voler costruire.

Conclusione

In questa breve guida, abbiamo dato un'occhiata a un'alternativa all'appiattimento nella progettazione dell'architettura CNN. Anche se breve, la guida affronta un problema comune durante la progettazione di prototipi o MVP e ti consiglia di utilizzare un'alternativa migliore all'appiattimento.

Qualsiasi ingegnere esperto di visione artificiale conoscerà e applicherà questo principio e la pratica è data per scontata. Sfortunatamente, non sembra essere adeguatamente trasmesso ai nuovi praticanti che stanno appena entrando nel campo e possono creare abitudini appiccicose che richiedono un po' di tempo per sbarazzarsi di.

Se stai entrando in Computer Vision, fatti un favore e non usare livelli di appiattimento prima di classificare le teste nel tuo viaggio di apprendimento.

Timestamp:

Di più da Impilamento