Skip to content

Training Guide

Run experiments, track results, and save outputs with Lighter.

This guide covers the full training workflow from config to results.

Basic Commands

Lighter provides four main commands:

# Train and validate
lighter fit config.yaml

# Validate only (requires checkpoint)
lighter validate config.yaml

# Test only (requires checkpoint)
lighter test config.yaml

# Run inference
lighter predict config.yaml

All commands use the same config structure.

The Fit Command

Train your model with automatic validation:

lighter fit config.yaml

What Happens

  1. Loads config from YAML
  2. Instantiates trainer, model, and data
  3. Runs training loop with validation
  4. Saves checkpoints automatically
  5. Logs metrics to configured logger

Output Structure

outputs/
└── YYYY-MM-DD/
    └── HH-MM-SS/
        ├── config.yaml          # Copy of config used
        ├── checkpoints/
        │   ├── last.ckpt       # Latest checkpoint
        │   └── epoch=09-step=1000.ckpt
        └── logs/               # Tensorboard/CSV logs

Resuming Training

Resume from latest checkpoint:

lighter fit config.yaml --ckpt_path path/to/checkpoint.ckpt

Overriding from CLI

Change any config value without editing files:

Single Override

# Change learning rate
lighter fit config.yaml model::optimizer::lr=0.01

# Train longer
lighter fit config.yaml trainer::max_epochs=100

# Use more GPUs
lighter fit config.yaml trainer::devices=4

Multiple Overrides

lighter fit config.yaml \
  model::optimizer::lr=0.01 \
  trainer::max_epochs=100 \
  data::train_dataloader::batch_size=64 \
  trainer::devices=4

Nested Overrides

# Override nested values
lighter fit config.yaml \
  model::network::num_classes=100 \
  model::optimizer::weight_decay=0.0001

Complex Overrides

# Add callbacks from CLI
lighter fit config.yaml \
  'trainer::callbacks=[{_target_: pytorch_lightning.callbacks.EarlyStopping, monitor: val/loss}]'

Merging Configs

Combine multiple YAML files:

lighter fit base.yaml,experiment.yaml

Example: Base + Experiment

base.yaml:

trainer:
  max_epochs: 100
  accelerator: auto
  devices: 1

model:
  _target_: models.MyModel
  network:
    _target_: torchvision.models.resnet18
    num_classes: 10

data:
  _target_: lighter.LighterDataModule
  train_dataloader:
    batch_size: 32

experiment.yaml:

# Override specific values
trainer:
  max_epochs: 200  # Override
  devices: 4       # Add

model:
  optimizer:
    lr: 0.01  # Add optimizer config

Result: Merged config with max_epochs=200, devices=4, new optimizer.

Merge Operators

Control how configs merge:

Replace with =:

# experiment.yaml
trainer:
  =callbacks:  # Replace entire list
    - _target_: pytorch_lightning.callbacks.EarlyStopping
      monitor: val/loss

Delete with ~:

# experiment.yaml
trainer:
  ~callbacks: null  # Remove callbacks entirely

data:
  ~test_dataloader: null  # Remove test dataloader

Checkpointing

Automatic Checkpointing

Lightning saves last.ckpt automatically. For more control:

trainer:
  callbacks:
    - _target_: pytorch_lightning.callbacks.ModelCheckpoint
      dirpath: checkpoints
      filename: 'epoch{epoch:02d}-loss{val/loss:.4f}'
      monitor: val/loss
      mode: min
      save_top_k: 3          # Keep best 3
      save_last: true        # Keep last checkpoint
      every_n_epochs: 1      # Save every epoch

Save Based on Metric

# Save best validation accuracy
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
  monitor: val/acc
  mode: max
  save_top_k: 1
  filename: 'best-acc{val/acc:.4f}'

Multiple Checkpointers

Save different metrics:

trainer:
  callbacks:
    # Best accuracy
    - _target_: pytorch_lightning.callbacks.ModelCheckpoint
      monitor: val/acc
      mode: max
      save_top_k: 1
      filename: 'best-acc'

    # Best loss
    - _target_: pytorch_lightning.callbacks.ModelCheckpoint
      monitor: val/loss
      mode: min
      save_top_k: 1
      filename: 'best-loss'

    # Regular saves
    - _target_: pytorch_lightning.callbacks.ModelCheckpoint
      every_n_epochs: 10
      filename: 'epoch{epoch:02d}'

