No utilice Flatten(): agrupación global para CNN con TensorFlow y Keras PlatoBlockchain Data Intelligence. Búsqueda vertical. Ai.

No use Flatten (): agrupación global para CNN con TensorFlow y Keras

La mayoría de los profesionales, mientras aprenden por primera vez sobre las arquitecturas de redes neuronales convolucionales (CNN), aprenden que se compone de tres segmentos básicos:

  • Capas convolucionales
  • Capas de agrupación
  • Capas totalmente conectadas

La mayoría de los recursos tienen algo variación de esta segmentación, incluido mi propio libro. Especialmente en línea: las capas totalmente conectadas se refieren a un capa de aplanamiento y (generalmente) múltiples capas densas.

Esta solía ser la norma, y ​​arquitecturas conocidas como VGGNets usaban este enfoque y terminaban en:

model = keras.Sequential([
    
    keras.layers.MaxPooling2D((2, 2), strides=(2, 2), padding='same'),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'), 
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dense(n_classes, activation='softmax')
])

Sin embargo, por alguna razón, a menudo se olvida que VGGNet fue prácticamente la última arquitectura en utilizar este enfoque, debido al obvio cuello de botella computacional que crea. Tan pronto como ResNets, publicado justo un año después de VGGNets (y hace 7 años), todas las arquitecturas principales terminaron sus definiciones de modelo con:

model = keras.Sequential([
    
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(n_classes, activation='softmax')
])

El aplanamiento en las CNN se ha mantenido durante 7 años. 7 años! Y no parece que haya suficientes personas hablando sobre el efecto dañino que tiene tanto en su experiencia de aprendizaje como en los recursos computacionales que está utilizando.

La agrupación promedio global es preferible en muchas cuentas al aplanamiento. Si está creando un prototipo de una CNN pequeña, use Global Pooling. Si le está enseñando a alguien sobre las CNN, use Global Pooling. Si está haciendo un MVP, use Global Pooling. Use capas acopladas para otros casos de uso donde realmente se necesiten.

Estudio de caso: aplanamiento frente a agrupación global

Global Pooling condensa todos los mapas de características en uno solo, agrupando toda la información relevante en un solo mapa que puede ser fácilmente entendido por una sola capa de clasificación densa en lugar de múltiples capas. Por lo general, se aplica como agrupación promedio (GlobalAveragePooling2D) o agrupación máxima (GlobalMaxPooling2D) y también puede funcionar para entradas 1D y 3D.

En lugar de aplanar un mapa de características como (7, 7, 32) en un vector de longitud 1536 y entrenando una o varias capas para discernir patrones de este vector largo: podemos condensarlo en un (7, 7) vector y clasificar directamente desde allí. ¡Es así de simple!

Tenga en cuenta que las capas de cuello de botella para redes como ResNet cuentan con decenas de miles de características, no solo con 1536. Al aplanar, está torturando su red para aprender de vectores con formas extrañas de una manera muy ineficiente. Imagine una imagen 2D cortada en cada fila de píxeles y luego concatenada en un vector plano. Los dos píxeles que solían estar a 0 píxeles de distancia verticalmente no son feature_map_width píxeles de distancia horizontalmente! Si bien esto puede no importar demasiado para un algoritmo de clasificación, que favorece la invariancia espacial, esto no sería ni siquiera conceptualmente bueno para otras aplicaciones de visión artificial.

Definamos una pequeña red demostrativa que usa una capa plana con un par de capas densas:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.summary()

¿Cómo es el resumen?

...                                                              
 dense_6 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 11,574,090
Trainable params: 11,573,898
Non-trainable params: 192
_________________________________________________________________

11.5 millones de parámetros para una red de juguete y observe cómo explotan los parámetros con una entrada más grande. 11.5 millones de parámetros. EfficientNets, una de las redes de mejor rendimiento jamás diseñadas, funciona con ~6 millones de parámetros y no se puede comparar con este modelo simple en términos de rendimiento real y capacidad para aprender de los datos.

