N'utilisez pas Flatten() - Global Pooling pour les CNN avec TensorFlow et Keras PlatoBlockchain Data Intelligence. Recherche verticale. Aï.

N'utilisez pas Flatten() - Regroupement global pour les CNN avec TensorFlow et Keras

La plupart des praticiens, tout en découvrant pour la première fois les architectures de réseau de neurones convolutifs (CNN), apprennent qu'il est composé de trois segments de base :

  • Couches convolutionnelles
  • Mise en commun des couches
  • Couches entièrement connectées

La plupart des ressources ont quelques variation sur cette segmentation, y compris mon propre livre. Surtout en ligne - les couches entièrement connectées font référence à un couche d'aplatissement et (généralement) plusieurs couches denses.

C'était la norme et des architectures bien connues telles que VGGNets utilisaient cette approche et se terminaient par :

model = keras.Sequential([
    
    keras.layers.MaxPooling2D((2, 2), strides=(2, 2), padding='same'),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'), 
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dense(n_classes, activation='softmax')
])

Cependant, pour une raison quelconque, on oublie souvent que VGGNet était pratiquement la dernière architecture à utiliser cette approche, en raison du goulot d'étranglement informatique évident qu'elle crée. Dès que ResNets, publié juste un an après VGGNets (et il y a 7 ans), toutes les architectures grand public ont terminé leurs définitions de modèles avec :

model = keras.Sequential([
    
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(n_classes, activation='softmax')
])

L'aplatissement des CNN dure depuis 7 ans. 7 ans! Et pas assez de gens semblent parler de l'effet néfaste que cela a sur votre expérience d'apprentissage et sur les ressources informatiques que vous utilisez.

Global Average Pooling est préférable sur de nombreux comptes à l'aplatissement. Si vous prototypez un petit CNN, utilisez Global Pooling. Si vous enseignez à quelqu'un sur les CNN, utilisez Global Pooling. Si vous créez un MVP, utilisez Global Pooling. Utilisez des calques d'aplatissement pour d'autres cas d'utilisation où ils sont réellement nécessaires.

Étude de cas - Flattening vs Global Pooling

La mise en commun globale condense toutes les cartes d'entités en une seule, regroupant toutes les informations pertinentes dans une seule carte qui peut être facilement comprise par une seule couche de classification dense au lieu de plusieurs couches. Il est généralement appliqué en tant que mise en commun moyenne (GlobalAveragePooling2D) ou regroupement maximum (GlobalMaxPooling2D) et peut également fonctionner pour les entrées 1D et 3D.

Au lieu d'aplatir une carte d'entités telle que (7, 7, 32) en un vecteur de longueur 1536 et former une ou plusieurs couches pour discerner les motifs de ce long vecteur : nous pouvons le condenser en un (7, 7) vecteur et classer directement à partir de là. C'est si simple!

Notez que les couches de goulot d'étranglement pour des réseaux comme ResNets comptent des dizaines de milliers de fonctionnalités, pas seulement 1536. Lors de l'aplatissement, vous torturez votre réseau pour apprendre de vecteurs de forme étrange d'une manière très inefficace. Imaginez une image 2D découpée en tranches sur chaque rangée de pixels, puis concaténée en un vecteur plat. Les deux pixels qui étaient séparés verticalement de 0 pixels ne sont plus feature_map_width pixels horizontalement ! Bien que cela n'ait pas trop d'importance pour un algorithme de classification, qui favorise l'invariance spatiale, cela ne serait même pas conceptuellement bon pour d'autres applications de la vision par ordinateur.

Définissons un petit réseau de démonstration qui utilise une couche d'aplatissement avec quelques couches denses :

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.summary()

À quoi ressemble le résumé ?

...                                                              
 dense_6 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 11,574,090
Trainable params: 11,573,898
Non-trainable params: 192
_________________________________________________________________

11.5 millions de paramètres pour un réseau de jouets - et regardez les paramètres exploser avec une entrée plus importante. 11.5 millions de paramètres. EfficientNets, l'un des réseaux les plus performants jamais conçus, fonctionne à environ 6 millions de paramètres et ne peut être comparé à ce modèle simple en termes de performances réelles et de capacité à apprendre des données.

Nous pourrions réduire ce nombre de manière significative en rendant le réseau plus profond, ce qui introduirait plus de pooling maximum (et potentiellement de convolution striée) pour réduire les cartes de caractéristiques avant qu'elles ne soient aplaties. Cependant, considérez que nous rendrions le réseau plus complexe afin de le rendre moins coûteux en calcul, tout cela pour une seule couche qui jette une clé dans les plans.

