Using Densnet Model for Image Classification with PyTorch

        In this tutorial, we'll learn about DenseNet model and how to use a pre-trained DenseNet121 model for image classification with PyTorch. We'll go through the steps of loading a pre-trained model, preprocessing image, and using the model to predict its class label, as well as displaying the results.The tutorial covers:

  1. Introduction to DenseNet model
  2. Loading a pre-trained DenseNet121 model
  3. Defining Image Preprocessing
  4. Loading ImageNet Class Labels
  5. Making a Prediction
  6. Conclusion
  7. Full code listing

 

Introduction to DenseNet model

    The DenseNet model is a type of deep learning network built to make image recognition more efficient by reusing features. It was introduced in the paper “Densely Connected Convolutional Networks.” In DenseNet, every layer is directly connected to all the layers that come after it. This approach maximizes feature propagation and minimizes redundant computations, and creates a compact and efficient model.


Dense block   

    A dense block is the main part of DenseNet where layers are tightly connected. Each layer gets inputs from all previous layers and adds its own output to the mix. This method helps the network reuse features, making it better at learning useful patterns.

  • Batch Normalization: Makes the inputs consistent to speed up learning.
  • ReLU Activation: Adds non-linearity to help the model learn complex patterns.
  • 3x3 Convolution: Captures important details in the data while being efficient.

    By combining outputs from all layers, the model avoids learning the same thing multiple times, saving time and resources.

 

Transition Layer

    A transition layer connects Dense Blocks and keeps the model manageable. It reduces the size of feature maps and their spatial dimensions to control the model's growth.

  • 1x1 Convolution: Cuts down the number of feature maps to save memory.
  • Average Pooling: Shrinks the size of feature maps, lowering computational costs.

    These layers help the model stay compact and prevent overfitting, even when it gets deeper.

 

Key Characteristics of Densnet Model

  • Dense Connectivity: Each layer receives input from all preceding layers and contributes its output to all subsequent ones, promoting feature reuse and efficient learning.
  • Growth Rate: Determines the number of new feature maps added by each layer, controlling model complexity.
  • Dense Block: A series of densely connected layers concatenating their outputs.
  • Transition Layers: Reduce feature map dimensions using 1x1 convolutions and average pooling, keeping the model efficient.
  • Global Average Pooling: Replaces fully connected layers to reduce parameters and minimize overfitting.

 

Limitations   

Despite its advantages, DenseNet has some challenges:

  • High Memory Usage: Even though DenseNet uses fewer parameters, combining all previous outputs can take up a lot of memory during training.
  • More Computation: Connecting every layer adds extra calculations, especially in very deep networks or when working with large images.
  • Scalability Challenges: Using DenseNet for tasks with larger images or complex data needs careful adjustments to settings like growth rate and layer design.
  • Longer Training Time: Training can take longer due to the extra connections.

  

Loading a DenseNet121 Model

    Before starting, make sure you have the following Python libraries installed:

  • torch (PyTorch)
  • torchvision (for pre-trained models and transformations)
  • PIL (Python Imaging Library to handle image files)
  • matplotlib (for displaying images)
  • requests (for downloading class labels)

    You can install these libraries using pip.

 
pip install torch torchvision pillow matplotlib requests 
 

    PyTorch provides a variety of pre-trained models via the torchvision library. In this tutorial, we use the Densnet121 model, which has been pre-trained on the ImageNet dataset. The 121 refers to the number of layers in the model. 

    We download the model and set it to evaluation mode (which disables certain layers like dropout that are used only during training).

 
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import requests
import matplotlib.pyplot as plt
import torch.nn as nn
 
 
# Load pre-trained DenseNet model
# Set pretrained=True to use the weights trained on ImageNet
model = models.densenet121(pretrained=True)

# Set the model to evaluation mode
model.eval()
 

    

Defining Image Preprocessing

    To use the Densnet model, the input image needs to be preprocessed in the same way the model was trained. For densnet
, this includes resizing, center-cropping, and normalizing the image. We’ll use torchvision.transforms to define the following transformations:

  1. Resize the image to 256x256 pixels.
  2. Center-crop the image to 224x224 pixels (densnet's input size).
  3. Convert the image to a tensor.
  4. Normalize the image with the same mean and standard deviation used in ImageNet training.
 
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256), # Resize the image to 256x256 pixels
transforms.CenterCrop(224), # Crop the center 224x224 pixels
transforms.ToTensor(), # Convert the image to a tensor
# Normalize with ImageNet mean and std
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])

 

