RetinaNet Object Detection with PyTorch and torchvision PlatoBlockchain Data Intelligence. Vertical Search. Ai.

RetinaNet Object Detection with PyTorch and torchvision

Introduction

Object detection is a large field in computer vision, and one of the more important applications of computer vision “in the wild”. On one end, it can be used to build autonomous systems that navigate agents through environments – be it robots performing tasks or self-driving cars, but this requires intersection with other fields. However, anomaly detection (such as defective products on a line), locating objects within images, facial detection and various other applications of object detection can be done without intersecting other fields.

Object detection isn’t as standardized as image classification, mainly because most of the new developments are typically done by individual researchers, maintainers and developers, rather than large libraries and frameworks. It’s difficult to package the necessary utility scripts in a framework like TensorFlow or PyTorch and maintain the API guidelines that guided the development so far.

This makes object detection somewhat more complex, typically more verbose (but not always), and less approachable than image classification. One of the major benefits of being in an ecosystem is that it provides you with a way to not search for useful information on good practices, tools and approaches to use. With object detection – most have to do way more research on the landscape of the field to get a good grip.

Object Detection with PyTorch/TorchVision’s RetinaNet

torchvision is PyTorch’s Computer Vision project, and aims to make the development of PyTorch-based CV models easier, by providing transformation and augmentation scripts, a model zoo with pre-trained weights, datasets and utilities that can be useful for a practitioner.

While still in beta and very much experimental – torchvision offers a relatively simple Object Detection API with a few models to choose from:

  • Faster R-CNN
  • RetinaNet
  • FCOS (Fully convolutional RetinaNet)
  • SSD (VGG16 backbone… yikes)
  • SSDLite (MobileNetV3 backbone)

While the API isn’t as polished or simple as some other third-party APIs, it’s a very decent starting point for those who’d still prefer the safety of being in an ecosystem they’re familiar with. Before going forward, make sure you install PyTorch and Torchvision:

$ pip install torch torchvision

Let’s load in some of the utility functions, such as read_image(), draw_bounding_boxes() and to_pil_image() to make it easier to read, draw on and output images, followed by importing RetinaNet and its pre-trained weights (MS COCO):

from torchvision.io.image import read_image
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights

import matplotlib.pyplot as plt

RetinaNet uses a ResNet50 backbone and a Feature Pyramid Network (FPN) on top of it. While the name of the class is verbose, it’s indicative of the architecture. Let’s fetch an image using the requests library and save it as a file on our local drive:

import requests
response = requests.get('https://i.ytimg.com/vi/q71MCWAEfL8/maxresdefault.jpg')
open("obj_det.jpeg", "wb").write(response.content)

img = read_image("obj_det.jpeg")

With an image in place – we can instantiate our model and weights:

weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.35)

model.eval()

preprocess = weights.transforms()

The score_thresh argument defines the threshold at which an object is detected as an object of a class. Intuitively, it’s the confidence threshold, and we won’t classify an object to belong to a class if the model is less than 35% confident that it belongs to a class.

Let’s preprocess the image using the transforms from our weights, create a batch and run inference:

batch = [preprocess(img)]
prediction = model(batch)[0]

That’s it, our prediction dictionary holds the inferred object classes and locations! Now, the results aren’t very useful for us in this form – we’ll want to extract the labels with respect to the metadata from the weights and draw bounding boxes, which can be done via draw_bounding_boxes():

labels = [weights.meta["categories"][i] for i in prediction["labels"]]

box = draw_bounding_boxes(img, boxes=prediction["boxes"],
                          labels=labels,
                          colors="cyan",
                          width=2, 
                          font_size=30,
                          font='Arial')

im = to_pil_image(box.detach())

fig, ax = plt.subplots(figsize=(16, 12))
ax.imshow(im)
plt.show()

This results in:

RetinaNet actually classified the person peeking behind the car! That’s a pretty difficult classification.

Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!