Loading Checkpoints

For validation/testing:

lighter validate config.yaml --ckpt_path checkpoints/best.ckpt
lighter test config.yaml --ckpt_path checkpoints/best.ckpt

For inference:

lighter predict config.yaml --ckpt_path checkpoints/best.ckpt

To resume training:

lighter fit config.yaml --ckpt_path checkpoints/last.ckpt

Logging

TensorBoard (Default)

trainer:
  logger:
    _target_: pytorch_lightning.loggers.TensorBoardLogger
    save_dir: logs
    name: my_experiment

View logs:

tensorboard --logdir logs

CSV Logger

trainer:
  logger:
    _target_: pytorch_lightning.loggers.CSVLogger
    save_dir: logs
    name: my_experiment

Results saved to logs/my_experiment/version_0/metrics.csv.

Weights & Biases

trainer:
  logger:
    _target_: pytorch_lightning.loggers.WandbLogger
    project: my_project
    name: experiment_1
    save_dir: logs

Multiple Loggers

Use all at once:

trainer:
  logger:
    - _target_: pytorch_lightning.loggers.TensorBoardLogger
      save_dir: logs

    - _target_: pytorch_lightning.loggers.CSVLogger
      save_dir: logs

    - _target_: pytorch_lightning.loggers.WandbLogger
      project: my_project

No Logging

Disable logging:

trainer:
  logger: false

Saving Predictions

Use Writers to save predictions to files.

CSV Writer

Save predictions to CSV:

trainer:
  callbacks:
    - _target_: lighter.callbacks.CSVWriter
      write_interval: batch  # or 'epoch'

Your predict_step should return a dict:

def predict_step(self, batch, batch_idx):
    x, y = batch
    pred = self(x)

    return {
        "prediction": pred.argmax(dim=1),
        "probability": pred.max(dim=1).values,
        "target": y,
    }

Output: predictions.csv with columns for each key.

File Writer

Save predictions to individual files:

trainer:
  callbacks:
    - _target_: lighter.callbacks.FileWriter
      write_interval: batch

Return dict with data and filenames:

def predict_step(self, batch, batch_idx, dataloader_idx=0):
    images, paths = batch

    predictions = self(images)

    # Save each prediction
    results = []
    for i, (pred, path) in enumerate(zip(predictions, paths)):
        results.append({
            "prediction": pred.cpu().numpy(),
            "$id": f"pred_{batch_idx}_{i}",  # Unique filename
        })

    return results

Saves: predictions/pred_0_0.npz, pred_0_1.npz, etc.

Custom Writer

Create your own:

from lighter.callbacks import BaseWriter

class CustomWriter(BaseWriter):
    def write(self, data):
        """Save data however you want."""
        # data is what you returned from predict_step
        output_path = self.output_dir / f"{data['$id']}.pkl"

        with open(output_path, 'wb') as f:
            pickle.dump(data, f)

Use in config:

trainer:
  callbacks:
    - _target_: my_project.writers.CustomWriter
      write_interval: batch

Debugging

Fast Dev Run

Run 1 batch of train/val/test to catch bugs:

lighter fit config.yaml trainer::fast_dev_run=true

Or specify number of batches:

lighter fit config.yaml trainer::fast_dev_run=5

Overfit on Small Batch

Test if model can overfit (sanity check):

lighter fit config.yaml trainer::overfit_batches=10

Trains on same 10 batches repeatedly.

Limit Batches

Run partial epoch:

# Train on 10% of data
lighter fit config.yaml \
  trainer::limit_train_batches=0.1 \
  trainer::limit_val_batches=0.1

Or specific number:

lighter fit config.yaml trainer::limit_train_batches=100

Profiler

Profile your code:

lighter fit config.yaml trainer::profiler=simple

Options:

  • simple - Basic profiling
  • advanced - Detailed profiling
  • pytorch - PyTorch profiler

Results saved to logs directory.

Find Learning Rate

Automatically find optimal LR:

trainer:
  _target_: pytorch_lightning.Trainer
  callbacks:
    - _target_: pytorch_lightning.callbacks.LearningRateFinder
      min_lr: 1e-6
      max_lr: 1.0

