Nu utilizați Flatten() - Pooling global pentru CNN-uri cu TensorFlow și Keras PlatoBlockchain Data Intelligence. Căutare verticală. Ai.

Nu utilizați Flatten() – Pooling global pentru CNN-uri cu TensorFlow și Keras

Majoritatea practicienilor, în timp ce învață mai întâi despre arhitecturile rețelelor neuronale convoluționale (CNN) – învață că aceasta este compusă din trei segmente de bază:

  • Straturi convoluționale
  • Straturi de grupare
  • Straturi complet conectate

Cele mai multe resurse au unele variație asupra acestei segmentări, inclusiv propria mea carte. În special online – straturile complet conectate se referă la a strat de aplatizare și (de obicei) multiple straturi dense.

Aceasta era obișnuită norma, iar arhitecturile bine-cunoscute, cum ar fi VGGNets, foloseau această abordare și se terminau în:

model = keras.Sequential([
    
    keras.layers.MaxPooling2D((2, 2), strides=(2, 2), padding='same'),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'), 
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dense(n_classes, activation='softmax')
])

Deși, din anumite motive, de multe ori se uită că VGGNet a fost practic ultima arhitectură care a folosit această abordare, din cauza blocajului de calcul evident pe care îl creează. De îndată ce ResNets, publicat la doar un an după VGGNets (și în urmă cu 7 ani), toate arhitecturile mainstream și-au încheiat definițiile modelului cu:

model = keras.Sequential([
    
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(n_classes, activation='softmax')
])

Aplatizarea în CNN-uri durează de 7 ani. ani 7! Și nu destui oameni par să vorbească despre efectul dăunător pe care îl are atât asupra experienței tale de învățare, cât și asupra resurselor de calcul pe care le folosești.

Global Average Pooling este de preferat în multe conturi decât aplatizarea. Dacă creați prototipuri pentru un CNN mic - utilizați Global Pooling. Dacă înveți pe cineva despre CNN-uri, folosește Global Pooling. Dacă faci un MVP – folosește Global Pooling. Utilizați straturi de aplatizare pentru alte cazuri de utilizare în care sunt de fapt necesare.

Studiu de caz – Aplatizare vs Pooling global

Global Pooling condensează toate hărțile de caracteristici într-o singură, grupând toate informațiile relevante într-o singură hartă care poate fi ușor de înțeles de un singur strat de clasificare dens în loc de mai multe straturi. Se aplică de obicei ca o punere în comun medie (GlobalAveragePooling2D) sau pooling maxim (GlobalMaxPooling2D) și poate funcționa și pentru intrare 1D și 3D.

În loc să aplatizeze o hartă caracteristică, cum ar fi (7, 7, 32) într-un vector de lungime 1536 și antrenând unul sau mai multe straturi pentru a discerne modele din acest vector lung: îl putem condensa într-un (7, 7) vector și clasifica direct de acolo. Este atat de simplu!

Rețineți că straturi de blocaj pentru rețele precum ResNets numără în zeci de mii de caracteristici, nu doar 1536. Când aplatizați, vă torturați rețeaua pentru a învăța din vectori cu forme ciudate într-un mod foarte ineficient. Imaginează-ți o imagine 2D tăiată pe fiecare rând de pixeli și apoi concatenată într-un vector plat. Cei doi pixeli care se aflau la o distanță de 0 pixeli pe verticală nu sunt feature_map_width pixeli distanță pe orizontală! În timp ce acest lucru poate să nu conteze prea mult pentru un algoritm de clasificare, care favorizează invarianța spațială, acest lucru nu ar fi nici măcar bun din punct de vedere conceptual pentru alte aplicații ale vederii computerizate.

Să definim o mică rețea demonstrativă care utilizează un strat de aplatizare cu câteva straturi dense:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.summary()

Cum arată rezumatul?

...                                                              
 dense_6 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 11,574,090
Trainable params: 11,573,898
Non-trainable params: 192
_________________________________________________________________

11.5 milioane de parametri pentru o rețea de jucării – și urmăriți cum explodează parametrii cu o intrare mai mare. 11.5 milioane de parametri. EfficientNets, una dintre cele mai performante rețele proiectate vreodată, funcționează la ~6 milioane de parametri și nu poate fi comparată cu acest model simplu în ceea ce privește performanța reală și capacitatea de a învăța din date.

