Skip to content

Index

Lighter is a framework for streamlining deep learning experiments with configuration files.

LighterDataModule

Bases: LightningDataModule

A lightweight wrapper for organizing dataloaders in configuration files.

This class exists purely as a convenience helper - it wraps pre-configured PyTorch DataLoaders so you can use Lighter's configuration system without having to write a custom LightningDataModule from scratch.

When to use LighterDataModule: - Simple datasets that don't need complex preprocessing - Quick experiments where you want to configure dataloaders in YAML - Cases where your data pipeline is straightforward

When to write a custom LightningDataModule: - Complex data preparation (downloading, extraction, processing) - Multi-process data setup with prepare_data() and setup() - Advanced preprocessing pipelines - Data that requires stage-specific transformations - Sharing reusable data modules across projects

Parameters:

Name Type Description Default
train_dataloader DataLoader | None

DataLoader for training (used in fit stage)

None
val_dataloader DataLoader | None

DataLoader for validation (used in fit and validate stages)

None
test_dataloader DataLoader | None

DataLoader for testing (used in test stage)

None
predict_dataloader DataLoader | None

DataLoader for predictions (used in predict stage)

None
Example
# config.yaml
data:
  _target_: lighter.LighterDataModule
  train_dataloader:
    _target_: torch.utils.data.DataLoader
    batch_size: 32
    shuffle: true
    dataset:
      _target_: torchvision.datasets.CIFAR10
      root: ./data
      train: true
      transform:
        _target_: torchvision.transforms.ToTensor
  val_dataloader:
    _target_: torch.utils.data.DataLoader
    batch_size: 32
    shuffle: false
    dataset:
      _target_: torchvision.datasets.CIFAR10
      root: ./data
      train: false
      transform:
        _target_: torchvision.transforms.ToTensor

model:
  _target_: project.MyModel
  network: ...
  optimizer: ...

trainer:
  _target_: pytorch_lightning.Trainer
  max_epochs: 10
Note

This is just a thin wrapper around PyTorch Lightning's LightningDataModule. It doesn't add any special logic - it simply holds your dataloaders and returns them when Lightning asks for them.

If you need more control (prepare_data, setup, etc.), write a custom LightningDataModule instead.

Source code in src/lighter/data.py
class LighterDataModule(LightningDataModule):
    """
    A lightweight wrapper for organizing dataloaders in configuration files.

    This class exists purely as a convenience helper - it wraps pre-configured
    PyTorch DataLoaders so you can use Lighter's configuration system without
    having to write a custom LightningDataModule from scratch.

    When to use LighterDataModule:
    - Simple datasets that don't need complex preprocessing
    - Quick experiments where you want to configure dataloaders in YAML
    - Cases where your data pipeline is straightforward

    When to write a custom LightningDataModule:
    - Complex data preparation (downloading, extraction, processing)
    - Multi-process data setup with prepare_data() and setup()
    - Advanced preprocessing pipelines
    - Data that requires stage-specific transformations
    - Sharing reusable data modules across projects

    Args:
        train_dataloader: DataLoader for training (used in fit stage)
        val_dataloader: DataLoader for validation (used in fit and validate stages)
        test_dataloader: DataLoader for testing (used in test stage)
        predict_dataloader: DataLoader for predictions (used in predict stage)

    Example:
        ```yaml
        # config.yaml
        data:
          _target_: lighter.LighterDataModule
          train_dataloader:
            _target_: torch.utils.data.DataLoader
            batch_size: 32
            shuffle: true
            dataset:
              _target_: torchvision.datasets.CIFAR10
              root: ./data
              train: true
              transform:
                _target_: torchvision.transforms.ToTensor
          val_dataloader:
            _target_: torch.utils.data.DataLoader
            batch_size: 32
            shuffle: false
            dataset:
              _target_: torchvision.datasets.CIFAR10
              root: ./data
              train: false
              transform:
                _target_: torchvision.transforms.ToTensor

        model:
          _target_: project.MyModel
          network: ...
          optimizer: ...

        trainer:
          _target_: pytorch_lightning.Trainer
          max_epochs: 10
        ```

    Note:
        This is just a thin wrapper around PyTorch Lightning's LightningDataModule.
        It doesn't add any special logic - it simply holds your dataloaders and
        returns them when Lightning asks for them.

        If you need more control (prepare_data, setup, etc.), write a custom
        LightningDataModule instead.
    """

    def __init__(
        self,
        train_dataloader: DataLoader | None = None,
        val_dataloader: DataLoader | None = None,
        test_dataloader: DataLoader | None = None,
        predict_dataloader: DataLoader | None = None,
    ) -> None:
        super().__init__()
        self._train_dataloader = train_dataloader
        self._val_dataloader = val_dataloader
        self._test_dataloader = test_dataloader
        self._predict_dataloader = predict_dataloader

    def train_dataloader(self) -> DataLoader | None:
        """Return the training dataloader."""
        return self._train_dataloader

    def val_dataloader(self) -> DataLoader | None:
        """Return the validation dataloader."""
        return self._val_dataloader

    def test_dataloader(self) -> DataLoader | None:
        """Return the test dataloader."""
        return self._test_dataloader

    def predict_dataloader(self) -> DataLoader | None:
        """Return the prediction dataloader."""
        return self._predict_dataloader

predict_dataloader()

Return the prediction dataloader.

Source code in src/lighter/data.py
def predict_dataloader(self) -> DataLoader | None:
    """Return the prediction dataloader."""
    return self._predict_dataloader

test_dataloader()

Return the test dataloader.

Source code in src/lighter/data.py
def test_dataloader(self) -> DataLoader | None:
    """Return the test dataloader."""
    return self._test_dataloader

train_dataloader()

Return the training dataloader.

Source code in src/lighter/data.py
def train_dataloader(self) -> DataLoader | None:
    """Return the training dataloader."""
    return self._train_dataloader

val_dataloader()

Return the validation dataloader.