Or run tuner:

lighter fit config.yaml trainer::auto_lr_find=true

Multi-GPU Training

Single Machine, Multiple GPUs

trainer:
  devices: 4  # Use 4 GPUs
  strategy: ddp  # Distributed Data Parallel

Or use all available GPUs:

trainer:
  devices: -1  # All GPUs
  strategy: ddp

Strategy Options

DDP (Recommended):

trainer:
  strategy: ddp

DDP Spawn:

trainer:
  strategy: ddp_spawn

DeepSpeed:

trainer:
  strategy:
    _target_: pytorch_lightning.strategies.DeepSpeedStrategy
    stage: 2

FSDP (Fully Sharded):

trainer:
  strategy: fsdp

Batch Size Adjustment

Scale batch size with GPUs:

vars:
  num_gpus: 4
  per_gpu_batch: 32

data:
  train_dataloader:
    batch_size: "$%vars::per_gpu_batch * %vars::num_gpus"

Or keep per-GPU batch size:

# Each GPU gets batch_size=32
data:
  train_dataloader:
    batch_size: 32

Mixed Precision Training

Use 16-bit precision for faster training:

trainer:
  precision: 16

Or BFloat16:

trainer:
  precision: "bf16-mixed"

Automatic mixed precision (AMP) is handled by Lightning.

Gradient Accumulation

Simulate larger batch sizes:

trainer:
  accumulate_grad_batches: 4

Effective batch size = batch_size × accumulate_grad_batches.

Example:

# Effective batch size = 32 × 4 = 128
data:
  train_dataloader:
    batch_size: 32

trainer:
  accumulate_grad_batches: 4

Early Stopping

Stop training when metric stops improving:

trainer:
  callbacks:
    - _target_: pytorch_lightning.callbacks.EarlyStopping
      monitor: val/loss
      patience: 10
      mode: min
      verbose: true

Parameters:

  • monitor: Metric to track
  • patience: Epochs to wait before stopping
  • mode: min or max
  • min_delta: Minimum change to qualify as improvement

Progress Bars

Default Progress Bar

Shows by default. Disable with:

trainer:
  enable_progress_bar: false

Custom Progress Bar

trainer:
  callbacks:
    - _target_: pytorch_lightning.callbacks.RichProgressBar

Or:

trainer:
  callbacks:
    - _target_: pytorch_lightning.callbacks.TQDMProgressBar
      refresh_rate: 10

Validation

Validate Only

Run validation on a checkpoint:

lighter validate config.yaml --ckpt_path checkpoints/best.ckpt

Validation Frequency

Validate every N epochs:

trainer:
  check_val_every_n_epoch: 5

Or every N steps:

trainer:
  val_check_interval: 0.5  # Validate twice per epoch

Or specific number of steps:

trainer:
  val_check_interval: 100  # Every 100 training steps

Skip Validation

trainer:
  limit_val_batches: 0  # No validation

Testing

Run final test after training:

# Fit then test automatically
lighter fit config.yaml

# Test separately
lighter test config.yaml --ckpt_path checkpoints/best.ckpt

Test During Fit

Not recommended, but possible by loading checkpoint at end of fit.

Prediction/Inference

Run inference on data:

lighter predict config.yaml --ckpt_path checkpoints/best.ckpt

Requires:

  1. predict_step in your module
  2. predict_dataloader in your data config
  3. Optional: Writer callback to save results

Example config:

data:
  predict_dataloader:
    _target_: torch.utils.data.DataLoader
    batch_size: 32
    dataset:
      _target_: my_project.data.PredictionDataset
      root: ./inference_data

trainer:
  callbacks:
    - _target_: lighter.callbacks.FileWriter
      write_interval: batch

Example predict_step:

def predict_step(self, batch, batch_idx):
    images = batch
    predictions = self(images)

    return {
        "predictions": predictions.cpu(),
        "batch_idx": batch_idx,
    }

Experiment Organization

my_project/
├── __lighter__.py
├── models.py
├── data.py
├── configs/
│   ├── base.yaml           # Baseline config
│   ├── resnet50.yaml       # Architecture variants
│   ├── augmented.yaml      # Augmentation experiments
│   └── ablation/
│       ├── no_dropout.yaml
│       └── no_batchnorm.yaml
└── outputs/                # Generated by Lighter
    └── YYYY-MM-DD/
        └── HH-MM-SS/

