Image Classification
In this tutorial we will learn how to:
- Set up the project folder
- Implement a custom CNN model
- Define the config for training and testing on CIFAR10 dataset
- Train and test the model using Lighter
Setting up the Project
First, create a new project directory named image_classification
with the following structure:
image_classification/
├── __init__.py
├── experiments/
│ └── config.yaml
└── models/
├── __init__.py
└── simple_cnn.py
If you're using Unix/Linux, you can create this structure with the following command:
mkdir -p image_classification/{models,experiments} && touch image_classification/__init__.py image_classification/experiments/config.yaml image_classification/models/__init__.py image_classification/models/simple_cnn.py
Warning
Do not forget to add __init__.py
to folders that contain Python modules. For more details, refer to the Project Module guide.
Setting up Dataloaders
system
's dataloaders
section defines dataloaders for train
, val
, test
, and predict
stages. Let's start by configuring the training dataloader for CIFAR10.
Note
The complete configuration is provided few sections later.
system:
# ...
dataloaders:
train:
_target_: torch.utils.data.DataLoader
batch_size: 32
shuffle: True
num_workers: 4
dataset:
_target_: torchvision.datasets.CIFAR10
root: cifar10/
download: True
train: True
transform:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: torchvision.transforms.Normalize
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
This is equivalent to the following Python code:
import torch
import torchvision
transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_dataset = torchvision.datasets.CIFAR10(
root="cifar10/",
download=True,
train=True,
transform=transforms
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
Setting up the Model
Defining a Custom Model
We will use a simple CNN for image classification. Define this model in image_classification/models/simple_cnn.py
.
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.fc = nn.Linear(64 * 8 * 8, num_classes) # Assuming 32x32 images
def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = self.flatten(x)
x = self.fc(x)
return x
Reference the Custom Model in config.yaml
Now that we have defined the model, let's specify it in the config.yaml
file.
project: /path/to/image_classification
system:
model:
_target_: project.models.simple_cnn.SimpleCNN
num_classes: 10 # Matches CIFAR10 classes
The project
specifies the root directory that Lighter will import as a Python module, allowing access to custom code and models. This is why we needed to set up __init__.py
files. As a result, we reference the custom model we defined simply using project.models.simple_cnn.SimpleCNN
.
Note
Learn more about the Project Module in the How-To guide.
Warning
Change /path/to/image_classification
to the location of your project directory. Both absolute and relative paths work. To get the absolute path of your image_classification/
directory, run:
If you're using a relative path, pay attention to where you are executing lighter
from. For example, if running lighter
while in the image_classification/
folder, you can simply set project: .
to reference the current directory.
Complete Configuration
Now, let's put together the complete config.yaml
file for training the SimpleCNN
on CIFAR10:
project: /path/to/image_classification # Update as noted above
trainer:
_target_: pytorch_lightning.Trainer
accelerator: "auto" # Use GPU if available, else CPU
max_epochs: 10
system:
_target_: lighter.System
model:
_target_: project.models.simple_cnn.SimpleCNN
num_classes: 10
criterion:
_target_: torch.nn.CrossEntropyLoss
optimizer:
_target_: torch.optim.Adam
params: "$@system#model.parameters()" # Link to model's learnable parameters
lr: 1.0e-3
metrics:
train:
- _target_: torchmetrics.Accuracy
task: "multiclass"
num_classes: 10
test: "%#train"
dataloaders:
train:
_target_: torch.utils.data.DataLoader
batch_size: 32
shuffle: True
num_workers: 4
dataset:
_target_: torchvision.datasets.CIFAR10
root: cifar10/
download: True
train: True
transform:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: torchvision.transforms.Normalize
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
test:
_target_: torch.utils.data.DataLoader
batch_size: 32
num_workers: 4
dataset:
_target_: torchvision.datasets.CIFAR10
root: cifar10/
download: True
train: False
transform:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: torchvision.transforms.Normalize
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
This configuration defines all the necessary components for training and testing:
trainer
: Configures the PyTorch Lightning Trainer to use automatic accelerator selection and train for a maximum of 10 epochs.system
: Defines the Lighter System.model
: Specifies theSimpleCNN
model, a custom model you defined inimage_classification/models/simple_cnn.py
.criterion
: Sets the loss function toCrossEntropyLoss
.optimizer
: Uses theAdam
optimizer with a learning rate of 1.0e-3.metrics
: Defines accuracy metrics for training and testing stages.dataloaders
: ConfiguresDataLoader
s fortrain
andtest
stages, using the CIFAR10 dataset and appropriate transforms.
Training Execution
To start training, save the above configuration as config.yaml
in your project directory. Ensure that you have created the image_classification/models/simple_cnn.py
file as well. Then, open your terminal, navigate to your project directory, and run the following command:
Lighter will parse your config.yaml
, initialize all the components, and start the training process using PyTorch Lightning. You will see the training progress, including loss and metrics, logged in your terminal.
Evaluation
After training, you can evaluate your model on the test set:
Lighter will load the best checkpoint saved during training (if a ModelCheckpoint
callback is used in the configuration, which is often the default in more complex setups) and evaluate the model on the specified dataloader, reporting the metrics defined in the system.metrics
section for thetest
stage, respectively.
Next Steps
In this tutorial, you have successfully trained and evaluated an image classification model on the CIFAR10 dataset using Lighter.
You now have a solid foundation for building more complex experiments with Lighter. Head over to the How-To guides to explore Lighter's features in more detail.