Ikke bruk Flatten() - Global Pooling for CNN-er med TensorFlow og Keras PlatoBlockchain Data Intelligence. Vertikalt søk. Ai.

Ikke bruk Flatten() – Global Pooling for CNN-er med TensorFlow og Keras

De fleste utøvere, mens de først lærer om Convolutional Neural Network (CNN) arkitekturer – lærer at det består av tre grunnleggende segmenter:

  • Konvolusjonslag
  • Samle lag
  • Fullt tilkoblede lag

De fleste ressurser har noen variant av denne segmenteringen, inkludert min egen bok. Spesielt online – fullt tilkoblede lag refererer til en flate lag og (vanligvis) flere tette lag.

Dette pleide å være normen, og kjente arkitekturer som VGGNets brukte denne tilnærmingen, og ville ende i:

model = keras.Sequential([
    
    keras.layers.MaxPooling2D((2, 2), strides=(2, 2), padding='same'),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'), 
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dense(n_classes, activation='softmax')
])

Skjønt, av en eller annen grunn – blir det ofte glemt at VGGNet praktisk talt var den siste arkitekturen som brukte denne tilnærmingen, på grunn av den åpenbare beregningsflaskehalsen den skaper. Så snart ResNets, publisert bare året etter VGGNets (og for 7 år siden), avsluttet alle mainstream-arkitekturer sine modelldefinisjoner med:

model = keras.Sequential([
    
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(n_classes, activation='softmax')
])

Utflating i CNN har holdt på i 7 år. 7 år! Og ikke nok folk ser ut til å snakke om den skadelige effekten det har på både læringsopplevelsen og beregningsressursene du bruker.

Global Average Pooling er å foretrekke på mange kontoer fremfor utflating. Hvis du lager en prototy til et lite CNN – bruk Global Pooling. Hvis du lærer noen om CNN – bruk Global Pooling. Hvis du lager en MVP – bruk Global Pooling. Bruk utjevnende lag for andre brukstilfeller der de faktisk er nødvendig.

Kasusstudie – Flattening vs Global Pooling

Global Pooling kondenserer alle funksjonskartene til ett enkelt, og samler all relevant informasjon til et enkelt kart som lett kan forstås av et enkelt tett klassifiseringslag i stedet for flere lag. Det brukes vanligvis som gjennomsnittlig sammenslåing (GlobalAveragePooling2D) eller maks pooling (GlobalMaxPooling2D) og kan også fungere for 1D- og 3D-inngang.

I stedet for å flate ut et funksjonskart som f.eks (7, 7, 32) inn i en vektor med lengde 1536 og trener ett eller flere lag for å skjelne mønstre fra denne lange vektoren: vi kan kondensere den til en (7, 7) vektor og klassifiser direkte derfra. Så enkelt er det!

Legg merke til at flaskehals-lag for nettverk som ResNets teller i titusenvis av funksjoner, ikke bare 1536. Når du flater ut, torturerer du nettverket ditt for å lære av merkelig formede vektorer på en svært ineffektiv måte. Se for deg at et 2D-bilde blir delt opp i hver pikselrad og deretter sammenkoblet til en flat vektor. De to pikslene som tidligere var 0 piksler fra hverandre vertikalt er ikke det feature_map_width piksler vekk horisontalt! Selv om dette kanskje ikke betyr så mye for en klassifiseringsalgoritme, som favoriserer romlig invarians, ville dette ikke engang vært konseptuelt bra for andre applikasjoner av datasyn.

La oss definere et lite demonstrativt nettverk som bruker et utflatingslag med et par tette lag:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.summary()

Hvordan ser sammendraget ut?

...                                                              
 dense_6 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 11,574,090
Trainable params: 11,573,898
Non-trainable params: 192
_________________________________________________________________

11.5 millioner parametere for et leketøysnettverk – og se parameterne eksplodere med større input. 11.5 millioner parametere. EfficientNets, et av de beste nettverkene som noen gang er designet, fungerer ved ~6M parametere, og kan ikke sammenlignes med denne enkle modellen når det gjelder faktisk ytelse og kapasitet til å lære av data.

