Não use Flatten() - Pooling global para CNNs com TensorFlow e Keras PlatoBlockchain Data Intelligence. Pesquisa vertical. Ai.

Não use Flatten() – Global Pooling para CNNs com TensorFlow e Keras

A maioria dos praticantes, ao aprender pela primeira vez sobre arquiteturas de Rede Neural Convolucional (CNN) – aprende que ela é composta por três segmentos básicos:

  • Camadas Convolucionais
  • Camadas de pool
  • Camadas totalmente conectadas

A maioria dos recursos tem alguns variação dessa segmentação, incluindo meu próprio livro. Especialmente online – camadas totalmente conectadas referem-se a um camada de achatamento e (geralmente) múltiplos camadas densas.

Isso costumava ser a norma, e arquiteturas bem conhecidas, como VGGNets, usavam essa abordagem e terminariam em:

model = keras.Sequential([
    
    keras.layers.MaxPooling2D((2, 2), strides=(2, 2), padding='same'),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'), 
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dense(n_classes, activation='softmax')
])

Porém, por alguma razão – muitas vezes é esquecido que a VGGNet foi praticamente a última arquitetura a usar essa abordagem, devido ao óbvio gargalo computacional que ela cria. Assim que ResNets, publicado apenas um ano após o VGGNets (e 7 anos atrás), todas as arquiteturas convencionais encerraram suas definições de modelo com:

model = keras.Sequential([
    
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(n_classes, activation='softmax')
])

O achatamento nas CNNs dura há 7 anos. 7 anos! E poucas pessoas parecem estar falando sobre o efeito prejudicial que isso tem em sua experiência de aprendizado e nos recursos computacionais que você está usando.

O Global Average Pooling é preferível em muitas contas ao achatamento. Se você estiver criando um protótipo de uma CNN pequena – use o Global Pooling. Se você está ensinando alguém sobre CNNs – use o Global Pooling. Se você estiver fazendo um MVP – use o Global Pooling. Use camadas de nivelamento para outros casos de uso em que elas sejam realmente necessárias.

Estudo de caso - achatamento vs agrupamento global

O Global Pooling condensa todos os mapas de recursos em um único, agrupando todas as informações relevantes em um único mapa que pode ser facilmente entendido por uma única camada de classificação densa em vez de várias camadas. Normalmente é aplicado como pool médio (GlobalAveragePooling2D) ou pool máximo (GlobalMaxPooling2D) e também pode funcionar para entrada 1D e 3D.

Em vez de achatar um mapa de recursos, como (7, 7, 32) em um vetor de comprimento 1536 e treinando uma ou várias camadas para discernir padrões desse vetor longo: podemos condensá-lo em um (7, 7) vetor e classifique diretamente de lá. É simples assim!

Observe que as camadas de gargalo para redes como ResNets contam em dezenas de milhares de recursos, não meros 1536. Ao nivelar, você está torturando sua rede para aprender com vetores de formatos estranhos de uma maneira muito ineficiente. Imagine uma imagem 2D sendo fatiada em cada linha de pixel e depois concatenada em um vetor plano. Os dois pixels que costumavam estar separados verticalmente por 0 pixels não são feature_map_width pixels de distância horizontalmente! Embora isso possa não importar muito para um algoritmo de classificação, que favorece a invariância espacial – isso não seria nem conceitualmente bom para outras aplicações de visão computacional.

Vamos definir uma pequena rede demonstrativa que usa uma camada de nivelamento com algumas camadas densas:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.summary()

Como é o resumo?

...                                                              
 dense_6 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 11,574,090
Trainable params: 11,573,898
Non-trainable params: 192
_________________________________________________________________

11.5 milhões de parâmetros para uma rede de brinquedos – e veja os parâmetros explodirem com entradas maiores. 11.5 milhões de parâmetros. As EfficientNets, uma das redes com melhor desempenho já projetadas, funcionam com parâmetros de ~6M e não podem ser comparadas com esse modelo simples em termos de desempenho real e capacidade de aprender com os dados.

