RetinaNet Object Detection med PyTorch och Torchvision PlatoBlockchain Data Intelligence. Vertikal sökning. Ai.

RetinaNet Object Detection med PyTorch och torchvision

Beskrivning

Objektdetektering är ett stort fält inom datorseende och en av de viktigaste tillämpningarna för datorseende "i det vilda". Å ena sidan kan den användas för att bygga autonoma system som navigerar agenter genom miljöer – oavsett om det är robotar som utför uppgifter eller självkörande bilar, men detta kräver korsning med andra fält. Avvikelsedetektering (som defekta produkter på en linje), lokalisering av objekt i bilder, ansiktsdetektering och olika andra tillämpningar av objektdetektering kan dock göras utan att skära andra fält.

Objektdetektering är inte lika standardiserad som bildklassificering, främst eftersom de flesta av de nya utvecklingarna vanligtvis görs av enskilda forskare, underhållare och utvecklare, snarare än stora bibliotek och ramverk. Det är svårt att paketera de nödvändiga verktygsskripten i ett ramverk som TensorFlow eller PyTorch och behålla API-riktlinjerna som väglett utvecklingen hittills.

Detta gör objektdetektering något mer komplex, vanligtvis mer omfattande (men inte alltid) och mindre lättillgänglig än bildklassificering. En av de stora fördelarna med att vara i ett ekosystem är att det ger dig ett sätt att inte söka efter användbar information om god praxis, verktyg och metoder att använda. Med objektdetektering – de flesta måste göra mycket mer forskning om fältets landskap för att få ett bra grepp.

Objektdetektion med PyTorch/TorchVisions RetinaNet

torchvision är PyTorchs Computer Vision-projekt, och syftar till att göra utvecklingen av PyTorch-baserade CV-modeller enklare, genom att tillhandahålla transformations- och förstärkningsskript, en modellzoo med förtränade vikter, datauppsättningar och verktyg som kan vara användbara för en utövare.

Medan den fortfarande är i beta och mycket experimentell – torchvision erbjuder ett relativt enkelt Object Detection API med några modeller att välja mellan:

  • Snabbare R-CNN
  • RetinaNet
  • FCOS (Fully convolutional RetinaNet)
  • SSD (VGG16 ryggrad... yikes)
  • SSDLite (MobileNetV3-ryggraden)

Även om API:et inte är lika polerat eller enkelt som vissa andra API:er från tredje part, är det en mycket anständig utgångspunkt för dem som fortfarande föredrar säkerheten att vara i ett ekosystem de är bekanta med. Innan du går vidare, se till att du installerar PyTorch och Torchvision:

$ pip install torch torchvision

Låt oss ladda in några av hjälpfunktionerna, som t.ex read_image(), draw_bounding_boxes() och to_pil_image() för att göra det lättare att läsa, rita på och skriva ut bilder, följt av import av RetinaNet och dess förtränade vikter (MS COCO):

from torchvision.io.image import read_image
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights

import matplotlib.pyplot as plt

RetinaNet använder ett ResNet50-stamnät och ett Feature Pyramid Network (FPN) ovanpå det. Även om namnet på klassen är utförligt, är det en indikation på arkitekturen. Låt oss hämta en bild med hjälp av requests bibliotek och spara det som en fil på vår lokala enhet:

import requests
response = requests.get('https://i.ytimg.com/vi/q71MCWAEfL8/maxresdefault.jpg')
open("obj_det.jpeg", "wb").write(response.content)

img = read_image("obj_det.jpeg")

Med en bild på plats kan vi instansiera vår modell och vikter:

weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.35)

model.eval()

preprocess = weights.transforms()

Smakämnen score_thresh argument definierar tröskeln vid vilken ett objekt detekteras som ett objekt i en klass. Intuitivt är det konfidensgränsen, och vi klassificerar inte ett objekt som tillhör en klass om modellen är mindre än 35 % säker på att den tillhör en klass.

Låt oss förbehandla bilden med hjälp av transformationerna från våra vikter, skapa en batch och köra slutledning:

batch = [preprocess(img)]
prediction = model(batch)[0]

Det är det, vår prediction ordboken innehåller de antagna objektklasserna och platserna! Nu är resultaten inte särskilt användbara för oss i det här formuläret – vi vill extrahera etiketterna med avseende på metadata från vikterna och rita avgränsningsrutor, vilket kan göras via draw_bounding_boxes():

labels = [weights.meta["categories"][i] for i in prediction["labels"]]

box = draw_bounding_boxes(img, boxes=prediction["boxes"],
                          labels=labels,
                          colors="cyan",
                          width=2, 
                          font_size=30,
                          font='Arial')

im = to_pil_image(box.detach())

fig, ax = plt.subplots(figsize=(16, 12))
ax.imshow(im)
plt.show()

Detta resulterar i:

RetinaNet klassificerade faktiskt personen som kikade bakom bilen! Det är en ganska svår klassificering.