You can switch out RetinaNet to an FCOS (fully convolutional RetinaNet) by replacing retinanet_resnet50_fpn_v2 with fcos_resnet50_fpn, and use the FCOS_ResNet50_FPN_Weights weights:

from torchvision.io.image import read_image
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
from torchvision.models.detection import fcos_resnet50_fpn, FCOS_ResNet50_FPN_Weights

import matplotlib.pyplot as plt
import requests
response = requests.get('https://i.ytimg.com/vi/q71MCWAEfL8/maxresdefault.jpg')
open("obj_det.jpeg", "wb").write(response.content)

img = read_image("obj_det.jpeg")
weights = FCOS_ResNet50_FPN_Weights.DEFAULT
model = fcos_resnet50_fpn(weights=weights, score_thresh=0.35)
model.eval()

preprocess = weights.transforms()
batch = [preprocess(img)]
prediction = model(batch)[0]

labels = [weights.meta["categories"][i] for i in prediction["labels"]]

box = draw_bounding_boxes(img, boxes=prediction["boxes"],
                          labels=labels,
                          colors="cyan",
                          width=2, 
                          font_size=30,
                          font='Arial')

im = to_pil_image(box.detach())

fig, ax = plt.subplots(figsize=(16, 12))
ax.imshow(im)
plt.show()

Going Further – Practical Deep Learning for Computer Vision

Your inquisitive nature makes you want to go further? We recommend checking out our Course: “Practical Deep Learning for Computer Vision with Python”.

RetinaNet Object Detection with PyTorch and torchvision PlatoBlockchain Data Intelligence. Vertical Search. Ai.

Another Computer Vision Course?

We won’t be doing classification of MNIST digits or MNIST fashion. They served their part a long time ago. Too many learning resources are focusing on basic datasets and basic architectures before letting advanced black-box architectures shoulder the burden of performance.

We want to focus on demystification, practicality, understanding, intuition and real projects. Want to learn how you can make a difference? We’ll take you on a ride from the way our brains process images to writing a research-grade deep learning classifier for breast cancer to deep learning networks that “hallucinate”, teaching you the principles and theory through practical work, equipping you with the know-how and tools to become an expert at applying deep learning to solve computer vision.

What’s inside?

  • The first principles of vision and how computers can be taught to “see”
  • Different tasks and applications of computer vision
  • The tools of the trade that will make your work easier
  • Finding, creating and utilizing datasets for computer vision
  • The theory and application of Convolutional Neural Networks
  • Handling domain shift, co-occurrence, and other biases in datasets
  • Transfer Learning and utilizing others’ training time and computational resources for your benefit
  • Building and training a state-of-the-art breast cancer classifier
  • How to apply a healthy dose of skepticism to mainstream ideas and understand the implications of widely adopted techniques
  • Visualizing a ConvNet’s “concept space” using t-SNE and PCA
  • Case studies of how companies use computer vision techniques to achieve better results
  • Proper model evaluation, latent space visualization and identifying the model’s attention
  • Performing domain research, processing your own datasets and establishing model tests
  • Cutting-edge architectures, the progression of ideas, what makes them unique and how to implement them
  • KerasCV – a WIP library for creating state of the art pipelines and models
  • How to parse and read papers and implement them yourself
  • Selecting models depending on your application
  • Creating an end-to-end machine learning pipeline
  • Landscape and intuition on object detection with Faster R-CNNs, RetinaNets, SSDs and YOLO
  • Instance and semantic segmentation
  • Real-Time Object Recognition with YOLOv5
  • Training YOLOv5 Object Detectors
  • Working with Transformers using KerasNLP (industry-strength WIP library)
  • Integrating Transformers with ConvNets to generate captions of images
  • DeepDream

Conclusion

Object Detection is an important field of Computer Vision, and one that’s unfortunately less approachable than it should be.

In this short guide, we’ve taken a look at how torchvision, PyTorch’s Computer Vision package, makes it easier to perform object detection on images, using RetinaNet.

Time Stamp:

More from Stackabuse