Gebruik Flatten() niet - Global Pooling voor CNN's met TensorFlow en Keras PlatoBlockchain Data Intelligence. Verticaal zoeken. Ai.

Gebruik Flatten() niet - Global Pooling voor CNN's met TensorFlow en Keras

De meeste beoefenaars leren, terwijl ze eerst leren over Convolutional Neural Network (CNN) -architecturen, dat het uit drie basissegmenten bestaat:

  • Convolutionele lagen
  • Lagen poolen
  • Volledig verbonden lagen

De meeste bronnen hebben sommige variatie op deze segmentatie, inclusief mijn eigen boek. Vooral online – volledig verbonden lagen verwijzen naar a afvlakkingslaag en (meestal) meerdere dichte lagen.

Dit was vroeger de norm, en bekende architecturen zoals VGGNets gebruikten deze benadering en eindigden in:

model = keras.Sequential([
    
    keras.layers.MaxPooling2D((2, 2), strides=(2, 2), padding='same'),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'), 
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dense(n_classes, activation='softmax')
])

Hoewel, om de een of andere reden, wordt vaak vergeten dat VGGNet praktisch de laatste architectuur was die deze aanpak gebruikte, vanwege de duidelijke computationele bottleneck die het creëert. Zodra ResNets, net het jaar na VGGNets (en 7 jaar geleden) werd gepubliceerd, beëindigden alle reguliere architecturen hun modeldefinities met:

model = keras.Sequential([
    
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(n_classes, activation='softmax')
])

Flattening in CNN's houdt al 7 jaar aan. 7 jaar! En er lijken niet genoeg mensen te praten over het schadelijke effect dat het heeft op zowel je leerervaring als de computerbronnen die je gebruikt.

Global Average Pooling verdient bij veel accounts de voorkeur boven afvlakking. Als je een kleine CNN maakt, gebruik dan Global Pooling. Als je iemand over CNN's leert, gebruik dan Global Pooling. Als je een MVP maakt, gebruik dan Global Pooling. Gebruik afvlakkingslagen voor andere toepassingen waar ze echt nodig zijn.

Casestudy - Afvlakken versus wereldwijd poolen

Global Pooling brengt alle feature maps samen in één enkele, waarbij alle relevante informatie wordt gebundeld in een enkele kaart die gemakkelijk kan worden begrepen door een enkele dichte classificatielaag in plaats van meerdere lagen. Het wordt meestal toegepast als gemiddelde pooling (GlobalAveragePooling2D) of maximale pooling (GlobalMaxPooling2D) en kan ook werken voor 1D- en 3D-invoer.

In plaats van een feature map af te vlakken, zoals: (7, 7, 32) in een vector met lengte 1536 en een of meerdere lagen trainen om patronen van deze lange vector te onderscheiden: we kunnen het condenseren tot een (7, 7) vector en classificeer direct vanaf daar. Het is zo simpel!

Merk op dat knelpunten voor netwerken zoals ResNets in tienduizenden functies tellen, niet slechts 1536. Bij het afvlakken martel je je netwerk om op een zeer inefficiënte manier te leren van vreemd gevormde vectoren. Stel je voor dat een 2D-afbeelding op elke pixelrij wordt gesneden en vervolgens wordt samengevoegd tot een platte vector. De twee pixels die voorheen verticaal 0 pixels uit elkaar lagen, zijn niet feature_map_width pixels horizontaal weg! Hoewel dit misschien niet zo veel uitmaakt voor een classificatie-algoritme, dat ruimtelijke invariantie bevordert, zou dit conceptueel niet eens goed zijn voor andere toepassingen van computervisie.

Laten we een klein demonstratief netwerk definiëren dat een afvlakkingslaag gebruikt met een aantal dichte lagen:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.summary()

Hoe ziet de samenvatting eruit?

...                                                              
 dense_6 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 11,574,090
Trainable params: 11,573,898
Non-trainable params: 192
_________________________________________________________________

11.5 miljoen parameters voor een speelgoednetwerk - en zie hoe de parameters exploderen met grotere input. 11.5M-parameters:. EfficientNets, een van de best presterende netwerken ooit ontworpen, werkt met ~6M parameters, en kan niet worden vergeleken met dit eenvoudige model in termen van werkelijke prestaties en capaciteit om van gegevens te leren.

We zouden dit aantal aanzienlijk kunnen verminderen door het netwerk dieper te maken, wat meer maximale pooling (en mogelijk getrapte convolutie) zou introduceren om de functiekaarten te verminderen voordat ze worden afgeplat. Bedenk echter dat we het netwerk complexer zouden maken om het minder rekenkundig duur te maken, allemaal omwille van een enkele laag die de plannen in de war gooit.