Am putea reduce semnificativ acest număr prin adâncirea rețelei, ceea ce ar introduce mai multă grupare maximă (și potențial convoluție cu pas) pentru a reduce hărțile caracteristicilor înainte ca acestea să fie aplatizate. Cu toate acestea, luați în considerare că am face rețeaua mai complexă pentru a o face mai puțin costisitoare din punct de vedere computațional, totul de dragul unui singur strat care aruncă o cheie în planuri.

Aprofundarea straturilor ar trebui să însemne extragerea unor relații mai semnificative, neliniare între punctele de date, fără a reduce dimensiunea de intrare pentru a satisface un strat de aplatizare.

Iată o rețea cu pooling global:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(10, activation='softmax')
])

model.summary()

Rezumat?

 dense_8 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 66,602
Trainable params: 66,410
Non-trainable params: 192
_________________________________________________________________

Mult mai bine! Dacă mergem mai profund cu acest model, numărul parametrilor va crește și s-ar putea să reușim să captăm modele mai complicate de date cu noile straturi. Dacă se face naiv, vor apărea aceleași probleme care au legat VGGNets.

Mergând mai departe – Proiect manual de la capăt la capăt

Natura ta curios te face să vrei să mergi mai departe? Vă recomandăm să verificați Proiect ghidat: „Rețele neuronale convoluționale – Dincolo de arhitecturile de bază”.

Consultați ghidul nostru practic și practic pentru a învăța Git, cu cele mai bune practici, standarde acceptate de industrie și fisa de cheat incluse. Opriți căutarea pe Google a comenzilor Git și de fapt învăţa aceasta!

Vă voi duce într-o mică călătorie în timp – mergând din 1998 până în 2022, evidențiind arhitecturile definitorii dezvoltate de-a lungul anilor, ceea ce le-a făcut unice, care sunt dezavantajele lor și le voi implementa pe cele notabile de la zero. Nu este nimic mai bun decât să ai niște murdărie pe mâini când vine vorba de acestea.

Puteți conduce o mașină fără să știți dacă motorul are 4 sau 8 cilindri și care este poziționarea supapelor în interiorul motorului. Totuși – dacă doriți să proiectați și să apreciați un motor (model de computer vision), veți dori să mergeți puțin mai adânc. Chiar dacă nu doriți să pierdeți timp proiectând arhitecturi și doriți să construiți produse în schimb, ceea ce majoritatea doresc să facă – veți găsi informații importante în această lecție. Veți afla de ce folosirea arhitecturilor învechite, cum ar fi VGGNet, vă va afecta produsul și performanța și de ce ar trebui să le omiteți dacă construiți ceva modern și veți afla la ce arhitecturi puteți merge pentru a rezolva probleme practice și ce argumentele pro și contra sunt pentru fiecare.

Dacă doriți să aplicați viziunea computerizată în domeniul dvs., folosind resursele din această lecție - veți putea găsi cele mai noi modele, veți putea înțelege cum funcționează și după ce criterii le puteți compara și veți lua o decizie asupra cărora utilizare.

Tu nu trebuie de la Google pentru arhitecturi și implementările acestora – de obicei sunt explicate foarte clar în lucrări, iar cadrele precum Keras fac aceste implementări mai ușoare ca niciodată. Principala concluzie a acestui proiect ghidat este să vă învețe cum să găsiți, să citiți, să implementați și să înțelegeți arhitecturi și lucrări. Nicio resursă din lume nu va putea ține pasul cu toate cele mai noi evoluții. Am inclus cele mai noi lucrări aici – dar în câteva luni vor apărea altele noi, iar asta este inevitabil. Știind unde să găsiți implementări credibile, să le comparați cu lucrările și să le modificați, vă poate oferi avantajul competitiv necesar pentru multe produse de viziune computerizată pe care doriți să le construiți.

Concluzie

În acest scurt ghid, am aruncat o privire asupra unei alternative la aplatizare în designul arhitecturii CNN. Deși scurt – ghidul abordează o problemă comună atunci când proiectați prototipuri sau MVP-uri și vă sfătuiește să utilizați o alternativă mai bună la aplatizare.

Orice inginer experimentat în viziune pe computer va cunoaște și va aplica acest principiu, iar practica este considerată de la sine înțeles. Din păcate, nu pare să fie transmis în mod corespunzător noilor practicanți care tocmai intră pe teren și poate crea obiceiuri lipicioase de care durează ceva timp pentru a scăpa.

Dacă intri în Computer Vision – fă-ți o favoare și nu folosi straturi de aplatizare înainte de capetele de clasificare în călătoria ta de învățare.

Timestamp-ul:

Mai mult de la Stackabuse