Source code in src/lighter/data.py
def val_dataloader(self) -> DataLoader | None:
    """Return the validation dataloader."""
    return self._val_dataloader

LighterModule

Bases: LightningModule

Minimal base class for deep learning models in Lighter.

Users should: - Subclass and implement the step methods they need (training_step, validation_step, etc.) - Define their own batch processing, loss computation, metric updates - Configure data separately using the 'data:' config key

Framework provides: - Automatic dual logging of losses (step + epoch) - Automatic dual logging of metrics (step + epoch) - Optimizer configuration

Parameters:

Name Type Description Default
network Module

Neural network model

required
criterion Callable | None

Loss function (optional, user can compute loss manually in step)

None
optimizer Optimizer | None

Optimizer (required for training)

None
scheduler LRScheduler | None

Learning rate scheduler (optional)

None
train_metrics Metric | MetricCollection | None

Training metrics (optional, user calls them in step)

None
val_metrics Metric | MetricCollection | None

Validation metrics (optional)

None
test_metrics Metric | MetricCollection | None

Test metrics (optional)

None
Example

class MyModel(LighterModule): def training_step(self, batch, batch_idx): x, y = batch pred = self(x)

    # Option 1: Use self.criterion if provided
    loss = self.criterion(pred, y) if self.criterion else F.cross_entropy(pred, y)

    # User calls metrics themselves
    if self.train_metrics:
        self.train_metrics(pred, y)

    return {"loss": loss, "pred": pred, "target": y}

def validation_step(self, batch, batch_idx):
    x, y = batch
    pred = self(x)
    loss = self.criterion(pred, y) if self.criterion else F.cross_entropy(pred, y)
    if self.val_metrics:
        self.val_metrics(pred, y)
    return {"loss": loss, "pred": pred, "target": y}

def test_step(self, batch, batch_idx):
    x, y = batch
    pred = self(x)
    if self.test_metrics:
        self.test_metrics(pred, y)
    return {"pred": pred, "target": y}

def predict_step(self, batch, batch_idx):
    x, y = batch
    pred = self(x)
    return pred
