Не використовуйте Flatten() – глобальне об’єднання для CNN із TensorFlow і Keras PlatoBlockchain Data Intelligence. Вертикальний пошук. Ai.

Не використовуйте Flatten() – глобальне об’єднання для CNN із TensorFlow і Keras

Більшість практиків, вперше вивчаючи архітектуру згорткової нейронної мережі (CNN), дізнаються, що вона складається з трьох основних сегментів:

  • Згорткові шари
  • Об'єднання шарів
  • Повністю підключені шари

Більшість ресурсів є деякі варіації цієї сегментації, включаючи мою власну книгу. Особливо онлайн – повністю пов’язані рівні стосуються a сплющуючий шар і (зазвичай) кілька щільні шари.

Раніше це було нормою, і добре відомі архітектури, такі як VGGNets, використовували цей підхід і закінчувалися такими:

model = keras.Sequential([
    
    keras.layers.MaxPooling2D((2, 2), strides=(2, 2), padding='same'),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'), 
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dense(n_classes, activation='softmax')
])

Хоча з якоїсь причини часто забувають, що VGGNet була практично останньою архітектурою, яка використовувала цей підхід, через очевидні обчислювальні вузькі місця, які він створює. Щойно ResNets, опублікований всього через рік після VGGNets (і 7 років тому), усі основні архітектури завершили визначення моделей такими словами:

model = keras.Sequential([
    
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(n_classes, activation='softmax')
])

Площення в CNN триває вже 7 років. 7 роки! І, здається, мало людей говорять про згубний вплив, який він має як на ваш досвід навчання, так і на обчислювальні ресурси, які ви використовуєте.

Глобальне середнє об’єднання є кращим для багатьох облікових записів, ніж зведення. Якщо ви створюєте прототип невеликого CNN – використовуйте глобальне об’єднання. Якщо ви навчаєте когось CNN – використовуйте Global Pooling. Якщо ви робите MVP – використовуйте Global Pooling. Використовуйте зведені шари для інших випадків використання, де вони дійсно потрібні.

Практичний приклад – зведення проти глобального об’єднання

Глобальне об’єднання об’єднує всі карти функцій в одну, об’єднуючи всю релевантну інформацію в одну карту, яку можна легко зрозуміти на одному щільному шарі класифікації замість кількох шарів. Зазвичай він використовується як середнє об’єднання (GlobalAveragePooling2D) або максимальне об’єднання (GlobalMaxPooling2D) і може також працювати для введення 1D і 3D.

Замість зведення карти функцій, наприклад (7, 7, 32) у вектор довжиною 1536 і навчаючи один або декілька шарів розпізнавати візерунки з цього довгого вектора: ми можемо конденсувати його в (7, 7) вектор і класифікуйте безпосередньо звідти. Це так просто!

Зауважте, що рівні вузьких місць для таких мереж, як ResNets, налічують десятки тисяч функцій, а не просто 1536. Під час зведення ви мучите свою мережу, щоб навчатися на векторах дивної форми в дуже неефективний спосіб. Уявіть собі, що 2D-зображення нарізане на кожен піксельний рядок, а потім об’єднано в плоский вектор. Два пікселі, які раніше були на відстані 0 пікселів один від одного по вертикалі, такими не є feature_map_width пікселів по горизонталі! Хоча це може не мати великого значення для алгоритму класифікації, який надає перевагу просторовій інваріантності, це навіть концептуально не буде добре для інших застосувань комп’ютерного зору.

Давайте визначимо невелику демонстраційну мережу, яка використовує шар зведення з парою щільних шарів:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.summary()

Як виглядає резюме?

...                                                              
 dense_6 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 11,574,090
Trainable params: 11,573,898
Non-trainable params: 192
_________________________________________________________________

11.5 млн параметрів для іграшкової мережі – і дивіться, як параметри вибухають із збільшенням вхідних даних. Параметри 11.5М. EfficientNets, одна з найефективніших мереж, коли-небудь створених, працює з ~6 млн параметрів, і її неможливо порівняти з цією простою моделлю з точки зору фактичної продуктивності та здатності вивчати дані.

