Nie używaj Flatten() — globalne łączenie dla CNN z TensorFlow i Keras PlatoBlockchain Data Intelligence. Wyszukiwanie pionowe. AI.

Nie używaj Flatten() – globalne łączenie dla CNN z TensorFlow i Keras

Większość praktyków, po raz pierwszy zapoznając się z architekturami Convolutional Neural Network (CNN) – dowiaduje się, że składa się ona z trzech podstawowych segmentów:

  • Warstwy splotowe
  • Łączenie warstw
  • W pełni połączone warstwy

Większość zasobów ma kilka wariacja na temat tej segmentacji, w tym moja własna książka. Zwłaszcza online – w pełni połączone warstwy odnoszą się do a warstwa spłaszczająca i (zwykle) wielokrotne gęste warstwy.

Kiedyś było to normą, a znane architektury, takie jak VGGNets, stosowały to podejście i kończyły się na:

model = keras.Sequential([
    
    keras.layers.MaxPooling2D((2, 2), strides=(2, 2), padding='same'),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'), 
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dense(n_classes, activation='softmax')
])

Chociaż z jakiegoś powodu – często zapomina się, że VGGNet była praktycznie ostatnią architekturą, w której zastosowano to podejście, ze względu na oczywiste wąskie gardło obliczeniowe, jakie tworzy. Jak tylko ResNets, opublikowany zaledwie rok po VGGNets (i 7 lat temu), wszystkie architektury głównego nurtu zakończyły swoje definicje modeli:

model = keras.Sequential([
    
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(n_classes, activation='softmax')
])

Spłaszczanie w CNN utrzymuje się od 7 lat. 7 roku! I wydaje się, że zbyt mało ludzi mówi o szkodliwym wpływie, jaki ma to zarówno na twoje doświadczenie w nauce, jak i na zasoby obliczeniowe, z których korzystasz.

Globalne łączenie średnich jest preferowane na wielu kontach niż spłaszczanie. Jeśli tworzysz prototypy małej CNN – skorzystaj z Global Pooling. Jeśli uczysz kogoś o CNN – skorzystaj z Global Pooling. Jeśli robisz MVP – użyj globalnego łączenia. Używaj warstw spłaszczających w innych przypadkach użycia, w których są one rzeczywiście potrzebne.

Studium przypadku – spłaszczanie kontra globalne łączenie

Globalne łączenie łączy wszystkie mapy obiektów w jedną, łącząc wszystkie istotne informacje w jedną mapę, którą można łatwo zrozumieć za pomocą pojedynczej, gęstej warstwy klasyfikacji zamiast wielu warstw. Zwykle jest stosowany jako średnie łączenie (GlobalAveragePooling2D) lub maksymalna pula (GlobalMaxPooling2D) i może pracować również dla danych wejściowych 1D i 3D.

Zamiast spłaszczać mapę funkcji, taką jak (7, 7, 32) w wektor o długości 1536 i trenując jedną lub wiele warstw, aby odróżnić wzorce z tego długiego wektora: możemy skondensować go w (7, 7) wektor i klasyfikować bezpośrednio z tego miejsca. To takie proste!

Zwróć uwagę, że warstwy wąskich gardeł dla sieci takich jak ResNets liczą się w dziesiątkach tysięcy funkcji, a nie zaledwie 1536. Podczas spłaszczania torturujesz swoją sieć, aby uczyć się z wektorów o dziwnych kształtach w bardzo nieefektywny sposób. Wyobraź sobie obraz 2D pocięty na każdym rzędzie pikseli, a następnie połączony w płaski wektor. Dwa piksele, które były oddalone od siebie o 0 pikseli w pionie, nie są feature_map_width piksele w poziomie! Chociaż może to nie mieć większego znaczenia dla algorytmu klasyfikacji, który faworyzuje niezmienność przestrzenną – nie byłoby to nawet koncepcyjnie dobre dla innych zastosowań wizji komputerowej.

Zdefiniujmy małą sieć demonstracyjną, która używa warstwy spłaszczania z kilkoma gęstymi warstwami:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.summary()

Jak wygląda podsumowanie?

...                                                              
 dense_6 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 11,574,090
Trainable params: 11,573,898
Non-trainable params: 192
_________________________________________________________________

11.5 mln parametrów dla sieci zabawek – i obserwuj, jak parametry eksplodują przy większym nakładzie. 11.5 mln parametrów. EfficientNets, jedna z najbardziej wydajnych sieci, jakie kiedykolwiek zaprojektowano, działa z parametrami ~6 mln i nie można jej porównać z tym prostym modelem pod względem rzeczywistej wydajności i zdolności uczenia się na podstawie danych.

Moglibyśmy znacznie zmniejszyć tę liczbę, pogłębiając sieć, co wprowadziłoby więcej maksymalnego łączenia (i potencjalnie stopniowego splotu), aby zmniejszyć mapy funkcji, zanim zostaną spłaszczone. Należy jednak wziąć pod uwagę, że uczynilibyśmy sieć bardziej złożoną, aby uczynić ją mniej kosztowną obliczeniowo, a wszystko to ze względu na pojedynczą warstwę, która zakłóca plany.