Source code in src/lighter/model.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
class LighterModule(pl.LightningModule):
    """
    Minimal base class for deep learning models in Lighter.

    Users should:
    - Subclass and implement the step methods they need (training_step, validation_step, etc.)
    - Define their own batch processing, loss computation, metric updates
    - Configure data separately using the 'data:' config key

    Framework provides:
    - Automatic dual logging of losses (step + epoch)
    - Automatic dual logging of metrics (step + epoch)
    - Optimizer configuration

    Args:
        network: Neural network model
        criterion: Loss function (optional, user can compute loss manually in step)
        optimizer: Optimizer (required for training)
        scheduler: Learning rate scheduler (optional)
        train_metrics: Training metrics (optional, user calls them in step)
        val_metrics: Validation metrics (optional)
        test_metrics: Test metrics (optional)

    Example:
        class MyModel(LighterModule):
            def training_step(self, batch, batch_idx):
                x, y = batch
                pred = self(x)

                # Option 1: Use self.criterion if provided
                loss = self.criterion(pred, y) if self.criterion else F.cross_entropy(pred, y)

                # User calls metrics themselves
                if self.train_metrics:
                    self.train_metrics(pred, y)

                return {"loss": loss, "pred": pred, "target": y}

            def validation_step(self, batch, batch_idx):
                x, y = batch
                pred = self(x)
                loss = self.criterion(pred, y) if self.criterion else F.cross_entropy(pred, y)
                if self.val_metrics:
                    self.val_metrics(pred, y)
                return {"loss": loss, "pred": pred, "target": y}

            def test_step(self, batch, batch_idx):
                x, y = batch
                pred = self(x)
                if self.test_metrics:
                    self.test_metrics(pred, y)
                return {"pred": pred, "target": y}

            def predict_step(self, batch, batch_idx):
                x, y = batch
                pred = self(x)
                return pred
    """

    def __init__(
        self,
        network: Module,
        criterion: Callable | None = None,
        optimizer: Optimizer | None = None,
        scheduler: LRScheduler | None = None,
        train_metrics: Metric | MetricCollection | None = None,
        val_metrics: Metric | MetricCollection | None = None,
        test_metrics: Metric | MetricCollection | None = None,
    ) -> None:
        super().__init__()

        # Core components
        self.network = network
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler

        # Metrics (registered as modules)
        self.train_metrics = self._prepare_metrics(train_metrics)
        self.val_metrics = self._prepare_metrics(val_metrics)
        self.test_metrics = self._prepare_metrics(test_metrics)

    def _prepare_metrics(self, metrics: Metric | MetricCollection | None) -> Metric | MetricCollection | None:
        """Validate metrics - must be Metric or MetricCollection."""
        if metrics is None:
            return None

        if isinstance(metrics, (Metric, MetricCollection)):
            return metrics

        raise TypeError(
            f"metrics must be Metric or MetricCollection, got {type(metrics).__name__}.\n\n"
            f"Single metric:\n"
            f"  train_metrics:\n"
            f"    _target_: torchmetrics.Accuracy\n"
            f"    task: multiclass\n\n"
            f"Multiple metrics:\n"
            f"  train_metrics:\n"
            f"    _target_: torchmetrics.MetricCollection\n"
            f"    metrics:\n"
            f"      - _target_: torchmetrics.Accuracy\n"
            f"        task: multiclass\n"
            f"      - _target_: torchmetrics.F1Score\n"
            f"        task: multiclass"
        )

    # ============================================================================
    # Step Methods - Override as Needed
    # ============================================================================

    def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor | dict[str, Any]:
        """
        Define training logic.

        User responsibilities:
        - Extract data from batch
        - Call self(input) for forward pass
        - Compute loss
        - Call self.train_metrics(pred, target) if configured
        - Return loss tensor or dict with 'loss' key

        Framework automatically logs loss and metrics.

        Returns:
            Either:
                - Tensor: The loss value (simplest option)
                - Dict with required 'loss' key and optional keys:
                    - pred: Model predictions (for callbacks)
                    - target: Target labels (for callbacks)
                    - input: Input data (for callbacks)
                    - Any other keys you need
        """
        raise NotImplementedError(
            f"{self.__class__.__name__} must implement training_step() to use trainer.fit(). "
            f"See https://project-lighter.github.io/lighter/guides/lighter-module/"
        )

    def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor | dict[str, Any]:
        """
        Define validation logic.

        Similar to training_step but typically without gradients.
        Call self.val_metrics(pred, target) if configured.

        Returns:
            Either:
                - Tensor: The loss value
                - Dict with 'loss' key
        """
        raise NotImplementedError(
            f"{self.__class__.__name__} must implement validation_step() to use validation. "
            f"See https://project-lighter.github.io/lighter/guides/lighter-module/"
        )

    def test_step(self, batch: Any, batch_idx: int) -> torch.Tensor | dict[str, Any]:
        """
        Define test logic.

        Loss is optional. Call self.test_metrics(pred, target) if configured.

        Returns:
            Either:
                - Tensor: The loss value (optional in test mode)
                - Dict with optional 'loss' key. Can include pred, target, etc.
        """
        raise NotImplementedError(
            f"{self.__class__.__name__} must implement test_step() to use trainer.test(). "
            f"See https://project-lighter.github.io/lighter/guides/lighter-module/"
        )

    def predict_step(self, batch: Any, batch_idx: int) -> Any:
        """
        Define prediction logic.

        User responsibilities:
        - Extract data from batch
        - Call self(input) for forward pass
        - Return predictions in desired format

        No automatic logging happens in predict mode.
        Return any format you need (tensor, dict, list, etc.).
        """
        raise NotImplementedError(
            f"{self.__class__.__name__} must implement predict_step() to use trainer.predict(). "
            f"See https://project-lighter.github.io/lighter/guides/lighter-module/"
        )

    # ============================================================================
    # Forward Pass - Simple Delegation
    # ============================================================================

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        """
        Forward pass - simply delegates to self.network.

        Override if you need custom forward logic.
        """
        return self.network(*args, **kwargs)

    # ============================================================================
    # Batch-End Hooks - Automatic Logging
    # ============================================================================

    def _on_batch_end(self, outputs: torch.Tensor | dict[str, Any], batch_idx: int) -> None:
        """Common batch-end logic for all modes."""
        outputs = self._normalize_output(outputs)
        self._log_outputs(outputs, batch_idx)

    def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None:
        """Framework hook - automatically logs training outputs."""
        self._on_batch_end(outputs, batch_idx)

    def on_validation_batch_end(
        self,
        outputs: Any,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:
        """Framework hook - automatically logs validation outputs."""
        self._on_batch_end(outputs, batch_idx)

    def on_test_batch_end(
        self,
        outputs: Any,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:
        """Framework hook - automatically logs test outputs."""
        self._on_batch_end(outputs, batch_idx)

    def _normalize_output(self, output: torch.Tensor | dict[str, Any]) -> dict[str, Any]:
        """
        Normalize step output to dict format.

        Args:
            output: Either:
                - torch.Tensor: Loss value (normalized to {"loss": tensor})
                - dict: Must contain outputs. Can include:
                    - "loss": torch.Tensor or dict with "total" key
                    - "pred", "target", "input": Additional data for callbacks

        Returns:
            Dict with normalized structure

        Raises:
            TypeError: If output is neither Tensor nor dict
            ValueError: If loss dict is missing 'total' key
        """
        if isinstance(output, torch.Tensor):
            return {"loss": output}
        elif isinstance(output, dict):
            # Validate loss structure if present
            if "loss" in output and isinstance(output["loss"], dict):
                if "total" not in output["loss"]:
                    raise ValueError(
                        f"Loss dict must include 'total' key. "
                        f"Got keys: {list(output['loss'].keys())}. "
                        f"Example: {{'loss': {{'total': combined, 'ce': ce_loss, 'reg': reg_loss}}}}"
                    )
            return output
        else:
            raise TypeError(
                f"Step method must return torch.Tensor or dict. "
                f"Got {type(output).__name__} instead. "
                f"Examples:\n"
                f"  - return loss  # Simple tensor\n"
                f'  - return {{"loss": loss, "pred": pred}}'
            )

    def _log_outputs(self, outputs: dict[str, Any], batch_idx: int) -> None:
        """
        Log all outputs from a step.

        Override this method to customize logging behavior.
        Default: dual logging (step + epoch) for loss and metrics.

        Args:
            outputs: Dict from user's step method
            batch_idx: Current batch index
        """
        if self.trainer.logger is None:
            return
        self._log_loss(outputs.get("loss"))
        self._log_metrics()
        self._log_optimizer_stats(batch_idx)

    def _log_loss(self, loss: torch.Tensor | dict[str, Any] | None) -> None:
        """
        Log loss with dual pattern (step + epoch).

        Args:
            loss: Loss tensor or dict from step method.
                If dict, must have 'total' key (validated in _normalize_output).
        """
        if loss is None:
            return

        # Log scalar or dict
        if isinstance(loss, dict):
            for name, value in loss.items():
                name = f"{self.mode}/loss/{name}"
                self._log(name, value, on_step=True)
                self._log(name, value, on_epoch=True, sync_dist=True)
        else:
            name = f"{self.mode}/loss"
            self._log(name, loss, on_step=True)
            self._log(name, loss, on_epoch=True, sync_dist=True)

    def _log_metrics(self) -> None:
        """
        Log metrics with dual pattern (step + epoch).

        User already called metrics in their step method.
        Handles both single Metric and MetricCollection.
        """
        metrics = getattr(self, f"{self.mode}_metrics", None)
        if metrics is None:
            return

        if isinstance(metrics, MetricCollection):
            # MetricCollection - iterate over named metrics
            for name, metric in metrics.items():
                name = f"{self.mode}/metrics/{name}"
                self._log(name, metric, on_step=True)
                self._log(name, metric, on_epoch=True, sync_dist=True)
        else:
            # Single Metric - use class name (consistent with MetricCollection auto-naming)
            name = f"{self.mode}/metrics/{metrics.__class__.__name__}"
            self._log(name, metrics, on_step=True)
            self._log(name, metrics, on_epoch=True, sync_dist=True)

    def _log_optimizer_stats(self, batch_idx: int) -> None:
        """
        Log optimizer stats once per epoch in train mode.

        Args:
            batch_idx: Current batch index
        """
        if self.mode != Mode.TRAIN or batch_idx != 0 or self.optimizer is None:
            return

        # Optimizer stats only logged per epoch
        for name, stat in get_optimizer_stats(self.optimizer).items():
            name = f"{self.mode}/{name}"
            self._log(name, stat, on_epoch=True, sync_dist=False)

    def _log(self, name: str, value: Any, on_step: bool = False, on_epoch: bool = False, sync_dist: bool = False) -> None:
        suffix = "step" if on_step and not on_epoch else "epoch"
        self.log(
            f"{name}/{suffix}",
            value,
            logger=True,
            on_step=on_step,
            on_epoch=on_epoch,
            sync_dist=sync_dist,
        )

    # ============================================================================
    # Lightning Optimizer Configuration
    # ============================================================================

    def configure_optimizers(self):
        """Configure optimizer and scheduler."""
        if self.optimizer is None:
            raise ValueError("Optimizer not configured.")

        if self.scheduler is None:
            return {"optimizer": self.optimizer}
        else:
            return {"optimizer": self.optimizer, "lr_scheduler": self.scheduler}

    # ============================================================================
    # Properties
    # ============================================================================

    @property
    def mode(self) -> str:
        """
        Current execution mode.

        Returns:
            "train", "val", "test", or "predict"

        Raises:
            RuntimeError: If called outside trainer context
        """
        if self.trainer is None:
            raise RuntimeError("LighterModule is not attached to a Trainer.")

        if self.trainer.sanity_checking:
            return Mode.VAL

        if self.trainer.training:
            return Mode.TRAIN
        elif self.trainer.validating:
            return Mode.VAL
        elif self.trainer.testing:
            return Mode.TEST
        elif self.trainer.predicting:
            return Mode.PREDICT
        else:
            raise RuntimeError("Cannot determine mode outside Lightning execution.")

mode property

Current execution mode.

Returns:

Type Description
str

"train", "val", "test", or "predict"

Raises:

Type Description
RuntimeError

If called outside trainer context

_log_loss(loss)

Log loss with dual pattern (step + epoch).

Parameters:

Name Type Description Default
loss Tensor | dict[str, Any] | None

Loss tensor or dict from step method. If dict, must have 'total' key (validated in _normalize_output).

required
Source code in src/lighter/model.py
def _log_loss(self, loss: torch.Tensor | dict[str, Any] | None) -> None:
    """
    Log loss with dual pattern (step + epoch).

    Args:
        loss: Loss tensor or dict from step method.
            If dict, must have 'total' key (validated in _normalize_output).
    """
    if loss is None:
        return

    # Log scalar or dict
    if isinstance(loss, dict):
        for name, value in loss.items():
            name = f"{self.mode}/loss/{name}"
            self._log(name, value, on_step=True)
            self._log(name, value, on_epoch=True, sync_dist=True)
    else:
        name = f"{self.mode}/loss"
        self._log(name, loss, on_step=True)
        self._log(name, loss, on_epoch=True, sync_dist=True)

_log_metrics()

Log metrics with dual pattern (step + epoch).

User already called metrics in their step method. Handles both single Metric and MetricCollection.

Source code in src/lighter/model.py
def _log_metrics(self) -> None:
    """
    Log metrics with dual pattern (step + epoch).

    User already called metrics in their step method.
    Handles both single Metric and MetricCollection.
    """
    metrics = getattr(self, f"{self.mode}_metrics", None)
    if metrics is None:
        return

    if isinstance(metrics, MetricCollection):
        # MetricCollection - iterate over named metrics
        for name, metric in metrics.items():
            name = f"{self.mode}/metrics/{name}"
            self._log(name, metric, on_step=True)
            self._log(name, metric, on_epoch=True, sync_dist=True)
    else:
        # Single Metric - use class name (consistent with MetricCollection auto-naming)
        name = f"{self.mode}/metrics/{metrics.__class__.__name__}"
        self._log(name, metrics, on_step=True)
        self._log(name, metrics, on_epoch=True, sync_dist=True)

_log_optimizer_stats(batch_idx)

Log optimizer stats once per epoch in train mode.

Parameters:

Name Type Description Default
batch_idx int

Current batch index

required
Source code in src/lighter/model.py
def _log_optimizer_stats(self, batch_idx: int) -> None:
    """
    Log optimizer stats once per epoch in train mode.

    Args:
        batch_idx: Current batch index
    """
    if self.mode != Mode.TRAIN or batch_idx != 0 or self.optimizer is None:
        return

    # Optimizer stats only logged per epoch
    for name, stat in get_optimizer_stats(self.optimizer).items():
        name = f"{self.mode}/{name}"
        self._log(name, stat, on_epoch=True, sync_dist=False)

_log_outputs(outputs, batch_idx)

Log all outputs from a step.

Override this method to customize logging behavior. Default: dual logging (step + epoch) for loss and metrics.

Parameters:

Name Type Description Default
outputs dict[str, Any]

Dict from user's step method

required
batch_idx int

Current batch index

required
Source code in src/lighter/model.py
def _log_outputs(self, outputs: dict[str, Any], batch_idx: int) -> None:
    """
    Log all outputs from a step.

    Override this method to customize logging behavior.
    Default: dual logging (step + epoch) for loss and metrics.

    Args:
        outputs: Dict from user's step method
        batch_idx: Current batch index
    """
    if self.trainer.logger is None:
        return
    self._log_loss(outputs.get("loss"))
    self._log_metrics()
    self._log_optimizer_stats(batch_idx)

_normalize_output(output)

Normalize step output to dict format.

Parameters:

Name Type Description Default
output Tensor | dict[str, Any]

Either: - torch.Tensor: Loss value (normalized to {"loss": tensor}) - dict: Must contain outputs. Can include: - "loss": torch.Tensor or dict with "total" key - "pred", "target", "input": Additional data for callbacks

required

Returns:

Type Description
dict[str, Any]

Dict with normalized structure

Raises:

Type Description
TypeError

If output is neither Tensor nor dict

ValueError

If loss dict is missing 'total' key

Source code in src/lighter/model.py
def _normalize_output(self, output: torch.Tensor | dict[str, Any]) -> dict[str, Any]:
    """
    Normalize step output to dict format.

    Args:
        output: Either:
            - torch.Tensor: Loss value (normalized to {"loss": tensor})
            - dict: Must contain outputs. Can include:
                - "loss": torch.Tensor or dict with "total" key
                - "pred", "target", "input": Additional data for callbacks

    Returns:
        Dict with normalized structure

    Raises:
        TypeError: If output is neither Tensor nor dict
        ValueError: If loss dict is missing 'total' key
    """
    if isinstance(output, torch.Tensor):
        return {"loss": output}
    elif isinstance(output, dict):
        # Validate loss structure if present
        if "loss" in output and isinstance(output["loss"], dict):
            if "total" not in output["loss"]:
                raise ValueError(
                    f"Loss dict must include 'total' key. "
                    f"Got keys: {list(output['loss'].keys())}. "
                    f"Example: {{'loss': {{'total': combined, 'ce': ce_loss, 'reg': reg_loss}}}}"
                )
        return output
    else:
        raise TypeError(
            f"Step method must return torch.Tensor or dict. "
            f"Got {type(output).__name__} instead. "
            f"Examples:\n"
            f"  - return loss  # Simple tensor\n"
            f'  - return {{"loss": loss, "pred": pred}}'
        )

_on_batch_end(outputs, batch_idx)

Common batch-end logic for all modes.

Source code in src/lighter/model.py
def _on_batch_end(self, outputs: torch.Tensor | dict[str, Any], batch_idx: int) -> None:
    """Common batch-end logic for all modes."""
    outputs = self._normalize_output(outputs)
    self._log_outputs(outputs, batch_idx)

_prepare_metrics(metrics)

Validate metrics - must be Metric or MetricCollection.

Source code in src/lighter/model.py
def _prepare_metrics(self, metrics: Metric | MetricCollection | None) -> Metric | MetricCollection | None:
    """Validate metrics - must be Metric or MetricCollection."""
    if metrics is None:
        return None

    if isinstance(metrics, (Metric, MetricCollection)):
        return metrics

    raise TypeError(
        f"metrics must be Metric or MetricCollection, got {type(metrics).__name__}.\n\n"
        f"Single metric:\n"
        f"  train_metrics:\n"
        f"    _target_: torchmetrics.Accuracy\n"
        f"    task: multiclass\n\n"
        f"Multiple metrics:\n"
        f"  train_metrics:\n"
        f"    _target_: torchmetrics.MetricCollection\n"
        f"    metrics:\n"
        f"      - _target_: torchmetrics.Accuracy\n"
        f"        task: multiclass\n"
        f"      - _target_: torchmetrics.F1Score\n"
        f"        task: multiclass"
    )

configure_optimizers()

Configure optimizer and scheduler.

Source code in src/lighter/model.py
def configure_optimizers(self):
    """Configure optimizer and scheduler."""
    if self.optimizer is None:
        raise ValueError("Optimizer not configured.")

    if self.scheduler is None:
        return {"optimizer": self.optimizer}
    else:
        return {"optimizer": self.optimizer, "lr_scheduler": self.scheduler}

forward(*args, **kwargs)

Forward pass - simply delegates to self.network.

Override if you need custom forward logic.

Source code in src/lighter/model.py
def forward(self, *args: Any, **kwargs: Any) -> Any:
    """
    Forward pass - simply delegates to self.network.

    Override if you need custom forward logic.
    """
    return self.network(*args, **kwargs)

on_test_batch_end(outputs, batch, batch_idx, dataloader_idx=0)

Framework hook - automatically logs test outputs.

Source code in src/lighter/model.py
def on_test_batch_end(
    self,
    outputs: Any,
    batch: Any,
    batch_idx: int,
    dataloader_idx: int = 0,
) -> None:
    """Framework hook - automatically logs test outputs."""
    self._on_batch_end(outputs, batch_idx)

on_train_batch_end(outputs, batch, batch_idx)

Framework hook - automatically logs training outputs.

Source code in src/lighter/model.py
def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None:
    """Framework hook - automatically logs training outputs."""
    self._on_batch_end(outputs, batch_idx)

on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx=0)

Framework hook - automatically logs validation outputs.

Source code in src/lighter/model.py
def on_validation_batch_end(
    self,
    outputs: Any,
    batch: Any,
    batch_idx: int,
    dataloader_idx: int = 0,
) -> None:
    """Framework hook - automatically logs validation outputs."""
    self._on_batch_end(outputs, batch_idx)

predict_step(batch, batch_idx)

Define prediction logic.

User responsibilities: - Extract data from batch - Call self(input) for forward pass - Return predictions in desired format

No automatic logging happens in predict mode. Return any format you need (tensor, dict, list, etc.).

Source code in src/lighter/model.py
def predict_step(self, batch: Any, batch_idx: int) -> Any:
    """
    Define prediction logic.

    User responsibilities:
    - Extract data from batch
    - Call self(input) for forward pass
    - Return predictions in desired format

    No automatic logging happens in predict mode.
    Return any format you need (tensor, dict, list, etc.).
    """
    raise NotImplementedError(
        f"{self.__class__.__name__} must implement predict_step() to use trainer.predict(). "
        f"See https://project-lighter.github.io/lighter/guides/lighter-module/"
    )

test_step(batch, batch_idx)

Define test logic.

Loss is optional. Call self.test_metrics(pred, target) if configured.

Returns:

Name Type Description
Either Tensor | dict[str, Any]
  • Tensor: The loss value (optional in test mode)
  • Dict with optional 'loss' key. Can include pred, target, etc.
Source code in src/lighter/model.py
def test_step(self, batch: Any, batch_idx: int) -> torch.Tensor | dict[str, Any]:
    """
    Define test logic.

    Loss is optional. Call self.test_metrics(pred, target) if configured.

    Returns:
        Either:
            - Tensor: The loss value (optional in test mode)
            - Dict with optional 'loss' key. Can include pred, target, etc.
    """
    raise NotImplementedError(
        f"{self.__class__.__name__} must implement test_step() to use trainer.test(). "
        f"See https://project-lighter.github.io/lighter/guides/lighter-module/"
    )

training_step(batch, batch_idx)

Define training logic.

User responsibilities: - Extract data from batch - Call self(input) for forward pass - Compute loss - Call self.train_metrics(pred, target) if configured - Return loss tensor or dict with 'loss' key

Framework automatically logs loss and metrics.

Returns:

Name Type Description
Either Tensor | dict[str, Any]
  • Tensor: The loss value (simplest option)
  • Dict with required 'loss' key and optional keys:
    • pred: Model predictions (for callbacks)
    • target: Target labels (for callbacks)
    • input: Input data (for callbacks)
    • Any other keys you need
Source code in src/lighter/model.py
def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor | dict[str, Any]:
    """
    Define training logic.

    User responsibilities:
    - Extract data from batch
    - Call self(input) for forward pass
    - Compute loss
    - Call self.train_metrics(pred, target) if configured
    - Return loss tensor or dict with 'loss' key

    Framework automatically logs loss and metrics.

    Returns:
        Either:
            - Tensor: The loss value (simplest option)
            - Dict with required 'loss' key and optional keys:
                - pred: Model predictions (for callbacks)
                - target: Target labels (for callbacks)
                - input: Input data (for callbacks)
                - Any other keys you need
    """
    raise NotImplementedError(
        f"{self.__class__.__name__} must implement training_step() to use trainer.fit(). "
        f"See https://project-lighter.github.io/lighter/guides/lighter-module/"
    )

validation_step(batch, batch_idx)

Define validation logic.

Similar to training_step but typically without gradients. Call self.val_metrics(pred, target) if configured.

Returns:

Name Type Description
Either Tensor | dict[str, Any]
  • Tensor: The loss value
  • Dict with 'loss' key
Source code in src/lighter/model.py
def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor | dict[str, Any]:
    """
    Define validation logic.

    Similar to training_step but typically without gradients.
    Call self.val_metrics(pred, target) if configured.

    Returns:
        Either:
            - Tensor: The loss value
            - Dict with 'loss' key
    """
    raise NotImplementedError(
        f"{self.__class__.__name__} must implement validation_step() to use validation. "
        f"See https://project-lighter.github.io/lighter/guides/lighter-module/"
    )

Runner

Orchestrates training stage execution by coordinating helper classes.

Runner delegates responsibilities to specialized helper classes: - ProjectImporter: Auto-discovers and imports user project modules via lighter.py marker - ConfigLoader: Loads and validates configurations using Sparkwheel

Runner focuses on resolving and validating components (model, trainer, datamodule) and executing the requested training stage.

Source code in src/lighter/engine/runner.py
class Runner:
    """
    Orchestrates training stage execution by coordinating helper classes.

    Runner delegates responsibilities to specialized helper classes:
    - ProjectImporter: Auto-discovers and imports user project modules via __lighter__.py marker
    - ConfigLoader: Loads and validates configurations using Sparkwheel

    Runner focuses on resolving and validating components (model, trainer, datamodule)
    and executing the requested training stage.
    """

    def run(
        self,
        stage: Stage,
        inputs: list,
        **stage_kwargs: Any,
    ) -> None:
        """
        Run a training stage with configuration inputs.

        Orchestrates the complete training workflow:
        1. Loads configuration via ConfigLoader (delegates to Sparkwheel for auto-detection)
        2. Auto-discovers and imports project modules via ProjectImporter
        3. Resolves and validates model, trainer, and datamodule components
        4. Saves configuration (to log directory, logger, and model hyperparameters)
        5. Executes the requested training stage

        Args:
            stage: Stage to run (fit, validate, test, predict)
            inputs: List of config file paths, dicts, and/or overrides.
                   Passed to ConfigLoader.load() which delegates to Sparkwheel for auto-detection:
                   - Strings without '=' → file paths
                   - Strings with '=' → overrides
                   - Dicts → merged into config
            **stage_kwargs: Additional keyword arguments from CLI (e.g., ckpt_path, verbose)
                           passed directly to the trainer stage method

        Raises:
            ValueError: If config validation fails or required components are missing
            TypeError: If model or trainer are not the correct type
        """
        seed_everything()

        # 1. Load configuration
        config = ConfigLoader.load(inputs)

        # 2. Auto-discover and import project
        ProjectImporter.auto_discover_and_import()

        # 3. Resolve components
        model = self._resolve_model(config)
        trainer = self._resolve_trainer(config)
        datamodule = self._resolve_datamodule(config, model)

        # 4. Save configuration to trainer's log directory, logger, and model hparams for checkpoint access
        self._save_config(config, trainer, model)

        # 5. Execute stage
        self._execute(stage, model, trainer, datamodule, **stage_kwargs)

    def _resolve_model(self, config: Config) -> LightningModule:
        """Resolve and validate model from config."""
        model = config.resolve("model")
        if not isinstance(model, LightningModule):
            raise TypeError(f"model must be LightningModule or LighterModule, got {type(model)}")
        return model

    def _resolve_trainer(self, config: Config) -> Trainer:
        """Resolve and validate trainer from config."""
        trainer = config.resolve("trainer")
        if not isinstance(trainer, Trainer):
            raise TypeError(f"trainer must be Trainer, got {type(trainer)}")
        return trainer

    def _resolve_datamodule(self, config: Config, model: LightningModule) -> LightningDataModule | None:
        """
        Resolve and validate datamodule from config.

        Args:
            config: Configuration object
            model: Resolved model (checked for built-in dataloaders)

        Returns:
            LightningDataModule instance or None if model defines its own dataloaders

        Raises:
            TypeError: If data key exists but is not a LightningDataModule
        """
        # Data key is optional - plain Lightning modules can define their own dataloaders
        if config.get("data") is None:
            # Check if model has dataloader methods (plain Lightning module)
            has_dataloaders = any(
                hasattr(model, method)
                for method in ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"]
            )
            if not has_dataloaders:
                raise ValueError(
                    "Missing required 'data:' config key and model does not define dataloader methods. "
                    "Either:\n"
                    "1. Add 'data:' config key:\n"
                    "   data:\n"
                    "     _target_: lighter.LighterDataModule\n"
                    "     train_dataloader: ...\n"
                    "2. Or define dataloader methods in your LightningModule (train_dataloader, val_dataloader, etc.)"
                )
            return None

        # Resolve and validate data key
        datamodule = config.resolve("data")
        if not isinstance(datamodule, LightningDataModule):
            raise TypeError(
                f"data must be LightningDataModule (or lighter.LighterDataModule), got {type(datamodule)}. "
                "Example:\n"
                "data:\n"
                "  _target_: lighter.LighterDataModule\n"
                "  train_dataloader:\n"
                "    _target_: torch.utils.data.DataLoader\n"
                "    # ... config ..."
            )

        return datamodule

    def _save_config(self, config: Config, trainer: Trainer, model: LightningModule) -> None:
        """
        Save configuration to multiple destinations.

        Saves the configuration to:
        - Model (for checkpoint access via model.hparams)
        - Logger (for experiment tracking via log_hyperparams)
        - Log directory (as config.yaml file)

        Args:
            config: Configuration object to save
            trainer: Trainer (uses trainer.logger and trainer.log_dir)
            model: Model to save hyperparameters to
        """

        # Save to model checkpoint (for model.hparams access)
        model.save_hyperparameters({"config": config.get()})

        # If no logger, skip other saves
        if not trainer.logger:
            return

        # Save to logger (for experiment tracking)
        trainer.logger.log_hyperparams(config.get())

        # Save as config.yaml to log directory if it exists
        if trainer.log_dir:
            config_file = Path(trainer.log_dir) / "config.yaml"
            config_file.parent.mkdir(parents=True, exist_ok=True)
            with open(config_file, "w") as f:
                yaml.dump(config.get(), f, default_flow_style=False, sort_keys=False, indent=4)
            logger.info(f"Saved config to: {config_file}")

    def _execute(
        self,
        stage: Stage,
        model: LightningModule,
        trainer: Trainer,
        datamodule: LightningDataModule | None,
        **stage_kwargs: Any,
    ) -> None:
        """
        Execute the training stage.

        Args:
            stage: Stage to execute (fit, validate, test, predict)
            model: Resolved model
            trainer: Resolved trainer
            datamodule: Resolved datamodule (None if model defines its own dataloaders)
            **stage_kwargs: Additional keyword arguments from CLI (e.g., ckpt_path, verbose)
        """
        stage_method = getattr(trainer, str(stage))
        if datamodule is not None:
            stage_method(model, datamodule=datamodule, **stage_kwargs)
        else:
            # Plain Lightning module with built-in dataloaders
            stage_method(model, **stage_kwargs)

_execute(stage, model, trainer, datamodule, **stage_kwargs)

Execute the training stage.

Parameters:

Name Type Description Default
stage Stage

Stage to execute (fit, validate, test, predict)

required
model LightningModule

Resolved model

required
trainer Trainer

Resolved trainer

required
datamodule LightningDataModule | None

Resolved datamodule (None if model defines its own dataloaders)

required
**stage_kwargs Any

Additional keyword arguments from CLI (e.g., ckpt_path, verbose)

{}
Source code in src/lighter/engine/runner.py
def _execute(
    self,
    stage: Stage,
    model: LightningModule,
    trainer: Trainer,
    datamodule: LightningDataModule | None,
    **stage_kwargs: Any,
) -> None:
    """
    Execute the training stage.

    Args:
        stage: Stage to execute (fit, validate, test, predict)
        model: Resolved model
        trainer: Resolved trainer
        datamodule: Resolved datamodule (None if model defines its own dataloaders)
        **stage_kwargs: Additional keyword arguments from CLI (e.g., ckpt_path, verbose)
    """
    stage_method = getattr(trainer, str(stage))
    if datamodule is not None:
        stage_method(model, datamodule=datamodule, **stage_kwargs)
    else:
        # Plain Lightning module with built-in dataloaders
        stage_method(model, **stage_kwargs)

_resolve_datamodule(config, model)

Resolve and validate datamodule from config.

Parameters:

Name Type Description Default
config Config

Configuration object

required
model LightningModule

Resolved model (checked for built-in dataloaders)

required

Returns:

Type Description
LightningDataModule | None

LightningDataModule instance or None if model defines its own dataloaders

Raises:

Type Description
TypeError

If data key exists but is not a LightningDataModule

Source code in src/lighter/engine/runner.py
def _resolve_datamodule(self, config: Config, model: LightningModule) -> LightningDataModule | None:
    """
    Resolve and validate datamodule from config.

    Args:
        config: Configuration object
        model: Resolved model (checked for built-in dataloaders)

    Returns:
        LightningDataModule instance or None if model defines its own dataloaders

    Raises:
        TypeError: If data key exists but is not a LightningDataModule
    """
    # Data key is optional - plain Lightning modules can define their own dataloaders
    if config.get("data") is None:
        # Check if model has dataloader methods (plain Lightning module)
        has_dataloaders = any(
            hasattr(model, method)
            for method in ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"]
        )
        if not has_dataloaders:
            raise ValueError(
                "Missing required 'data:' config key and model does not define dataloader methods. "
                "Either:\n"
                "1. Add 'data:' config key:\n"
                "   data:\n"
                "     _target_: lighter.LighterDataModule\n"
                "     train_dataloader: ...\n"
                "2. Or define dataloader methods in your LightningModule (train_dataloader, val_dataloader, etc.)"
            )
        return None

    # Resolve and validate data key
    datamodule = config.resolve("data")
    if not isinstance(datamodule, LightningDataModule):
        raise TypeError(
            f"data must be LightningDataModule (or lighter.LighterDataModule), got {type(datamodule)}. "
            "Example:\n"
            "data:\n"
            "  _target_: lighter.LighterDataModule\n"
            "  train_dataloader:\n"
            "    _target_: torch.utils.data.DataLoader\n"
            "    # ... config ..."
        )

    return datamodule

_resolve_model(config)

Resolve and validate model from config.

Source code in src/lighter/engine/runner.py
def _resolve_model(self, config: Config) -> LightningModule:
    """Resolve and validate model from config."""
    model = config.resolve("model")
    if not isinstance(model, LightningModule):
        raise TypeError(f"model must be LightningModule or LighterModule, got {type(model)}")
    return model

_resolve_trainer(config)

Resolve and validate trainer from config.

Source code in src/lighter/engine/runner.py
def _resolve_trainer(self, config: Config) -> Trainer:
    """Resolve and validate trainer from config."""
    trainer = config.resolve("trainer")
    if not isinstance(trainer, Trainer):
        raise TypeError(f"trainer must be Trainer, got {type(trainer)}")
    return trainer

_save_config(config, trainer, model)

Save configuration to multiple destinations.

Saves the configuration to: - Model (for checkpoint access via model.hparams) - Logger (for experiment tracking via log_hyperparams) - Log directory (as config.yaml file)

Parameters:

Name Type Description Default
config Config

Configuration object to save

required
trainer Trainer

Trainer (uses trainer.logger and trainer.log_dir)

required
model LightningModule

Model to save hyperparameters to

required
Source code in src/lighter/engine/runner.py
def _save_config(self, config: Config, trainer: Trainer, model: LightningModule) -> None:
    """
    Save configuration to multiple destinations.

    Saves the configuration to:
    - Model (for checkpoint access via model.hparams)
    - Logger (for experiment tracking via log_hyperparams)
    - Log directory (as config.yaml file)

    Args:
        config: Configuration object to save
        trainer: Trainer (uses trainer.logger and trainer.log_dir)
        model: Model to save hyperparameters to
    """

    # Save to model checkpoint (for model.hparams access)
    model.save_hyperparameters({"config": config.get()})

    # If no logger, skip other saves
    if not trainer.logger:
        return

    # Save to logger (for experiment tracking)
    trainer.logger.log_hyperparams(config.get())

    # Save as config.yaml to log directory if it exists
    if trainer.log_dir:
        config_file = Path(trainer.log_dir) / "config.yaml"
        config_file.parent.mkdir(parents=True, exist_ok=True)
        with open(config_file, "w") as f:
            yaml.dump(config.get(), f, default_flow_style=False, sort_keys=False, indent=4)
        logger.info(f"Saved config to: {config_file}")

run(stage, inputs, **stage_kwargs)

Run a training stage with configuration inputs.

Orchestrates the complete training workflow: 1. Loads configuration via ConfigLoader (delegates to Sparkwheel for auto-detection) 2. Auto-discovers and imports project modules via ProjectImporter 3. Resolves and validates model, trainer, and datamodule components 4. Saves configuration (to log directory, logger, and model hyperparameters) 5. Executes the requested training stage

Parameters:

Name Type Description Default
stage Stage

Stage to run (fit, validate, test, predict)

required
inputs list

List of config file paths, dicts, and/or overrides. Passed to ConfigLoader.load() which delegates to Sparkwheel for auto-detection: - Strings without '=' → file paths - Strings with '=' → overrides - Dicts → merged into config

required
**stage_kwargs Any

Additional keyword arguments from CLI (e.g., ckpt_path, verbose) passed directly to the trainer stage method

{}

Raises:

Type Description
ValueError

If config validation fails or required components are missing

TypeError

If model or trainer are not the correct type

Source code in src/lighter/engine/runner.py
def run(
    self,
    stage: Stage,
    inputs: list,
    **stage_kwargs: Any,
) -> None:
    """
    Run a training stage with configuration inputs.

    Orchestrates the complete training workflow:
    1. Loads configuration via ConfigLoader (delegates to Sparkwheel for auto-detection)
    2. Auto-discovers and imports project modules via ProjectImporter
    3. Resolves and validates model, trainer, and datamodule components
    4. Saves configuration (to log directory, logger, and model hyperparameters)
    5. Executes the requested training stage

    Args:
        stage: Stage to run (fit, validate, test, predict)
        inputs: List of config file paths, dicts, and/or overrides.
               Passed to ConfigLoader.load() which delegates to Sparkwheel for auto-detection:
               - Strings without '=' → file paths
               - Strings with '=' → overrides
               - Dicts → merged into config
        **stage_kwargs: Additional keyword arguments from CLI (e.g., ckpt_path, verbose)
                       passed directly to the trainer stage method

    Raises:
        ValueError: If config validation fails or required components are missing
        TypeError: If model or trainer are not the correct type
    """
    seed_everything()

    # 1. Load configuration
    config = ConfigLoader.load(inputs)

    # 2. Auto-discover and import project
    ProjectImporter.auto_discover_and_import()

    # 3. Resolve components
    model = self._resolve_model(config)
    trainer = self._resolve_trainer(config)
    datamodule = self._resolve_datamodule(config, model)

    # 4. Save configuration to trainer's log directory, logger, and model hparams for checkpoint access
    self._save_config(config, trainer, model)

    # 5. Execute stage
    self._execute(stage, model, trainer, datamodule, **stage_kwargs)