Dieper gaan met lagen zou moeten zijn om betekenisvollere, niet-lineaire relaties tussen gegevenspunten te extraheren, en niet om de invoergrootte te verkleinen om tegemoet te komen aan een afvlakkende laag.

Hier is een netwerk met wereldwijde pooling:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(10, activation='softmax')
])

model.summary()

Samenvatting?

 dense_8 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 66,602
Trainable params: 66,410
Non-trainable params: 192
_________________________________________________________________

Veel beter! Als we dieper gaan met dit model, zal het aantal parameters toenemen en kunnen we mogelijk meer ingewikkelde gegevenspatronen vastleggen met de nieuwe lagen. Als het echter naïef wordt gedaan, zullen dezelfde problemen ontstaan ​​​​die VGGNets bonden.

Verder gaan – Handmatig end-to-end project

Je leergierige karakter maakt dat je verder wilt gaan? We raden aan om onze Begeleid project: "Convolutionele neurale netwerken - verder dan basisarchitecturen".

Bekijk onze praktische, praktische gids voor het leren van Git, met best-practices, door de industrie geaccepteerde normen en bijgevoegd spiekbriefje. Stop met Googlen op Git-commando's en eigenlijk leren het!

Ik neem je mee op een tijdreis - van 1998 tot 2022, waarbij ik de bepalende architecturen belicht die door de jaren heen zijn ontwikkeld, wat ze uniek maakte, wat hun nadelen zijn, en de opmerkelijke vanaf het begin implementeren. Er is niets beter dan wat vuil op je handen te hebben als het om deze gaat.

U kunt autorijden zonder te weten of de motor 4 of 8 cilinders heeft en wat de plaatsing van de kleppen in de motor is. Maar als je een engine (computer vision-model) wilt ontwerpen en waarderen, moet je wat dieper gaan. Zelfs als u geen tijd wilt besteden aan het ontwerpen van architecturen en in plaats daarvan producten wilt bouwen, wat de meeste mensen willen doen, vindt u belangrijke informatie in deze les. Je leert waarom het gebruik van verouderde architecturen zoals VGGNet je product en prestaties schaadt, en waarom je ze moet overslaan als je iets moderns bouwt, en je leert bij welke architecturen je terecht kunt voor het oplossen van praktische problemen en wat de voor- en nadelen zijn voor elk.

Als u computervisie op uw vakgebied wilt toepassen, kunt u met behulp van de bronnen uit deze les de nieuwste modellen vinden, begrijpen hoe ze werken en op basis van welke criteria u ze kunt vergelijken en een beslissing kunt nemen over welke gebruiken.

You niet moeten Googlen voor architecturen en hun implementaties - ze worden meestal heel duidelijk uitgelegd in de papieren, en frameworks zoals Keras maken deze implementaties gemakkelijker dan ooit. Het belangrijkste van dit begeleide project is om u te leren hoe u architecturen en papers kunt vinden, lezen, implementeren en begrijpen. Geen enkele bron ter wereld zal in staat zijn om alle nieuwste ontwikkelingen bij te houden. Ik heb hier de nieuwste kranten bijgevoegd – maar over een paar maanden zullen er nieuwe verschijnen, en dat is onvermijdelijk. Als u weet waar u geloofwaardige implementaties kunt vinden, ze kunt vergelijken met papieren en ze kunt aanpassen, kunt u de concurrentievoorsprong krijgen die nodig is voor veel computervisieproducten die u misschien wilt bouwen.

Conclusie

In deze korte handleiding hebben we een alternatief voor flattening in CNN-architectuurontwerp bekeken. Hoewel kort, behandelt de gids een veelvoorkomend probleem bij het ontwerpen van prototypes of MVP's, en adviseert u een beter alternatief voor afvlakken te gebruiken.

Elke doorgewinterde Computer Vision Engineer zal dit principe kennen en toepassen, en de praktijk wordt als vanzelfsprekend beschouwd. Helaas lijkt het niet goed te worden doorgegeven aan nieuwe beoefenaars die net het veld betreden en kleverige gewoonten kunnen creëren waarvan het een tijdje duurt om er vanaf te komen.

Als je met Computer Vision begint, doe jezelf dan een plezier en gebruik geen afvlakkende lagen voor classificatiehoofden in je leertraject.

Tijdstempel:

Meer van Stapelmisbruik