Kolla in vår praktiska, praktiska guide för att lära dig Git, med bästa praxis, branschaccepterade standarder och medföljande fuskblad. Sluta googla Git-kommandon och faktiskt lära Det!

Du kan byta ut RetinaNet till ett FCOS (fullständigt konvolutionellt RetinaNet) genom att ersätta det retinanet_resnet50_fpn_v2 med fcos_resnet50_fpn, och använd FCOS_ResNet50_FPN_Weights vikter:

from torchvision.io.image import read_image
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
from torchvision.models.detection import fcos_resnet50_fpn, FCOS_ResNet50_FPN_Weights

import matplotlib.pyplot as plt
import requests
response = requests.get('https://i.ytimg.com/vi/q71MCWAEfL8/maxresdefault.jpg')
open("obj_det.jpeg", "wb").write(response.content)

img = read_image("obj_det.jpeg")
weights = FCOS_ResNet50_FPN_Weights.DEFAULT
model = fcos_resnet50_fpn(weights=weights, score_thresh=0.35)
model.eval()

preprocess = weights.transforms()
batch = [preprocess(img)]
prediction = model(batch)[0]

labels = [weights.meta["categories"][i] for i in prediction["labels"]]

box = draw_bounding_boxes(img, boxes=prediction["boxes"],
                          labels=labels,
                          colors="cyan",
                          width=2, 
                          font_size=30,
                          font='Arial')

im = to_pil_image(box.detach())

fig, ax = plt.subplots(figsize=(16, 12))
ax.imshow(im)
plt.show()

Going Further – Praktisk djupinlärning för datorseende

Din nyfikna natur gör att du vill gå längre? Vi rekommenderar att du kollar in vår Kurs: "Praktisk djupinlärning för datorseende med Python".

RetinaNet Object Detection med PyTorch och Torchvision PlatoBlockchain Data Intelligence. Vertikal sökning. Ai.

Ännu en kurs i datorseende?

Vi kommer inte att göra klassificering av MNIST-siffror eller MNIST-mode. De tjänade sin del för länge sedan. Alltför många inlärningsresurser fokuserar på grundläggande datamängder och grundläggande arkitekturer innan de låter avancerade blackbox-arkitekturer bära bördan av prestanda.

Vi vill fokusera på avmystifiering, praktiskhet, förståelse, intuition och riktiga projekt. Vill lära sig hur du kan göra skillnad? Vi tar dig med på en tur från hur våra hjärnor bearbetar bilder till att skriva en klassificerare för djupinlärning för bröstcancer i forskningsklass till nätverk för djupinlärning som "hallucinerar", lär dig principer och teorier genom praktiskt arbete, och utrustar dig med kunskap och verktyg för att bli expert på att tillämpa djupinlärning för att lösa datorseende.

Vad är inuti?

  • De första principerna för syn och hur datorer kan läras att "se"
  • Olika uppgifter och tillämpningar av datorseende
  • Branschens verktyg som gör ditt arbete enklare
  • Hitta, skapa och använda datauppsättningar för datorseende
  • Teorin och tillämpningen av Convolutional Neural Networks
  • Hantera domänskifte, samtidig förekomst och andra fördomar i datamängder
  • Överför Lärande och utnyttja andras träningstid och beräkningsresurser till din fördel
  • Bygga och träna en toppmodern klassificerare för bröstcancer
  • Hur man applicerar en hälsosam dos av skepsis på mainstream idéer och förstår implikationerna av allmänt använda tekniker
  • Visualisera ett ConvNets "konceptutrymme" med t-SNE och PCA
  • Fallstudier av hur företag använder datorseendetekniker för att uppnå bättre resultat
  • Korrekt modellutvärdering, latent rumsvisualisering och identifiering av modellens uppmärksamhet
  • Utföra domänforskning, bearbeta dina egna datamängder och upprätta modelltester
  • Banbrytande arkitekturer, utvecklingen av idéer, vad som gör dem unika och hur man implementerar dem
  • KerasCV – ett WIP-bibliotek för att skapa toppmoderna pipelines och modeller
  • Hur man analyserar och läser uppsatser och implementerar dem själv
  • Välja modeller beroende på din applikation
  • Skapa en komplett maskininlärningspipeline
  • Landskap och intuition på objektdetektering med snabbare R-CNN, RetinaNets, SSD och YOLO
  • Instans och semantisk segmentering
  • Objektigenkänning i realtid med YOLOv5
  • Träning av YOLOv5-objektdetektorer
  • Arbeta med transformatorer med KerasNLP (industristarkt WIP-bibliotek)
  • Integrering av Transformers med ConvNets för att generera bildtexter
  • DeepDream

Slutsats

Objektdetektion är ett viktigt område inom datorseende, och ett område som tyvärr är mindre lättillgängligt än det borde vara.

I den här korta guiden har vi tagit en titt på hur torchvision, PyTorchs Computer Vision-paket, gör det lättare att utföra objektdetektering på bilder, med hjälp av RetinaNet.

Tidsstämpel:

Mer från Stackabuse