Loading ImageNet Class Labels

    The model outputs a tensor of raw scores corresponding to ImageNet class labels. We need to download these labels to interpret the output. We'll fetch the class labels from PyTorch's GitHub repository using the requests library and convert them into a Python list. 

    Once you download the class label data, you can save it to a file and use it locally.

 
# URL to fetch the ImageNet class labels
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"

# Send a GET request to download the class labels
response = requests.get(url)
response.raise_for_status() # Check if the request was successful

# Convert the response text directly into a list of class labels
class_labels = [line.strip() for line in response.text.splitlines()]

# Print the first 10 class labels as a quick check
print(class_labels[:10])
 

The output of class_labels:

 
['tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead', 'electric ray', 'stingray', 'cock', 'hen', 'ostrich']

 

Loading and Preprocessing the Image

    Next, we’ll load a sample image, apply the transformations, and prepare it for the model. The image is loaded using the PIL library.

 
# Path to the local image file
image_path = "/data/sample_images/IMG_1049.JPG" # Replace with your local image path

# Load and preprocess the image
img = Image.open(image_path) # Open the image file
img_t = transform(img) # Apply the transformations
img_t = img_t.unsqueeze(0) # Add batch dimension (required by the model)

 

 Making a Prediction

    The image is ready, we can pass it through the
Densnet model to get predictions. The output will be a tensor of raw scores for each class. We’ll use the following steps:

  1. Perform a forward pass through the network.
  2. Get the predicted class index using torch.max().
  3. Convert the predicted scores to probabilities using softmax.
  4. Map the predicted index to the corresponding class label.
 
# Forward pass through the network
output = model(img_t)

# The output is a tensor of raw scores for each class
# Get the predicted class index
_, predicted = torch.max(output, 1)

# Convert the predicted scores to probabilities using softmax
probabilities = nn.Softmax(dim=1)(output)

# Get the predicted class label and its probability
predicted_class_label = class_labels[predicted.item()]
predicted_probability = probabilities[0, predicted].item()

# Print the result
print(f"Predicted: {predicted_class_label}, Probability: {predicted_probability:.4f}")

Finally, we’ll display the input image alongside its predicted class label and probability using matplotlib.

 
# Display the image with the predicted class and probability
plt.imshow(img)
plt.title(f'Predicted: {predicted_class_label}, Probability: {predicted_probability:.4f}')
plt.axis('off') # Hide axes for a cleaner display
plt.show()






Conclusion

    This tutorial provided an explanation of Densnet model and how to use a pre-trained Densnet121 model in PyTorch to classify an image. Here, we learned:

  • The architecture of Densnet model
  • Loading the Densnet121 model.
  • Preprocessing an image with the correct transformations.
  • Making predictions and interpret the results using class labels.

 Complete code for this tutorial is listed below.

 

Full code listing

 
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import requests
import matplotlib.pyplot as plt
import torch.nn as nn

# Define the transformations for the input image
transform = transforms.Compose([
transforms.Resize(256), # Resize the image to 256x256
transforms.CenterCrop(224), # Crop the center 224x224 region
transforms.ToTensor(), # Convert the image to a tensor
# Normalize with ImageNet mean and std 
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# URL to fetch the ImageNet class labels
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"

# Send a GET request to download the class labels
response = requests.get(url)
response.raise_for_status() # Check if the request was successful

# Convert the response text directly into a list of class labels
class_labels = [line.strip() for line in response.text.splitlines()]

# Print the first 10 class labels as a quick check
print(class_labels[:10])
 
# Path to the local image file, replace with your local image path
image_path = "/Users/user/data/sample_images/IMG_1049.JPG"
 
# Load and preprocess the image
img = Image.open(image_path) # Open the image file
img_t = transform(img) # Apply the transformations
img_t = img_t.unsqueeze(0) # Add batch dimension (required by the model)
 
# Load pre-trained DenseNet model
# Set pretrained=True to use the weights trained on ImageNet
model = models.densenet121(pretrained=True)

# Set the model to evaluation mode
model.eval()
 
# Forward pass through the network
output = model(img_t)

# The output is a tensor of raw scores for each class
# Get the predicted class index
_, predicted = torch.max(output, 1)

# Convert the predicted scores to probabilities using softmax
probabilities = nn.Softmax(dim=1)(output)

# Get the predicted class label and its probability
predicted_class_label = class_labels[predicted.item()]
predicted_probability = probabilities[0, predicted].item()

# Display the image along with the predicted class and its probability
plt.imshow(img)
plt.title(f'Predicted: {predicted_class_label}, Probability: {predicted_probability:.4f}')
plt.axis('off') # Hide axes for a cleaner image display
plt.show()
 

 

 





No comments:

Post a Comment