Vi kan redusere dette tallet betraktelig ved å gjøre nettverket dypere, noe som vil introdusere mer maksimal pooling (og potensielt skrittvis konvolusjon) for å redusere funksjonskartene før de blir flatet ut. Tenk imidlertid på at vi ville gjort nettverket mer komplekst for å gjøre det mindre beregningsmessig kostbart, alt av hensyn til et enkelt lag som kaster en skiftenøkkel i planene.

Å gå dypere med lag bør være å trekke ut mer meningsfulle, ikke-lineære relasjoner mellom datapunkter, og ikke redusere inngangsstørrelsen for å imøtekomme et utflatende lag.

Her er et nettverk med global pooling:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(10, activation='softmax')
])

model.summary()

Sammendrag?

 dense_8 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 66,602
Trainable params: 66,410
Non-trainable params: 192
_________________________________________________________________

Mye bedre! Hvis vi går dypere med denne modellen, vil parameterantallet øke, og vi kan kanskje fange opp mer intrikate mønstre av data med de nye lagene. Hvis det gjøres naivt, vil de samme problemene som binder VGGNets oppstå.

Gå videre – Håndholdt ende-til-ende-prosjekt

Din nysgjerrige natur gjør at du ønsker å gå lenger? Vi anbefaler å sjekke ut vår Veiledet prosjekt: "Konvolusjonelle nevrale nettverk - Beyond Basic Architectures".

Sjekk ut vår praktiske, praktiske guide for å lære Git, med beste praksis, bransjeaksepterte standarder og inkludert jukseark. Slutt å google Git-kommandoer og faktisk lære den!

Jeg tar deg med på en liten tidsreise – fra 1998 til 2022, og fremhever de definerende arkitekturene som er utviklet gjennom årene, hva som gjorde dem unike, hva deres ulemper er, og implementer de bemerkelsesverdige fra bunnen av. Det er ikke noe bedre enn å ha litt smuss på hendene når det kommer til disse.

Du kan kjøre bil uten å vite om motoren har 4 eller 8 sylindre og hva plasseringen av ventilene i motoren er. Men - hvis du vil designe og sette pris på en motor (datasynsmodell), vil du gå litt dypere. Selv om du ikke vil bruke tid på å designe arkitekturer og ønsker å bygge produkter i stedet, som er det de fleste ønsker å gjøre – vil du finne viktig informasjon i denne leksjonen. Du vil få vite hvorfor bruk av utdaterte arkitekturer som VGGNet vil skade produktet og ytelsen, og hvorfor du bør hoppe over dem hvis du bygger noe moderne, og du vil lære hvilke arkitekturer du kan gå til for å løse praktiske problemer og hva fordeler og ulemper er for hver.

Hvis du ønsker å bruke datasyn på feltet ditt, ved å bruke ressursene fra denne leksjonen – vil du kunne finne de nyeste modellene, forstå hvordan de fungerer og etter hvilke kriterier du kan sammenligne dem og ta en beslutning om hvilke bruk.

Du ikke må Google for arkitekturer og deres implementeringer – de er vanligvis veldig tydelig forklart i avisene, og rammeverk som Keras gjør disse implementeringene enklere enn noen gang. Nøkkelen til dette guidede prosjektet er å lære deg hvordan du finner, leser, implementerer og forstår arkitekturer og artikler. Ingen ressurs i verden vil kunne holde tritt med alle de nyeste utviklingene. Jeg har tatt med de nyeste avisene her – men om noen måneder dukker det opp nye, og det er uunngåelig. Å vite hvor du kan finne troverdige implementeringer, sammenligne dem med papirer og justere dem kan gi deg konkurransefortrinnet som kreves for mange datasynsprodukter du kanskje vil bygge.

konklusjonen

I denne korte guiden har vi tatt en titt på et alternativ til utflating i CNN-arkitekturdesign. Om enn kort – veiledningen tar for seg et vanlig problem når du designer prototyper eller MVP-er, og råder deg til å bruke et bedre alternativ til utflating.

Enhver erfaren datasynsingeniør vil kjenne til og anvende dette prinsippet, og praksisen tas for gitt. Dessverre ser det ikke ut til at det blir riktig videreformidlet til nye utøvere som nettopp er på vei inn i feltet, og kan skape klissete vaner som det tar litt tid å bli kvitt.

Hvis du begynner på Computer Vision – gjør deg selv en tjeneste og ikke bruk utflatende lag før klassifiseringshoder i læringsreisen din.

Tidstempel:

Mer fra Stackabuse