Jangan Gunakan Flatten() - Pengumpulan Global untuk CNN dengan TensorFlow dan Keras PlatoBlockchain Data Intelligence. Pencarian Vertikal. Ai.

Jangan Gunakan Flatten() – Global Pooling untuk CNN dengan TensorFlow dan Keras

Sebagian besar praktisi, saat pertama kali mempelajari arsitektur Convolutional Neural Network (CNN) – mengetahui bahwa arsitektur itu terdiri dari tiga segmen dasar:

  • Lapisan Konvolusi
  • Lapisan Pengumpulan
  • Lapisan Terhubung Sepenuhnya

Sebagian besar sumber daya memiliki beberapa variasi pada segmentasi ini, termasuk buku saya sendiri. Khususnya online – lapisan yang terhubung penuh mengacu pada a lapisan merata dan (biasanya) banyak lapisan padat.

Ini dulu norma, dan arsitektur terkenal seperti VGGNets menggunakan pendekatan ini, dan akan berakhir dengan:

model = keras.Sequential([
    
    keras.layers.MaxPooling2D((2, 2), strides=(2, 2), padding='same'),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'), 
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dense(n_classes, activation='softmax')
])

Meskipun, untuk beberapa alasan – seringkali dilupakan bahwa VGGNet praktis merupakan arsitektur terakhir yang menggunakan pendekatan ini, karena hambatan komputasi yang jelas yang diciptakannya. Segera setelah ResNets, diterbitkan hanya setahun setelah VGGNets (dan 7 tahun yang lalu), semua arsitektur arus utama mengakhiri definisi model mereka dengan:

model = keras.Sequential([
    
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(n_classes, activation='softmax')
])

Perataan di CNN telah bertahan selama 7 tahun. 7 tahun! Dan tampaknya tidak cukup banyak orang yang berbicara tentang efek merusak yang ditimbulkannya pada pengalaman belajar Anda dan sumber daya komputasi yang Anda gunakan.

Pengumpulan Rata-Rata Global lebih disukai di banyak akun daripada perataan. Jika Anda membuat prototipe CNN kecil – gunakan Global Pooling. Jika Anda mengajar seseorang tentang CNN – gunakan Global Pooling. Jika Anda membuat MVP – gunakan Global Pooling. Gunakan lapisan perataan untuk kasus penggunaan lain di mana mereka benar-benar dibutuhkan.

Studi Kasus – Perataan vs Penggabungan Global

Global Pooling memadatkan semua peta fitur menjadi satu, menyatukan semua informasi yang relevan ke dalam satu peta yang dapat dengan mudah dipahami oleh satu lapisan klasifikasi padat alih-alih beberapa lapisan. Ini biasanya diterapkan sebagai penyatuan rata-rata (GlobalAveragePooling2D) atau pengumpulan maksimum (GlobalMaxPooling2D) dan dapat bekerja untuk input 1D dan 3D juga.

Alih-alih meratakan peta fitur seperti (7, 7, 32) menjadi vektor dengan panjang 1536 dan melatih satu atau beberapa lapisan untuk membedakan pola dari vektor panjang ini: kita dapat memadatkannya menjadi (7, 7) vektor dan mengklasifikasikan langsung dari sana. Sesederhana itu!

Perhatikan bahwa lapisan bottleneck untuk jaringan seperti ResNets menghitung puluhan ribu fitur, bukan hanya 1536. Saat merata, Anda menyiksa jaringan Anda untuk belajar dari vektor berbentuk aneh dengan cara yang sangat tidak efisien. Bayangkan gambar 2D diiris pada setiap baris piksel dan kemudian digabungkan menjadi vektor datar. Dua piksel yang dulunya 0 piksel terpisah secara vertikal tidak feature_map_width piksel menjauh secara horizontal! Meskipun ini mungkin tidak terlalu menjadi masalah untuk algoritme klasifikasi, yang mendukung invarian spasial – ini bahkan tidak akan baik secara konseptual untuk aplikasi visi komputer lainnya.

Mari kita definisikan jaringan demonstratif kecil yang menggunakan lapisan perataan dengan beberapa lapisan padat:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.summary()

Seperti apa ringkasannya?

...                                                              
 dense_6 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 11,574,090
Trainable params: 11,573,898
Non-trainable params: 192
_________________________________________________________________

11.5M parameter untuk jaringan mainan – dan saksikan parameter meledak dengan input yang lebih besar. 11.5M parameter. EfficientNets, salah satu jaringan berkinerja terbaik yang pernah dirancang bekerja pada parameter ~6 juta, dan tidak dapat dibandingkan dengan model sederhana ini dalam hal kinerja aktual dan kapasitas untuk belajar dari data.

Kami dapat mengurangi jumlah ini secara signifikan dengan membuat jaringan lebih dalam, yang akan memperkenalkan lebih banyak pengumpulan maksimal (dan berpotensi konvolusi bertahap) untuk mengurangi peta fitur sebelum diratakan. Namun, pertimbangkan bahwa kami akan membuat jaringan lebih kompleks untuk membuatnya lebih murah secara komputasi, semua demi satu lapisan yang melemparkan kunci pas dalam rencana.

