Image Classification
In this tutorial we will learn how to:
- Set up the project folder
- Implement a custom CNN model
- Define the config for training and testing on CIFAR10 dataset
- Train and test the model using Lighter
Setting up the Project
First, create a new project directory named image_classification
with the following structure:
Warning
Do not forget the __init__.py
files. For more details, refer to the Project Module guide.
Setting up Dataloaders
system
's dataloaders
section defines dataloaders for train
, val
, test
, and predict
stages. Let's start by configuring the training dataloader for CIFAR10.
Note
The complete configuration is provided few sections later.
system:
# ...
dataloaders:
train:
_target_: torch.utils.data.DataLoader
batch_size: 32
shuffle: True
num_workers: 4
dataset:
_target_: torchvision.datasets.CIFAR10
root: cifar10/
download: True
train: True
transform:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: torchvision.transforms.Normalize
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
This is equivalent to the following Python code:
import torch
import torchvision
transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_dataset = torchvision.datasets.CIFAR10(
root="cifar10/",
download=True,
train=True,
transform=transforms
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
Setting up the Model
Defining a Custom Model
We will use a simple CNN for image classification. Define this model in image_classification/models/simple_cnn.py
.
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.fc = nn.Linear(64 * 8 * 8, num_classes) # Assuming 32x32 images
def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = self.flatten(x)
x = self.fc(x)
return x
Reference the Custom Model in config.yaml
Now that we have defined the model, let's specify it in the config.yaml
file.
project: /path/to/image_classification
system:
model:
_target_: project.models.simple_cnn.SimpleCNN
num_classes: 10 # Matches CIFAR10 classes
The project
section tells Lighter where to import the project module from. This allows us to use our SimpleCNN
class by referencing project.models.simple_cnn.SimpleCNN
.
Complete Configuration
Now, let's put together the complete config.yaml
file for training the SimpleCNN
on CIFAR10:
project: /path/to/image_classification
trainer:
_target_: pytorch_lightning.Trainer
accelerator: "auto" # Use GPU if available, else CPU
max_epochs: 10
system:
_target_: lighter.System
model:
_target_: project.models.simple_cnn.SimpleCNN
num_classes: 10
criterion:
_target_: torch.nn.CrossEntropyLoss
optimizer:
_target_: torch.optim.Adam
lr: 1.0e-3
metrics:
train:
- _target_: torchmetrics.Accuracy
task: "multiclass"
num_classes: 10
test: "%#train"
dataloaders:
train:
_target_: torch.utils.data.DataLoader
batch_size: 32
shuffle: True
num_workers: 4
dataset:
_target_: torchvision.datasets.CIFAR10
root: cifar10/
download: True
train: True
transform:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: torchvision.transforms.Normalize
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
test:
_target_: torch.utils.data.DataLoader
batch_size: 32
num_workers: 4
dataset:
_target_: torchvision.datasets.CIFAR10
root: cifar10/
download: True
train: False
transform:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: torchvision.transforms.Normalize
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
This configuration defines all the necessary components for training and testing:
trainer
: Configures the PyTorch Lightning Trainer to use automatic accelerator selection and train for a maximum of 10 epochs.system
: Defines the Lighter System.model
: Specifies theSimpleCNN
model, a custom model you defined inimage_classification/models/simple_cnn.py
.criterion
: Sets the loss function toCrossEntropyLoss
.optimizer
: Uses theAdam
optimizer with a learning rate of 1.0e-3.metrics
: Defines accuracy metrics for training and testing stages.dataloaders
: ConfiguresDataLoader
s fortrain
andtest
stages, using the CIFAR10 dataset and appropriate transforms.
Training Execution
To start training, save the above configuration as config.yaml
in your project directory. Ensure that you have created the image_classification/models/simple_cnn.py
file as well. Then, open your terminal, navigate to your project directory, and run the following command:
Lighter will parse your config.yaml
, initialize all the components, and start the training process using PyTorch Lightning. You will see the training progress, including loss and metrics, logged in your terminal.
Evaluation
After training, you can evaluate your model on the test set:
Lighter will load the best checkpoint saved during training (if a ModelCheckpoint
callback is used in the configuration, which is often the default in more complex setups) and evaluate the model on the specified dataloader, reporting the metrics defined in the system.metrics
section for thetest
stage, respectively.
Next Steps
In this tutorial, you have successfully trained and evaluated an image classification model on the CIFAR10 dataset using Lighter.
You now have a solid foundation for building more complex experiments with Lighter. Head over to the How-To guides to explore Lighter's features in more detail.