Ми могли б значно зменшити це число, зробивши мережу глибшою, що запровадило б більше максимального об’єднання (і потенційно поступової згортки), щоб зменшити карти функцій до того, як вони будуть зведені. Однак подумайте, що ми зробимо мережу складнішою, щоб зробити її менш дорогою з точки зору обчислень, і все це заради єдиного рівня, який кидає гайковий ключ у плани.

Заглиблення в шари повинно полягати в тому, щоб отримати більш значущі, нелінійні зв’язки між точками даних, не зменшуючи розмір вхідних даних для задоволення шару зведення.

Ось мережа з глобальним об’єднанням:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(10, activation='softmax')
])

model.summary()

Резюме?

 dense_8 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 66,602
Trainable params: 66,410
Non-trainable params: 192
_________________________________________________________________

Набагато краще! Якщо ми заглибимося в цю модель, кількість параметрів збільшиться, і ми, можливо, зможемо фіксувати складніші моделі даних за допомогою нових шарів. Однак якщо це зробити наївно, виникнуть ті самі проблеми, що пов’язані з VGGNets.

Йти далі – наскрізний ручний проект

Ваша допитлива природа змушує вас йти далі? Ми рекомендуємо перевірити наш Керований проект: «Згорткові нейронні мережі – за межами базової архітектури».

Ознайомтеся з нашим практичним практичним посібником із вивчення Git з передовими методами, прийнятими в галузі стандартами та включеною шпаргалкою. Припиніть гуглити команди Git і фактично вчитися це!

Я відправлю вас у невелику подорож у часі – з 1998 по 2022 рік, виділяючи визначальні архітектури, розроблені протягом багатьох років, що робить їх унікальними, які їхні недоліки, а також реалізую помітні з нуля. Немає нічого кращого, ніж мати трохи бруду на своїх руках, коли справа доходить до цього.

Ви можете керувати автомобілем, не знаючи, чи має двигун 4 чи 8 циліндрів і яке розташування клапанів у двигуні. Однак якщо ви хочете розробити і оцінити механізм (модель комп’ютерного бачення), ви захочете піти трохи глибше. Навіть якщо ви не хочете витрачати час на проектування архітектури, а натомість хочете створювати продукти, чого більшість хоче робити, ви знайдете важливу інформацію в цьому уроці. Ви дізнаєтеся, чому використання застарілих архітектур, таких як VGGNet, зашкодить вашому продукту та продуктивності, і чому ви повинні пропустити їх, якщо створюєте щось сучасне, а також ви дізнаєтеся, до яких архітектур можна звернутися для вирішення практичних проблем і які плюси і мінуси є для кожного.

Якщо ви прагнете застосувати комп’ютерний зір у своїй галузі, використовуючи ресурси з цього уроку, ви зможете знайти найновіші моделі, зрозуміти, як вони працюють, за якими критеріями їх можна порівняти та прийняти рішення щодо використовувати.

Ти НЕ потрібно шукати в Google архітектури та їх реалізації – вони зазвичай дуже чітко пояснюються в документах, а такі фреймворки, як Keras, роблять ці реалізації легшими, ніж будь-коли. Основний висновок цього керованого проекту — навчити вас знаходити, читати, впроваджувати та розуміти архітектури та документи. Жоден ресурс у світі не зможе встигнути за всіма новинками. Я включив сюди найновіші статті, але через кілька місяців з’являться нові, і це неминуче. Знання, де знайти надійні реалізації, порівняти їх із документами та налаштувати їх, може дати вам конкурентну перевагу, необхідну для багатьох продуктів комп’ютерного зору, які ви, можливо, захочете створити.

Висновок

У цьому короткому посібнику ми розглянули альтернативу зведенню в дизайні архітектури CNN. Хоч і коротко, посібник розглядає поширену проблему під час розробки прототипів або MVP і радить вам використовувати кращу альтернативу зведенню.

Будь-який досвідчений інженер комп’ютерного бачення знатиме та застосовуватиме цей принцип, і практика сприйматиметься як належне. На жаль, здається, що новачкам-практикам, які тільки виходять на ринок, це не передається належним чином, і вони можуть створити в’язкі звички, яких потрібен час, щоб позбутися.

Якщо ви починаєте вивчати комп’ютерне бачення, зробіть собі послугу і не використовуйте шари зведення перед класифікаційними головками під час навчання.

Часова мітка:

Більше від Stackabuse