Wchodzenie głębiej w warstwy powinno polegać na wyodrębnieniu bardziej znaczących, nieliniowych relacji między punktami danych, a nie zmniejszaniu rozmiaru wejściowego w celu zaspokojenia warstwy spłaszczonej.

Oto sieć z globalnym łączeniem:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(10, activation='softmax')
])

model.summary()

Streszczenie?

 dense_8 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 66,602
Trainable params: 66,410
Non-trainable params: 192
_________________________________________________________________

Dużo lepiej! Jeśli wejdziemy głębiej w ten model, liczba parametrów wzrośnie i możemy być w stanie uchwycić bardziej skomplikowane wzorce danych za pomocą nowych warstw. Jeśli jednak zostanie to zrobione naiwnie, pojawią się te same problemy, które wiążą VGGNets.

Idąc dalej – ręczny projekt end-to-end

Twoja dociekliwość sprawia, że ​​chcesz iść dalej? Zalecamy sprawdzenie naszego Projekt z przewodnikiem: „Splotowe sieci neuronowe – poza podstawowymi architekturami”.

Zapoznaj się z naszym praktycznym, praktycznym przewodnikiem dotyczącym nauki Git, zawierającym najlepsze praktyki, standardy przyjęte w branży i dołączoną ściągawkę. Zatrzymaj polecenia Google Git, a właściwie uczyć się to!

Zabiorę Cię w małą podróż w czasie – od 1998 do 2022 roku, podkreślając definiujące architektury opracowane przez lata, co czyni je wyjątkowymi, jakie są ich wady i wdrażając te godne uwagi od podstaw. Nie ma nic lepszego niż brud na rękach, jeśli chodzi o te.

Możesz prowadzić samochód, nie wiedząc, czy silnik ma 4 czy 8 cylindrów i jakie jest rozmieszczenie zaworów w silniku. Jednak – jeśli chcesz zaprojektować i docenić silnik (model wizyjny komputerowy), będziesz chciał wejść nieco głębiej. Nawet jeśli nie chcesz spędzać czasu na projektowaniu architektur, a zamiast tego chcesz budować produkty, co jest tym, co najbardziej chcesz zrobić – w tej lekcji znajdziesz ważne informacje. Dowiesz się, dlaczego używanie przestarzałych architektur, takich jak VGGNet, zaszkodzi Twojemu produktowi i wydajności oraz dlaczego powinieneś je pominąć, jeśli tworzysz coś nowoczesnego, a także dowiesz się, do których architektur możesz się udać, aby rozwiązać praktyczne problemy i jakie plusy i minusy są dla każdego.

Jeśli chcesz zastosować wizję komputerową w swojej dziedzinie, korzystając z zasobów z tej lekcji – będziesz w stanie znaleźć najnowsze modele, zrozumieć, jak działają i według jakich kryteriów możesz je porównać i podjąć decyzję, które posługiwać się.

You nie trzeba Google dla architektur i ich implementacji – zazwyczaj są one bardzo jasno wyjaśnione w dokumentach, a frameworki takie jak Keras sprawiają, że te implementacje są łatwiejsze niż kiedykolwiek. Kluczowym wnioskiem tego projektu z przewodnikiem jest nauczenie Cię, jak znajdować, czytać, wdrażać i rozumieć architektury i dokumenty. Żaden zasób na świecie nie będzie w stanie nadążyć za wszystkimi najnowszymi osiągnięciami. Zamieściłem tutaj najnowsze artykuły – ale za kilka miesięcy pojawią się nowe, a to nieuniknione. Wiedza o tym, gdzie znaleźć wiarygodne implementacje, porównać je z dokumentami i dostosować je, może zapewnić przewagę konkurencyjną wymaganą w przypadku wielu produktów wizyjnych, które możesz chcieć zbudować.

Wnioski

W tym krótkim przewodniku przyjrzeliśmy się alternatywie dla spłaszczania w projektowaniu architektury CNN. Choć krótki – przewodnik porusza typowy problem podczas projektowania prototypów lub MVP i radzi, aby użyć lepszej alternatywy dla spłaszczania.

Każdy doświadczony inżynier wizji komputerowej będzie znał i stosował tę zasadę, a praktyka jest uważana za pewnik. Niestety, wydaje się, że nie jest to właściwie przekazywane nowym praktykującym, którzy dopiero wchodzą w pole i mogą tworzyć lepkie nawyki, których pozbycie się zajmuje trochę czasu.

Jeśli zaczynasz przygodę z komputerowym widzeniem – zrób sobie przysługę i nie używaj warstw spłaszczających przed klasyfikacjami w swojej podróży edukacyjnej.

Znak czasu:

Więcej z Nadużycie stosu