Poderíamos reduzir esse número significativamente tornando a rede mais profunda, o que introduziria mais pooling máximo (e convolução potencialmente strided) para reduzir os mapas de recursos antes de serem achatados. No entanto, considere que estaríamos tornando a rede mais complexa para torná-la menos dispendiosa computacionalmente, tudo por causa de uma única camada que está atrapalhando os planos.

Aprofundar as camadas deve ser extrair relacionamentos não lineares mais significativos entre os pontos de dados, não reduzindo o tamanho da entrada para atender a uma camada de nivelamento.

Aqui está uma rede com pool global:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(10, activation='softmax')
])

model.summary()

Resumo?

 dense_8 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 66,602
Trainable params: 66,410
Non-trainable params: 192
_________________________________________________________________

Muito melhor! Se aprofundarmos esse modelo, a contagem de parâmetros aumentará e poderemos capturar padrões de dados mais complexos com as novas camadas. Se feito de forma ingênua, porém, os mesmos problemas que limitam as VGGNets surgirão.

Indo além - Projeto de ponta a ponta manual

Sua natureza curiosa faz você querer ir mais longe? Recomendamos verificar nosso Projeto Guiado: “Redes Neurais Convolucionais – Além das Arquiteturas Básicas”.

Confira nosso guia prático e prático para aprender Git, com práticas recomendadas, padrões aceitos pelo setor e folha de dicas incluída. Pare de pesquisar comandos Git no Google e realmente aprender -lo!

Vou levá-lo em uma pequena viagem no tempo – indo de 1998 a 2022, destacando as arquiteturas definidoras desenvolvidas ao longo dos anos, o que as tornou únicas, quais são suas desvantagens e implementando as notáveis ​​​​do zero. Não há nada melhor do que ter um pouco de sujeira nas mãos quando se trata disso.

Você pode dirigir um carro sem saber se o motor tem 4 ou 8 cilindros e qual é a colocação das válvulas dentro do motor. No entanto – se você quiser projetar e apreciar um motor (modelo de visão computacional), você vai querer ir um pouco mais fundo. Mesmo que você não queira gastar tempo projetando arquiteturas e, em vez disso, queira construir produtos, que é o que a maioria quer fazer – você encontrará informações importantes nesta lição. Você aprenderá por que usar arquiteturas desatualizadas como VGGNet prejudicará seu produto e desempenho, e por que você deve ignorá-las se estiver construindo algo moderno, e aprenderá quais arquiteturas você pode usar para resolver problemas práticos e quais os prós e contras são para cada um.

Se você deseja aplicar a visão computacional ao seu campo, usando os recursos desta lição, você poderá encontrar os modelos mais recentes, entender como eles funcionam e por quais critérios você pode compará-los e tomar uma decisão sobre quais usar.

Vocês não precisam do Google para arquiteturas e suas implementações – elas geralmente são explicadas com muita clareza nos documentos, e frameworks como Keras tornam essas implementações mais fáceis do que nunca. A principal lição deste projeto guiado é ensiná-lo a encontrar, ler, implementar e entender arquiteturas e documentos. Nenhum recurso no mundo será capaz de acompanhar todos os desenvolvimentos mais recentes. Incluí os artigos mais recentes aqui – mas em poucos meses, novos aparecerão, e isso é inevitável. Saber onde encontrar implementações confiáveis, compará-las com documentos e ajustá-las pode fornecer a vantagem competitiva necessária para muitos produtos de visão computacional que você deseja construir.

Conclusão

Neste pequeno guia, analisamos uma alternativa ao achatamento no design da arquitetura da CNN. Embora curto – o guia aborda um problema comum ao projetar protótipos ou MVPs e aconselha você a usar uma alternativa melhor ao achatamento.

Qualquer Engenheiro de Visão Computacional experiente conhecerá e aplicará este princípio, e a prática é tida como certa. Infelizmente, não parece ser devidamente retransmitido para novos praticantes que estão apenas entrando no campo, e pode criar hábitos pegajosos que demoram um pouco para se livrar.

Se você está entrando em Visão Computacional – faça um favor a si mesmo e não use camadas de nivelamento antes dos chefes de classificação em sua jornada de aprendizado.

Carimbo de hora:

Mais de Abuso de pilha