Använd inte Flatten() - Global Pooling för CNN:er med TensorFlow och Keras PlatoBlockchain Data Intelligence. Vertikal sökning. Ai.

Använd inte Flatten() – Global Pooling för CNN med TensorFlow och Keras

De flesta utövare lär sig, samtidigt som de först lär sig om arkitekturer för Convolutional Neural Network (CNN) – att det består av tre grundläggande segment:

  • Konvolutionella lager
  • Samla lager
  • Fullt anslutna lager

De flesta resurser har några variation på denna segmentering, inklusive min egen bok. Speciellt online – helt anslutna lager hänvisar till en tillplattande lager och (vanligtvis) flera täta lager.

Detta brukade vara normen, och välkända arkitekturer som VGGNets använde detta tillvägagångssätt och skulle sluta i:

model = keras.Sequential([
    
    keras.layers.MaxPooling2D((2, 2), strides=(2, 2), padding='same'),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'), 
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dense(n_classes, activation='softmax')
])

Även om det av någon anledning ofta glöms bort att VGGNet praktiskt taget var den sista arkitekturen som använde detta tillvägagångssätt, på grund av den uppenbara beräkningsflaskhalsen som den skapar. Så snart ResNets, publicerad bara året efter VGGNets (och för 7 år sedan), avslutade alla vanliga arkitekturer sina modelldefinitioner med:

model = keras.Sequential([
    
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(n_classes, activation='softmax')
])

Tillplattning i CNN har funnits i 7 år. 7 år! Och inte tillräckligt många människor verkar prata om den skadliga effekten det har på både din inlärningsupplevelse och de beräkningsresurser du använder.

Global Average Pooling är att föredra på många konton framför utjämning. Om du skapar en prototyp till ett litet CNN – använd Global Pooling. Om du lär någon om CNN – använd Global Pooling. Om du gör en MVP – använd Global Pooling. Använd utplattande lager för andra användningsfall där de faktiskt behövs.

Fallstudie – Utjämning vs global pooling

Global Pooling kondenserar alla funktionskartor till en enda, och samlar all relevant information till en enda karta som lätt kan förstås av ett enda tätt klassificeringsskikt istället för flera lager. Det används vanligtvis som genomsnittlig pooling (GlobalAveragePooling2D) eller max pooling (GlobalMaxPooling2D) och kan även fungera för 1D- och 3D-ingång.

Istället för att platta till en funktionskarta som t.ex (7, 7, 32) till en vektor med längden 1536 och träna ett eller flera lager för att urskilja mönster från denna långa vektor: vi kan kondensera den till en (7, 7) vektor och klassificera direkt därifrån. Det är så enkelt!

Observera att flaskhalsskikt för nätverk som ResNets räknas i tiotusentals funktioner, inte bara 1536. När du plattar ut torterar du ditt nätverk för att lära dig från konstigt formade vektorer på ett mycket ineffektivt sätt. Föreställ dig en 2D-bild som skivas på varje pixelrad och sedan sammanfogas till en platt vektor. De två pixlarna som brukade vara 0 pixlar från varandra vertikalt är det inte feature_map_width pixlar bort horisontellt! Även om detta kanske inte spelar så stor roll för en klassificeringsalgoritm, som gynnar rumslig invarians, skulle detta inte ens vara begreppsmässigt bra för andra tillämpningar av datorseende.

Låt oss definiera ett litet demonstrativt nätverk som använder ett utplattande lager med ett par täta lager:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.summary()

Hur ser sammanfattningen ut?

...                                                              
 dense_6 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 11,574,090
Trainable params: 11,573,898
Non-trainable params: 192
_________________________________________________________________

11.5 miljoner parametrar för ett leksaksnätverk – och se parametrarna explodera med större input. 11.5 miljoner parametrar. EfficientNets, ett av de bäst presterande nätverken som någonsin designats fungerar vid ~6M parametrar, och kan inte jämföras med denna enkla modell när det gäller faktisk prestanda och förmåga att lära av data.