Podríamos reducir este número significativamente al hacer que la red sea más profunda, lo que introduciría una mayor agrupación máxima (y potencialmente una convolución escalonada) para reducir los mapas de características antes de que se aplanen. Sin embargo, considere que estaríamos haciendo la red más compleja para que sea menos costosa desde el punto de vista computacional, todo por el bien de una sola capa que está arruinando los planes.

Profundizar con las capas debería ser extraer relaciones no lineales más significativas entre los puntos de datos, sin reducir el tamaño de entrada para atender a una capa plana.

Aquí hay una red con agrupación global:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(10, activation='softmax')
])

model.summary()

¿Resumen?

 dense_8 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 66,602
Trainable params: 66,410
Non-trainable params: 192
_________________________________________________________________

¡Mucho mejor! Si profundizamos con este modelo, el número de parámetros aumentará y podremos capturar patrones de datos más complejos con las nuevas capas. Sin embargo, si se hace de manera ingenua, surgirán los mismos problemas que vincularon a VGGNets.

Yendo más allá: proyecto portátil de extremo a extremo

¿Tu naturaleza inquisitiva te hace querer ir más allá? Recomendamos revisar nuestro Proyecto Guiado: “Redes neuronales convolucionales: más allá de las arquitecturas básicas”.

Consulte nuestra guía práctica y práctica para aprender Git, con las mejores prácticas, los estándares aceptados por la industria y la hoja de trucos incluida. Deja de buscar en Google los comandos de Git y, de hecho, aprenden ella!

Lo llevaré a un pequeño viaje en el tiempo: desde 1998 hasta 2022, destacando las arquitecturas definitorias desarrolladas a lo largo de los años, qué las hizo únicas, cuáles son sus inconvenientes e implementaré las más notables desde cero. No hay nada mejor que tener un poco de suciedad en las manos cuando se trata de estos.

Puede conducir un automóvil sin saber si el motor tiene 4 u 8 cilindros y cuál es la ubicación de las válvulas dentro del motor. Sin embargo, si desea diseñar y apreciar un motor (modelo de visión por computadora), querrá profundizar un poco más. Incluso si no quiere perder tiempo diseñando arquitecturas y quiere construir productos en su lugar, que es lo que más quiere hacer, encontrará información importante en esta lección. Aprenderá por qué el uso de arquitecturas obsoletas como VGGNet dañará su producto y su rendimiento, y por qué debería omitirlas si está construyendo algo moderno, y aprenderá a qué arquitecturas puede acudir para resolver problemas prácticos y qué los pros y los contras son de cada uno.

Si está buscando aplicar la visión por computadora a su campo, utilizando los recursos de esta lección, podrá encontrar los modelos más nuevos, comprender cómo funcionan y con qué criterios puede compararlos y tomar una decisión sobre cuál. usar.

Usted no tienen que buscar en Google las arquitecturas y sus implementaciones; por lo general, se explican muy claramente en los documentos, y los marcos como Keras hacen que estas implementaciones sean más fáciles que nunca. El punto clave de este proyecto guiado es enseñarle cómo encontrar, leer, implementar y comprender arquitecturas y documentos. Ningún recurso en el mundo podrá mantenerse al día con todos los desarrollos más recientes. He incluido los documentos más nuevos aquí, pero en unos meses aparecerán nuevos, y eso es inevitable. Saber dónde encontrar implementaciones creíbles, compararlas con documentos y modificarlas puede brindarle la ventaja competitiva necesaria para muchos productos de visión por computadora que desee crear.

Conclusión

En esta breve guía, hemos analizado una alternativa al aplanamiento en el diseño de la arquitectura CNN. Aunque breve, la guía aborda un problema común al diseñar prototipos o MVP, y le aconseja que utilice una mejor alternativa al aplanamiento.

Cualquier ingeniero de visión artificial experimentado conocerá y aplicará este principio, y la práctica se da por sentada. Desafortunadamente, no parece transmitirse adecuadamente a los nuevos practicantes que recién ingresan al campo y pueden crear hábitos pegajosos que toman un tiempo para deshacerse de ellos.

Si está ingresando a Computer Vision, hágase un favor y no use capas planas antes de los encabezados de clasificación en su viaje de aprendizaje.

Sello de tiempo:

Mas de Abuso de pila