Using ResNet for Image Classification with PyTorch

      In this tutorial, we'll learn about ResNet model and how to use a pre-trained ResNet-50 model for image classification with PyTorch. We'll go through the steps of loading a pre-trained model, preprocessing image, and using the model to predict its class label, as well as displaying the results.The tutorial covers:

  1. Introduction to ResNet model
  2. Loading a Pre-Trained ResNet-50 model
  3. Defining Image Preprocessing
  4. Loading ImageNet Class Labels
  5. Making a Prediction
  6. Conclusion
  7. Full code listing

 

Introduction to ResNet model

    ResNet, short for Residual Network, is a deep convolutional neural network (CNN) architecture that addresses a key problem in very deep networks: the vanishing gradient problem, where gradients shrink as they’re back-propagated through layers, making it hard to train deeper networks effectively. ResNet enables the training of extremely deep networks by using residual connections, which allow gradients to flow more easily through the network.


Residual blocks

    In traditional CNNs, layers are arranged in a sequence of convolutions, batch normalization, and activation functions. This setup can make training very deep networks difficult. In ResNet, a residual block adds a shortcut (or skip connection) that allows the input to jump over one or more layers. This shortcut helps the network learn a residual mapping instead of trying to learn the entire transformation.

    The output of the residual block is calculated as:

output=activation(F(x)+x)\text{output} = \text{activation}(F(x) + x)

    where:

  • xx is the input to the block
  • F(x)F(x) is the transformation done by the convolutional layers in the block

    The output combines the original input and the result from the convolutional layers.

 

Key Characteristics of ResNet

  • Deep Architecture: ResNet comes in various depths, such as ResNet-18, ResNet-34, ResNet-50, ResNet-101, and ResNet-152, with the latter capable of containing over 150 layers. This depth allows the model to capture intricate patterns and features in data.

  • Residual Connections: The hallmark of ResNet is its use of skip connections (or shortcuts) that bypass one or more layers. This allows the network to learn residual mappings, which helps in mitigating the vanishing gradient problem and enables effective training of very deep networks.

  • Bottleneck Design: In deeper variants like ResNet-50 and above, the architecture employs bottleneck blocks that consist of three convolutional layers (1x1, 3x3, and 1x1). This design reduces the number of parameters while maintaining performance.

  • Global Average Pooling: Instead of fully connected layers, ResNet often utilizes global average pooling, which significantly reduces the number of parameters and helps combat overfitting.

 

Limitations

    The primary challenge is the increased computational requirements associated with very deep networks, which can lead to longer training times and higher resource consumption. Additionally, while residual connections alleviate the vanishing gradient problem, they do not entirely eliminate it, and training very deep networks can still be complex.

  

Loading a Pre-Trained ResNet-50 Model

    Before starting, make sure you have the following Python libraries installed:

  • torch (PyTorch)
  • torchvision (for pre-trained models and transformations)
  • PIL (Python Imaging Library to handle image files)
  • matplotlib (for displaying images)
  • requests (for downloading class labels)

    You can install these libraries using pip.

 
pip install torch torchvision pillow matplotlib requests 
 

    PyTorch provides a variety of pre-trained models via the torchvision library. In this tutorial, we use the ResNet-50 model, which has been pre-trained on the ImageNet dataset. We’ll load the model and set it to evaluation mode (which disables certain layers like dropout that are used only during training).

 
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import requests
import matplotlib.pyplot as plt
import torch.nn as nn


# Load the pre-trained ResNet-50 model
model = models.resnet50(pretrained=True)

# Set the model to evaluation mode (this disables dropout and batch normalization layers)
model.eval()
 

    

Defining Image Preprocessing

    To use the ResNet model, the input image needs to be preprocessed in the same way the model was trained. For ResNet, this includes resizing, center-cropping, and normalizing the image. We’ll use torchvision.transforms to define the following transformations:

  1. Resize the image to 256x256 pixels.
  2. Center-crop the image to 224x224 pixels (ResNet's input size).
  3. Convert the image to a tensor.
  4. Normalize the image with the same mean and standard deviation used in ImageNet training.
 
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256), # Resize the image to 256x256 pixels
transforms.CenterCrop(224), # Crop the center 224x224 pixels
transforms.ToTensor(), # Convert the image to a tensor
# Normalize with ImageNet mean and std
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])

 

Loading ImageNet Class Labels

    The model outputs a tensor of raw scores corresponding to ImageNet class labels. We need to download these labels to interpret the output. We'll fetch the class labels from PyTorch's GitHub repository using the requests library and convert them into a Python list.

 