Vi skulle kunna minska detta antal avsevärt genom att göra nätverket djupare, vilket skulle introducera mer maxpooling (och potentiellt stegrad faltning) för att minska funktionskartorna innan de plattas ut. Men tänk på att vi skulle göra nätverket mer komplext för att göra det billigare beräkningsmässigt dyrt, allt för ett enda lager som kastar en skiftnyckel i planerna.

Att gå djupare med lager bör vara att extrahera mer meningsfulla, icke-linjära relationer mellan datapunkter, inte att minska indatastorleken för att tillgodose ett tillplattande lager.

Här är ett nätverk med global pooling:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(10, activation='softmax')
])

model.summary()

Sammanfattning?

 dense_8 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 66,602
Trainable params: 66,410
Non-trainable params: 192
_________________________________________________________________

Mycket bättre! Om vi ​​går djupare med den här modellen kommer parameterantalet att öka, och vi kanske kan fånga mer intrikata mönster av data med de nya lagren. Men om det görs naivt kommer samma problem som binder VGGNets att uppstå.

Gå vidare – Handhållet End-to-End-projekt

Din nyfikna natur gör att du vill gå längre? Vi rekommenderar att du kollar in vår Guidade projekt: "Konvolutionella neurala nätverk – bortom grundläggande arkitekturer".

Kolla in vår praktiska, praktiska guide för att lära dig Git, med bästa praxis, branschaccepterade standarder och medföljande fuskblad. Sluta googla Git-kommandon och faktiskt lära Det!

Jag tar dig med på lite tidsresor – från 1998 till 2022, och lyfter fram de definierande arkitekturerna som utvecklats under åren, vad som gjorde dem unika, vilka nackdelar de har och implementera de anmärkningsvärda från grunden. Det finns inget bättre än att ha lite smuts på händerna när det kommer till dessa.

Du kan köra bil utan att veta om motorn har 4 eller 8 cylindrar och hur ventilerna är placerade i motorn. Men – om du vill designa och uppskatta en motor (modell för datorseende) vill du gå lite djupare. Även om du inte vill lägga tid på att designa arkitekturer och istället vill bygga produkter, vilket är vad de flesta vill göra – så hittar du viktig information i den här lektionen. Du kommer att få lära dig varför användning av föråldrade arkitekturer som VGGNet kommer att skada din produkt och prestanda, och varför du bör hoppa över dem om du bygger något modernt, och du kommer att lära dig vilka arkitekturer du kan gå till för att lösa praktiska problem och vad för- och nackdelarna är för var och en.

Om du funderar på att tillämpa datorseende på ditt område, med hjälp av resurserna från den här lektionen – kommer du att kunna hitta de senaste modellerna, förstå hur de fungerar och utifrån vilka kriterier du kan jämföra dem och fatta ett beslut om vilka använda sig av.

Om er inte måste Google för arkitekturer och deras implementeringar – de är vanligtvis mycket tydligt förklarade i tidningarna, och ramverk som Keras gör dessa implementeringar enklare än någonsin. Det viktigaste med det här guidade projektet är att lära dig hur du hittar, läser, implementerar och förstår arkitekturer och uppsatser. Ingen resurs i världen kommer att kunna hänga med i alla de senaste utvecklingarna. Jag har tagit med de senaste tidningarna här – men om några månader kommer nya att dyka upp, och det är oundvikligt. Att veta var man kan hitta trovärdiga implementeringar, jämföra dem med tidningar och justera dem kan ge dig den konkurrensfördel som krävs för många datorseendeprodukter du kanske vill bygga.

Slutsats

I den här korta guiden har vi tagit en titt på ett alternativ till tillplattning i CNN-arkitekturdesign. Om än kort – guiden tar upp ett vanligt problem när man designar prototyper eller MVP:er, och råder dig att använda ett bättre alternativ till plattning.

Alla erfarna datavisionsingenjörer kommer att känna till och tillämpa denna princip, och praktiken tas för given. Tyvärr verkar det inte vara ordentligt vidarebefordrat till nya utövare som precis kommer in på fältet, och kan skapa klibbiga vanor som tar ett tag att bli av med.

Om du börjar med Computer Vision – gör dig själv en tjänst och använd inte utplattande lager före klassificeringshuvuden i din inlärningsresa.

Tidsstämpel:

Mer från Stackabuse