Aller plus loin avec les couches devrait consister à extraire des relations non linéaires plus significatives entre les points de données, et non à réduire la taille d'entrée pour répondre à une couche d'aplatissement.

Voici un réseau avec mutualisation mondiale :

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(10, activation='softmax')
])

model.summary()

Résumé?

 dense_8 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 66,602
Trainable params: 66,410
Non-trainable params: 192
_________________________________________________________________

Bien mieux ! Si nous approfondissons ce modèle, le nombre de paramètres augmentera et nous pourrons peut-être capturer des modèles de données plus complexes avec les nouvelles couches. Si cela est fait naïvement, les mêmes problèmes qui ont lié VGGNets se poseront.

Aller plus loin - Projet de bout en bout tenu dans la main

Votre nature curieuse vous donne envie d'aller plus loin ? Nous vous recommandons de consulter notre Projet guidé: "Réseaux de neurones convolutifs - Au-delà des architectures de base".

Consultez notre guide pratique et pratique pour apprendre Git, avec les meilleures pratiques, les normes acceptées par l'industrie et la feuille de triche incluse. Arrêtez de googler les commandes Git et en fait apprendre il!

Je vais vous emmener dans un voyage dans le temps - de 1998 à 2022, en mettant en évidence les architectures déterminantes développées au fil des ans, ce qui les rend uniques, quels sont leurs inconvénients et en mettant en œuvre les plus notables à partir de zéro. Il n'y a rien de mieux que d'avoir de la saleté sur les mains quand il s'agit de ces derniers.

Vous pouvez conduire une voiture sans savoir si le moteur a 4 ou 8 cylindres et quel est l'emplacement des soupapes dans le moteur. Cependant, si vous souhaitez concevoir et apprécier un moteur (modèle de vision par ordinateur), vous voudrez aller un peu plus loin. Même si vous ne souhaitez pas passer du temps à concevoir des architectures et que vous souhaitez plutôt créer des produits, ce que la plupart souhaitent faire, vous trouverez des informations importantes dans cette leçon. Vous apprendrez pourquoi l'utilisation d'architectures obsolètes comme VGGNet nuira à votre produit et à vos performances, et pourquoi vous devriez les ignorer si vous construisez quelque chose de moderne, et vous apprendrez à quelles architectures vous pouvez vous adresser pour résoudre des problèmes pratiques et quoi les avantages et les inconvénients sont pour chacun.

Si vous cherchez à appliquer la vision par ordinateur à votre domaine, en utilisant les ressources de cette leçon - vous serez en mesure de trouver les modèles les plus récents, de comprendre comment ils fonctionnent et selon quels critères vous pouvez les comparer et prendre une décision sur laquelle utilisation.

Vous ne voulez pas doivent Google pour les architectures et leurs implémentations - elles sont généralement très clairement expliquées dans les articles, et des frameworks comme Keras rendent ces implémentations plus faciles que jamais. L'essentiel à retenir de ce projet guidé est de vous apprendre à trouver, lire, mettre en œuvre et comprendre les architectures et les documents. Aucune ressource au monde ne pourra suivre tous les développements les plus récents. J'ai inclus les articles les plus récents ici - mais dans quelques mois, de nouveaux apparaîtront, et c'est inévitable. Savoir où trouver des implémentations crédibles, les comparer aux documents et les ajuster peut vous donner l'avantage concurrentiel requis pour de nombreux produits de vision par ordinateur que vous souhaitez créer.

Conclusion

Dans ce petit guide, nous avons examiné une alternative à l'aplatissement dans la conception de l'architecture CNN. Bien que court, le guide aborde un problème courant lors de la conception de prototypes ou de MVP, et vous conseille d'utiliser une meilleure alternative à l'aplatissement.

Tout ingénieur chevronné en vision par ordinateur connaîtra et appliquera ce principe, et la pratique est tenue pour acquise. Malheureusement, cela ne semble pas être correctement transmis aux nouveaux praticiens qui viennent d'entrer sur le terrain et peut créer des habitudes collantes dont il faut du temps pour se débarrasser.

Si vous vous lancez dans la vision par ordinateur, rendez-vous service et n'utilisez pas de calques d'aplatissement avant la classification en tête de votre parcours d'apprentissage.

Horodatage:

Plus de Stackabuse