# URL to fetch the ImageNet class labels
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"

# Send a GET request to download the class labels
response = requests.get(url)
response.raise_for_status() # Check if the request was successful

# Convert the response text directly into a list of class labels
class_labels = [line.strip() for line in response.text.splitlines()]

# Print the first 10 class labels as a quick check
print(class_labels[:10])
 

The output of class_labels:

 
['tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead', 'electric ray', 'stingray', 'cock', 'hen', 'ostrich']

 

Loading and Preprocessing the Image

    Next, we’ll load a sample image, apply the transformations, and prepare it for the model. The image is loaded using the PIL library.

 
# Path to the local image file
image_path = "/test/vgg/images/eagle.jpg" # Replace with your local image path

# Load and preprocess the image
img = Image.open(image_path) # Open the image file
img_t = transform(img) # Apply the transformations
img_t = img_t.unsqueeze(0) # Add batch dimension (required by the model)

 

 Making a Prediction

    The image is ready, we can pass it through the ResNet-50 model to get predictions. The output will be a tensor of raw scores for each class. We’ll use the following steps:

  1. Perform a forward pass through the network.
  2. Get the predicted class index using torch.max().
  3. Convert the predicted scores to probabilities using softmax.
  4. Map the predicted index to the corresponding class label.
 
# Forward pass through the network
output = model(img_t)

# The output is a tensor of raw scores for each class
# Get the predicted class index
_, predicted = torch.max(output, 1)

# Convert the predicted scores to probabilities using softmax
probabilities = nn.Softmax(dim=1)(output)

# Get the predicted class label and its probability
predicted_class_label = class_labels[predicted.item()]
predicted_probability = probabilities[0, predicted].item()

# Print the result
print(f"Predicted: {predicted_class_label}, Probability: {predicted_probability:.4f}")

Finally, we’ll display the input image alongside its predicted class label and probability using matplotlib.

 
# Display the image with the predicted class and probability
plt.imshow(img)
plt.title(f'Predicted: {predicted_class_label}, Probability: {predicted_probability:.4f}')
plt.axis('off') # Hide axes for a cleaner display
plt.show()


Conclusion

    This tutorial provided an explanation of ResNet model and how to use a pre-trained ResNet-50 model in PyTorch to classify an image. Here, we learned:

  • The architecture of ResNet model
  • Loading the ResNet-50 model.
  • Preprocessing an image with the correct transformations.
  • Making predictions and interpret the results using class labels.

 Complete code for this tutorial is listed below.

 

Full code listing

 
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import requests
import matplotlib.pyplot as plt
import torch.nn as nn

# Define the transformations for the input image
transform = transforms.Compose([
transforms.Resize(256), # Resize the image to 256x256
transforms.CenterCrop(224), # Crop the center 224x224 region
transforms.ToTensor(), # Convert the image to a tensor
# Normalize with ImageNet mean and std 
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# URL to fetch the ImageNet class labels
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"

# Send a GET request to download the class labels
response = requests.get(url)
response.raise_for_status() # Check if the request was successful

# Convert the response text directly into a list of class labels
class_labels = [line.strip() for line in response.text.splitlines()]

# Print the first 10 class labels as a quick check
print(class_labels[:10])

# Load the pre-trained ResNet model
model = models.resnet50(pretrained=True) # Using ResNet-50
model.eval() # Set the model to evaluation mode

# Path to the local image file
image_path = "/test/sample_images/eagle.jpg" # Replace with your local image path

# Load and preprocess the image
img = Image.open(image_path) # Open the image file
img_t = transform(img) # Apply the transformations
img_t = img_t.unsqueeze(0) # Add batch dimension (required by the model)

# Forward pass through the network
output = model(img_t)

# The output is a tensor of raw scores for each class
# Get the predicted class index
_, predicted = torch.max(output, 1)

# Convert the predicted scores to probabilities using softmax
probabilities = nn.Softmax(dim=1)(output)

# Get the predicted class label and its probability
predicted_class_label = class_labels[predicted.item()]
predicted_probability = probabilities[0, predicted].item()

# Display the image along with the predicted class and its probability
plt.imshow(img)
plt.title(f'Predicted: {predicted_class_label}, Probability: {predicted_probability:.4f}')
plt.axis('off') # Hide axes for a cleaner image display
plt.show()
 

 

 



No comments:

Post a Comment