不要使用 Flatten() - 使用 TensorFlow 和 Keras PlatoBlockchain 数据智能的 CNN 全局池化。 垂直搜索。 哎。

不要使用 Flatten() – 使用 TensorFlow 和 Keras 的 CNN 全局池化

大多数从业者在第一次学习卷积神经网络 (CNN) 架构时 - 了解到它由三个基本部分组成:

  • 卷积层
  • 池化层
  • 全连接层

大多数资源都有 一些 这种细分的变化,包括我自己的书。 尤其是在线——全连接层指的是 平整层 和(通常)多个 致密层.

这曾经是常态,而著名的架构(例如 VGGNets)使用了这种方法,并会以以下方式结束:

model = keras.Sequential([
    
    keras.layers.MaxPooling2D((2, 2), strides=(2, 2), padding='same'),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'), 
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dense(n_classes, activation='softmax')
])

但是,由于某种原因,人们经常忘记 VGGNet 实际上是使用这种方法的最后一个架构,因为它会产生明显的计算瓶颈。 ResNets 在 VGGNets 发布后的一年(以及 7 年前)发布后,所有主流架构都以以下方式结束了其模型定义:

model = keras.Sequential([
    
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(n_classes, activation='softmax')
])

CNN 中的扁平化已经持续了 7 年。 7 年! 而且似乎没有足够多的人谈论它对您的学习体验和您正在使用的计算资源的破坏性影响。

在许多账户上,全球平均池化比扁平化更可取。 如果您正在制作小型 CNN 的原型,请使用 Global Pooling。 如果您正在教某人有关 CNN 的知识,请使用 Global Pooling。 如果您正在制作 MVP,请使用 Global Pooling。 将展平层用于实际需要的其他用例。

案例研究——扁平化与全局池化

全局池化将所有的特征图浓缩成一个单一的,将所有相关信息汇集到一个单一的地图中,可以通过单个密集分类层而不是多个层轻松理解。 它通常用作平均池(GlobalAveragePooling2D) 或最大池化 (GlobalMaxPooling2D) 并且也可以用于 1D 和 3D 输入。

而不是展平特征图,例如 (7, 7, 32) 成一个长度为 1536 的向量,并训练一个或多个层来从这个长向量中辨别模式:我们可以将它压缩成一个 (7, 7) 向量并直接从那里分类。 就是这么简单!

请注意,像 ResNets 这样的网络的瓶颈层数以万计,而不仅仅是 1536 个。扁平化时,您正在折磨您的网络以非常低效的方式从形状奇特的向量中学习。 想象一个 2D 图像在每个像素行上被切片,然后连接成一个平面向量。 曾经垂直相隔 0 像素的两个像素现在不是 feature_map_width 像素水平相距! 虽然这对于有利于空间不变性的分类算法来说可能无关紧要——但这对于计算机视觉的其他应用在概念上甚至都不是很好。

让我们定义一个小型演示网络,该网络使用带有几个密集层的展平层:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.summary()

摘要看起来像什么?

...                                                              
 dense_6 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 11,574,090
Trainable params: 11,573,898
Non-trainable params: 192
_________________________________________________________________

玩具网络的 11.5 万个参数——观察参数随着更大的输入而爆炸。 11.5M 参数. EfficientNets 是有史以来设计性能最好的网络之一,在大约 6M 参数下工作,在实际性能和从数据中学习的能力方面无法与这个简单的模型进行比较。

我们可以通过使网络更深来显着减少这个数字,这将引入更多的最大池(和潜在的跨步卷积)以在特征图被展平之前减少它们。 然而,考虑到我们会让网络变得更复杂,以降低计算成本,所有这些都是为了让计划中的单层受到影响。

更深入的层应该是提取数据点之间更有意义的非线性关系,而不是减少输入大小以迎合扁平化层。

这是一个具有全局池的网络:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(10, activation='softmax')
])

model.summary()

概括?

 dense_8 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 66,602
Trainable params: 66,410
Non-trainable params: 192
_________________________________________________________________

好多了! 如果我们对这个模型进行更深入的研究,参数数量将会增加,并且我们可能能够使用新层捕获更复杂的数据模式。 但是,如果天真地完成,将出现与 VGGNets 相同的问题。

走得更远——手持式端到端项目

你好奇的天性让你想走得更远? 我们建议查看我们的 指导项目: “卷积神经网络——超越基本架构”.

查看我们的 Git 学习实践指南,其中包含最佳实践、行业认可的标准以及随附的备忘单。 停止谷歌搜索 Git 命令,实际上 学习 它!

我将带您进行一些时间旅行——从 1998 年到 2022 年,重点介绍多年来开发的定义架构,是什么让它们与众不同,它们的缺点是什么,并从头开始实施值得注意的架构。 当涉及到这些时,没有什么比手上有一些污垢更好的了。

您可以在不知道发动机是 4 缸还是 8 缸以及发动机内阀门位置的情况下驾驶汽车。 然而——如果你想设计和欣赏一个引擎(计算机视觉模型),你会想要更深入一点。 即使您不想花时间设计架构,而是想构建产品,这是最想做的——您将在本课中找到重要信息。 您将了解为什么使用 VGGNet 等过时的架构会损害您的产品和性能,以及在构建任何现代产品时为什么应该跳过它们,您将了解可以使用哪些架构来解决实际问题以及什么利弊各有千秋。

如果您希望将计算机视觉应用到您的领域,使用本课程中的资源 - 您将能够找到最新的模型,了解它们的工作原理以及您可以根据哪些标准来比较它们并决定要使用哪个模型利用。

完全 必须在谷歌上搜索架构及其实现——它们通常在论文中得到非常清楚的解释,像 Keras 这样的框架使这些实现比以往任何时候都容易。 这个引导项目的主要内容是教你如何查找、阅读、实施和理解架构和论文。 世界上没有任何资源能够跟上所有最新的发展。 我已经在这里包含了最新的论文——但几个月后,新论文会出现,这是不可避免的。 知道在哪里可以找到可靠的实现,将它们与论文进行比较并调整它们可以为您提供您可能想要构建的许多计算机视觉产品所需的竞争优势。

结论

在这个简短的指南中,我们研究了 CNN 架构设计中扁平化的替代方案。 尽管很短——该指南解决了设计原型或 MVP 时的一个常见问题,并建议您使用更好的扁平化替代方案。

任何经验丰富的计算机视觉工程师都会知道并应用这一原则,而且这种做法是理所当然的。 不幸的是,它似乎没有正确地传达给刚刚进入该领域的新从业者,并且可能会产生需要一段时间才能摆脱的粘性习惯。

如果您要进入计算机视觉领域,请帮自己一个忙,不要在学习过程中在分类头之前使用展平层。

时间戳记:

更多来自 堆栈滥用