Config Naming

Use descriptive names:

configs/
├── baseline-resnet18.yaml
├── baseline-resnet50.yaml
├── lr0.01-batch128.yaml
├── augment-strong.yaml
└── finetune-imagenet.yaml

Version Control

Track configs in git:

git add configs/
git commit -m "Add strong augmentation experiment"

Compare experiments:

git diff configs/baseline.yaml configs/improved.yaml

Common Workflows

Create configs for different hyperparameters:

# Try different learning rates
lighter fit base.yaml model::optimizer::lr=0.001
lighter fit base.yaml model::optimizer::lr=0.01
lighter fit base.yaml model::optimizer::lr=0.1

# Try different architectures
lighter fit base.yaml model::network::_target_=torchvision.models.resnet18
lighter fit base.yaml model::network::_target_=torchvision.models.resnet50
lighter fit base.yaml model::network::_target_=torchvision.models.efficientnet_b0

Workflow 2: Resume Failed Training

Training crashed? Resume:

lighter fit config.yaml --ckpt_path outputs/2024-01-15/10-30-45/checkpoints/last.ckpt

Workflow 3: Incremental Training

Train, then finetune:

# Initial training
lighter fit pretrain.yaml

# Finetune with lower LR
lighter fit finetune.yaml \
  --ckpt_path outputs/.../checkpoints/last.ckpt \
  model::optimizer::lr=0.0001

Workflow 4: Cross-Validation

Run multiple folds:

for fold in {0..4}; do
  lighter fit config.yaml data::fold=$fold
done

Config:

# data.py
class CVDataset(Dataset):
    def __init__(self, root, fold, num_folds=5):
        # Split data by fold
        ...

Output Management

Change Output Directory

# In config
trainer:
  default_root_dir: ./my_outputs

Or CLI:

lighter fit config.yaml trainer::default_root_dir=./my_outputs

Disable Checkpoints

trainer:
  enable_checkpointing: false

Save Frequency

Save less often:

trainer:
  callbacks:
    - _target_: pytorch_lightning.callbacks.ModelCheckpoint
      every_n_epochs: 10  # Save every 10 epochs

Or based on steps:

trainer:
  callbacks:
    - _target_: pytorch_lightning.callbacks.ModelCheckpoint
      every_n_train_steps: 1000

Troubleshooting

Out of Memory

Solutions:

  1. Reduce batch size:

    lighter fit config.yaml data::train_dataloader::batch_size=16
    

  2. Use gradient accumulation:

    trainer:
      accumulate_grad_batches: 4
    

  3. Use mixed precision:

    trainer:
      precision: 16
    

  4. Reduce model size:

    lighter fit config.yaml model::network::_target_=torchvision.models.resnet18
    

Training Too Slow

Solutions:

  1. Use more workers:

    data:
      train_dataloader:
        num_workers: 8
    

  2. Pin memory:

    data:
      train_dataloader:
        pin_memory: true
    

  3. Use multiple GPUs:

    trainer:
      devices: 4
      strategy: ddp
    

  4. Mixed precision:

    trainer:
      precision: 16
    

Model Not Learning

Debug steps:

  1. Overfit on small batch:

    lighter fit config.yaml trainer::overfit_batches=10
    

  2. Check learning rate:

    lighter fit config.yaml trainer::auto_lr_find=true
    

  3. Visualize data:

    # In training_step
    if batch_idx == 0:
        self.logger.experiment.add_images("train/batch", x[:8])
    

  4. Profile:

    lighter fit config.yaml trainer::profiler=simple
    

Next Steps

Quick Reference

# Basic commands
lighter fit config.yaml
lighter validate config.yaml
lighter test config.yaml
lighter predict config.yaml

# Override from CLI
lighter fit config.yaml key::path=value

# Merge configs
lighter fit base.yaml,experiment.yaml

# Resume training
lighter fit config.yaml --ckpt_path path/to/last.ckpt

# Multi-GPU
lighter fit config.yaml trainer::devices=4 trainer::strategy=ddp

# Debug
lighter fit config.yaml trainer::fast_dev_run=true
lighter fit config.yaml trainer::overfit_batches=10