Masuk lebih dalam dengan lapisan harus mengekstraksi hubungan non-linier yang lebih bermakna antara titik data, tidak mengurangi ukuran input untuk memenuhi lapisan yang rata.

Berikut adalah jaringan dengan penyatuan global:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(10, activation='softmax')
])

model.summary()

Ringkasan?

 dense_8 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 66,602
Trainable params: 66,410
Non-trainable params: 192
_________________________________________________________________

Jauh lebih baik! Jika kita mendalami model ini, jumlah parameter akan meningkat, dan kita mungkin dapat menangkap pola data yang lebih rumit dengan lapisan baru. Jika dilakukan secara naif, masalah yang sama yang mengikat VGGNet akan muncul.

Melangkah Lebih Jauh – Proyek Hand-Held End-to-End

Sifat ingin tahu Anda membuat Anda ingin melangkah lebih jauh? Kami merekomendasikan untuk memeriksa kami Proyek Terpandu: “Jaringan Saraf Konvolusi – Melampaui Arsitektur Dasar”.

Lihat panduan praktis dan praktis kami untuk mempelajari Git, dengan praktik terbaik, standar yang diterima industri, dan termasuk lembar contekan. Hentikan perintah Googling Git dan sebenarnya belajar itu!

Saya akan mengajak Anda melakukan sedikit perjalanan waktu – mulai dari tahun 1998 hingga 2022, menyoroti arsitektur yang berkembang selama bertahun-tahun, apa yang membuatnya unik, apa kekurangannya, dan menerapkan arsitektur penting dari awal. Tidak ada yang lebih baik daripada memiliki beberapa kotoran di tangan Anda ketika datang ke ini.

Anda dapat mengendarai mobil tanpa mengetahui apakah mesin memiliki 4 atau 8 silinder dan apa penempatan katup di dalam mesin. Namun – jika Anda ingin merancang dan menghargai mesin (model visi komputer), Anda harus masuk lebih dalam. Bahkan jika Anda tidak ingin menghabiskan waktu merancang arsitektur dan sebaliknya ingin membangun produk, itulah yang paling ingin Anda lakukan – Anda akan menemukan informasi penting dalam pelajaran ini. Anda akan mempelajari mengapa menggunakan arsitektur usang seperti VGGNet akan merusak produk dan kinerja Anda, dan mengapa Anda harus melewatinya jika Anda sedang membangun sesuatu yang modern, dan Anda akan mempelajari arsitektur mana yang dapat Anda gunakan untuk memecahkan masalah praktis dan apa pro dan kontra adalah untuk masing-masing.

Jika Anda ingin menerapkan visi komputer ke bidang Anda, menggunakan sumber daya dari pelajaran ini – Anda akan dapat menemukan model terbaru, memahami cara kerjanya dan dengan kriteria apa Anda dapat membandingkannya dan membuat keputusan yang mana menggunakan.

Kamu tidak harus Google untuk arsitektur dan implementasinya – biasanya dijelaskan dengan sangat jelas di makalah, dan kerangka kerja seperti Keras membuat implementasi ini lebih mudah dari sebelumnya. Kunci utama dari Proyek Terpandu ini adalah untuk mengajari Anda cara menemukan, membaca, menerapkan, dan memahami arsitektur dan makalah. Tidak ada sumber daya di dunia yang dapat mengikuti semua perkembangan terbaru. Saya telah menyertakan makalah terbaru di sini – tetapi dalam beberapa bulan, yang baru akan muncul, dan itu tidak dapat dihindari. Mengetahui di mana menemukan implementasi yang kredibel, membandingkannya dengan kertas dan menyesuaikannya dapat memberi Anda keunggulan kompetitif yang diperlukan untuk banyak produk visi komputer yang mungkin ingin Anda buat.

Kesimpulan

Dalam panduan singkat ini, kami telah melihat alternatif perataan dalam desain arsitektur CNN. Meskipun singkat – panduan ini membahas masalah umum saat merancang prototipe atau MVP, dan menyarankan Anda untuk menggunakan alternatif yang lebih baik untuk meratakan.

Setiap Insinyur Penglihatan Komputer yang berpengalaman akan mengetahui dan menerapkan prinsip ini, dan praktiknya diterima begitu saja. Sayangnya, hal itu tampaknya tidak tersampaikan dengan baik kepada praktisi baru yang baru memasuki lapangan, dan dapat menciptakan kebiasaan lengket yang perlu waktu lama untuk dihilangkan.

Jika Anda masuk ke Computer Vision – bantulah diri Anda sendiri dan jangan gunakan lapisan perataan sebelum kepala klasifikasi dalam perjalanan belajar Anda.

Stempel Waktu:

Lebih dari penyalahgunaan