Skip to content

callbacks

FileWriter

Bases: BaseWriter

Writer for saving predictions to files in various formats including tensors, images, videos, and ITK images. Custom writer functions can be provided to extend supported formats. Args: path: Directory path where output files will be saved. writer: Either a string specifying a built-in writer or a custom writer function. Built-in writers: - "tensor": Saves raw tensor data (.pt) - "image": Saves as image file (.png) - "video": Saves as video file - "itk_nrrd": Saves as ITK NRRD file (.nrrd) - "itk_seg_nrrd": Saves as ITK segmentation NRRD file (.seg.nrrd) - "itk_nifti": Saves as ITK NIfTI file (.nii.gz) Custom writers must: - Accept (path, tensor) arguments - Handle single tensor input (no batch dimension) - Save output to the specified path

Source code in src/lighter/callbacks/writer/file.py
class FileWriter(BaseWriter):
    """
    Writer for saving predictions to files in various formats including tensors, images, videos, and ITK images.
    Custom writer functions can be provided to extend supported formats.
    Args:
        path: Directory path where output files will be saved.
        writer: Either a string specifying a built-in writer or a custom writer function.
            Built-in writers:
                - "tensor": Saves raw tensor data (.pt)
                - "image": Saves as image file (.png)
                - "video": Saves as video file
                - "itk_nrrd": Saves as ITK NRRD file (.nrrd)
                - "itk_seg_nrrd": Saves as ITK segmentation NRRD file (.seg.nrrd)
                - "itk_nifti": Saves as ITK NIfTI file (.nii.gz)
            Custom writers must:
                - Accept (path, tensor) arguments
                - Handle single tensor input (no batch dimension)
                - Save output to the specified path
    """

    @property
    def writers(self) -> dict[str, Callable]:
        return {
            "tensor": write_tensor,
            "image": write_image,
            "video": write_video,
            "itk_nrrd": partial(write_itk_image, suffix=".nrrd"),
            "itk_seg_nrrd": partial(write_itk_image, suffix=".seg.nrrd"),
            "itk_nifti": partial(write_itk_image, suffix=".nii.gz"),
        }

    def write(self, tensor: Tensor, identifier: int | str) -> None:
        """
        Writes the tensor to a file using the specified writer.

        Args:
            tensor: The tensor to write.
            identifier: Identifier for naming the file.
        """
        if not self.path.is_dir():
            raise RuntimeError(f"FileWriter expects a directory path, got {self.path}")

        # Determine the path for the file based on prediction count. The suffix must be added by the writer function.
        path = self.path / str(identifier)
        # Write the tensor to the file.
        self.writer(path, tensor)

write(tensor, identifier)

Writes the tensor to a file using the specified writer.

Parameters:

Name Type Description Default
tensor Tensor

The tensor to write.

required
identifier int | str

Identifier for naming the file.

required
Source code in src/lighter/callbacks/writer/file.py
def write(self, tensor: Tensor, identifier: int | str) -> None:
    """
    Writes the tensor to a file using the specified writer.

    Args:
        tensor: The tensor to write.
        identifier: Identifier for naming the file.
    """
    if not self.path.is_dir():
        raise RuntimeError(f"FileWriter expects a directory path, got {self.path}")

    # Determine the path for the file based on prediction count. The suffix must be added by the writer function.
    path = self.path / str(identifier)
    # Write the tensor to the file.
    self.writer(path, tensor)

Freezer

Bases: Callback

Callback to freeze model parameters during training. Parameters can be frozen by exact name or prefix. Freezing can be applied indefinitely or until a specified step/epoch.

Parameters:

Name Type Description Default
names str | list[str] | None

Full names of parameters to freeze.

None
name_starts_with str | list[str] | None

Prefixes of parameter names to freeze.

None
except_names str | list[str] | None

Names of parameters to exclude from freezing.

None
except_name_starts_with str | list[str] | None

Prefixes of parameter names to exclude from freezing.

None
until_step int | None

Maximum step to freeze parameters until.

None
until_epoch int | None

Maximum epoch to freeze parameters until.

None

Raises:

Type Description
ValueError

If neither names nor name_starts_with are specified.

ValueError

If both until_step and until_epoch are specified.

Source code in src/lighter/callbacks/freezer.py
class Freezer(Callback):
    """
    Callback to freeze model parameters during training. Parameters can be frozen by exact name or prefix.
    Freezing can be applied indefinitely or until a specified step/epoch.

    Args:
        names: Full names of parameters to freeze.
        name_starts_with: Prefixes of parameter names to freeze.
        except_names: Names of parameters to exclude from freezing.
        except_name_starts_with: Prefixes of parameter names to exclude from freezing.
        until_step: Maximum step to freeze parameters until.
        until_epoch: Maximum epoch to freeze parameters until.

    Raises:
        ValueError: If neither `names` nor `name_starts_with` are specified.
        ValueError: If both `until_step` and `until_epoch` are specified.

    """

    def __init__(
        self,
        names: str | list[str] | None = None,
        name_starts_with: str | list[str] | None = None,
        except_names: str | list[str] | None = None,
        except_name_starts_with: str | list[str] | None = None,
        until_step: int | None = None,
        until_epoch: int | None = None,
    ) -> None:
        super().__init__()

        if names is None and name_starts_with is None:
            raise ValueError("At least one of `names` or `name_starts_with` must be specified.")

        if until_step is not None and until_epoch is not None:
            raise ValueError("Only one of `until_step` or `until_epoch` can be specified.")

        self.names = ensure_list(names)
        self.name_starts_with = ensure_list(name_starts_with)
        self.except_names = ensure_list(except_names)
        self.except_name_starts_with = ensure_list(except_name_starts_with)
        self.until_step = until_step
        self.until_epoch = until_epoch

        self._frozen_state = False

    def on_train_batch_start(self, trainer: Trainer, pl_module: System, batch: Any, batch_idx: int) -> None:
        """
        Called at the start of each training batch to potentially freeze parameters.

        Args:
            trainer: The trainer instance.
            pl_module: The System instance.
            batch: The current batch.
            batch_idx: The index of the batch.
        """
        self._on_batch_start(trainer, pl_module)

    def on_validation_batch_start(
        self, trainer: Trainer, pl_module: System, batch: Any, batch_idx: int, dataloader_idx: int = 0
    ) -> None:
        self._on_batch_start(trainer, pl_module)

    def on_test_batch_start(
        self, trainer: Trainer, pl_module: System, batch: Any, batch_idx: int, dataloader_idx: int = 0
    ) -> None:
        self._on_batch_start(trainer, pl_module)

    def on_predict_batch_start(
        self, trainer: Trainer, pl_module: System, batch: Any, batch_idx: int, dataloader_idx: int = 0
    ) -> None:
        self._on_batch_start(trainer, pl_module)

    def _on_batch_start(self, trainer: Trainer, pl_module: System) -> None:
        """
        Freezes or unfreezes model parameters based on the current step or epoch.

        Args:
            trainer: The trainer instance.
            pl_module: The System instance.
        """
        current_step = trainer.global_step
        current_epoch = trainer.current_epoch

        if self.until_step is not None and current_step >= self.until_step:
            if self._frozen_state:
                logger.info(f"Reached step {self.until_step} - unfreezing the previously frozen layers.")
                self._set_model_requires_grad(pl_module, True)
            return

        if self.until_epoch is not None and current_epoch >= self.until_epoch:
            if self._frozen_state:
                logger.info(f"Reached epoch {self.until_epoch} - unfreezing the previously frozen layers.")
                self._set_model_requires_grad(pl_module, True)
            return

        if not self._frozen_state:
            self._set_model_requires_grad(pl_module, False)

    def _set_model_requires_grad(self, model: Module | System, requires_grad: bool) -> None:
        """
        Sets the requires_grad attribute for model parameters, effectively freezing or unfreezing them.

        Args:
            model: The model whose parameters to modify.
            requires_grad: Whether to allow gradients (unfreeze) or not (freeze).
        """
        # If the model is a `System`, get the underlying PyTorch model.
        if isinstance(model, System):
            model = model.model

        frozen_layers = []
        # Freeze the specified parameters.
        for name, param in model.named_parameters():
            # Leave the excluded-from-freezing parameters trainable.
            if self.except_names and name in self.except_names:
                param.requires_grad = True
                continue
            if self.except_name_starts_with and any(name.startswith(prefix) for prefix in self.except_name_starts_with):
                param.requires_grad = True
                continue

            # Freeze/unfreeze the specified parameters, based on the `requires_grad` argument.
            if self.names and name in self.names:
                param.requires_grad = requires_grad
                frozen_layers.append(name)
                continue
            if self.name_starts_with and any(name.startswith(prefix) for prefix in self.name_starts_with):
                param.requires_grad = requires_grad
                frozen_layers.append(name)
                continue

            # Otherwise, leave the parameter trainable.
            param.requires_grad = True

        self._frozen_state = not requires_grad
        # Log only when freezing the parameters.
        if self._frozen_state:
            logger.info(
                f"Setting requires_grad={requires_grad} the following layers"
                + (f" until step {self.until_step}" if self.until_step is not None else "")
                + (f" until epoch {self.until_epoch}" if self.until_epoch is not None else "")
                + f": {frozen_layers}"
            )

_on_batch_start(trainer, pl_module)

Freezes or unfreezes model parameters based on the current step or epoch.

Parameters:

Name Type Description Default
trainer Trainer

The trainer instance.

required
pl_module System

The System instance.

required
Source code in src/lighter/callbacks/freezer.py
def _on_batch_start(self, trainer: Trainer, pl_module: System) -> None:
    """
    Freezes or unfreezes model parameters based on the current step or epoch.

    Args:
        trainer: The trainer instance.
        pl_module: The System instance.
    """
    current_step = trainer.global_step
    current_epoch = trainer.current_epoch

    if self.until_step is not None and current_step >= self.until_step:
        if self._frozen_state:
            logger.info(f"Reached step {self.until_step} - unfreezing the previously frozen layers.")
            self._set_model_requires_grad(pl_module, True)
        return

    if self.until_epoch is not None and current_epoch >= self.until_epoch:
        if self._frozen_state:
            logger.info(f"Reached epoch {self.until_epoch} - unfreezing the previously frozen layers.")
            self._set_model_requires_grad(pl_module, True)
        return

    if not self._frozen_state:
        self._set_model_requires_grad(pl_module, False)

_set_model_requires_grad(model, requires_grad)

Sets the requires_grad attribute for model parameters, effectively freezing or unfreezing them.

Parameters:

Name Type Description Default
model Module | System

The model whose parameters to modify.

required
requires_grad bool

Whether to allow gradients (unfreeze) or not (freeze).

required
Source code in src/lighter/callbacks/freezer.py
def _set_model_requires_grad(self, model: Module | System, requires_grad: bool) -> None:
    """
    Sets the requires_grad attribute for model parameters, effectively freezing or unfreezing them.

    Args:
        model: The model whose parameters to modify.
        requires_grad: Whether to allow gradients (unfreeze) or not (freeze).
    """
    # If the model is a `System`, get the underlying PyTorch model.
    if isinstance(model, System):
        model = model.model

    frozen_layers = []
    # Freeze the specified parameters.
    for name, param in model.named_parameters():
        # Leave the excluded-from-freezing parameters trainable.
        if self.except_names and name in self.except_names:
            param.requires_grad = True
            continue
        if self.except_name_starts_with and any(name.startswith(prefix) for prefix in self.except_name_starts_with):
            param.requires_grad = True
            continue

        # Freeze/unfreeze the specified parameters, based on the `requires_grad` argument.
        if self.names and name in self.names:
            param.requires_grad = requires_grad
            frozen_layers.append(name)
            continue
        if self.name_starts_with and any(name.startswith(prefix) for prefix in self.name_starts_with):
            param.requires_grad = requires_grad
            frozen_layers.append(name)
            continue

        # Otherwise, leave the parameter trainable.
        param.requires_grad = True

    self._frozen_state = not requires_grad
    # Log only when freezing the parameters.
    if self._frozen_state:
        logger.info(
            f"Setting requires_grad={requires_grad} the following layers"
            + (f" until step {self.until_step}" if self.until_step is not None else "")
            + (f" until epoch {self.until_epoch}" if self.until_epoch is not None else "")
            + f": {frozen_layers}"
        )

on_train_batch_start(trainer, pl_module, batch, batch_idx)

Called at the start of each training batch to potentially freeze parameters.

Parameters:

Name Type Description Default
trainer Trainer

The trainer instance.

required
pl_module System

The System instance.

required
batch Any

The current batch.

required
batch_idx int

The index of the batch.

required
Source code in src/lighter/callbacks/freezer.py
def on_train_batch_start(self, trainer: Trainer, pl_module: System, batch: Any, batch_idx: int) -> None:
    """
    Called at the start of each training batch to potentially freeze parameters.

    Args:
        trainer: The trainer instance.
        pl_module: The System instance.
        batch: The current batch.
        batch_idx: The index of the batch.
    """
    self._on_batch_start(trainer, pl_module)

TableWriter

Bases: BaseWriter

Writer for saving predictions in a table format, such as CSV.

Parameters:

Name Type Description Default
path str | Path

CSV filepath.

required
writer str | Callable

Writer function or name of a registered writer.

required
Source code in src/lighter/callbacks/writer/table.py
class TableWriter(BaseWriter):
    """
    Writer for saving predictions in a table format, such as CSV.

    Args:
        path: CSV filepath.
        writer: Writer function or name of a registered writer.
    """

    def __init__(self, path: str | Path, writer: str | Callable) -> None:
        super().__init__(path, writer)
        self.csv_records = []

    @property
    def writers(self) -> dict[str, Callable]:
        return {
            "tensor": lambda tensor: tensor.item() if tensor.numel() == 1 else tensor.tolist(),
        }

    def write(self, tensor: Any, identifier: int | str) -> None:
        """
        Writes the tensor as a table record using the specified writer.

        Args:
            tensor: The tensor to record. Should not have a batch dimension.
            identifier: Identifier for the record.
        """
        self.csv_records.append({"identifier": identifier, "pred": self.writer(tensor)})

    def on_predict_epoch_end(self, trainer: Trainer, pl_module: System) -> None:
        """
        Called at the end of the prediction epoch to save predictions to a CSV file.

        Args:
            trainer: The trainer instance.
            pl_module: The System instance.
        """
        # If in distributed data parallel mode, gather records from all processes to rank 0.
        if trainer.world_size > 1:
            gather_csv_records = [None] * trainer.world_size if trainer.is_global_zero else None
            torch.distributed.gather_object(self.csv_records, gather_csv_records, dst=0)
            if trainer.is_global_zero:
                self.csv_records = list(itertools.chain(*gather_csv_records))

        # Save the records to a CSV file
        if trainer.is_global_zero:
            df = pd.DataFrame(self.csv_records)
            try:
                df = df.sort_values("identifier")
            except TypeError:
                pass
            df = df.set_index("identifier")
            df.to_csv(self.path)

        # Clear the records after saving
        self.csv_records = []

on_predict_epoch_end(trainer, pl_module)

Called at the end of the prediction epoch to save predictions to a CSV file.

Parameters:

Name Type Description Default
trainer Trainer

The trainer instance.

required
pl_module System

The System instance.

required
Source code in src/lighter/callbacks/writer/table.py
def on_predict_epoch_end(self, trainer: Trainer, pl_module: System) -> None:
    """
    Called at the end of the prediction epoch to save predictions to a CSV file.

    Args:
        trainer: The trainer instance.
        pl_module: The System instance.
    """
    # If in distributed data parallel mode, gather records from all processes to rank 0.
    if trainer.world_size > 1:
        gather_csv_records = [None] * trainer.world_size if trainer.is_global_zero else None
        torch.distributed.gather_object(self.csv_records, gather_csv_records, dst=0)
        if trainer.is_global_zero:
            self.csv_records = list(itertools.chain(*gather_csv_records))

    # Save the records to a CSV file
    if trainer.is_global_zero:
        df = pd.DataFrame(self.csv_records)
        try:
            df = df.sort_values("identifier")
        except TypeError:
            pass
        df = df.set_index("identifier")
        df.to_csv(self.path)

    # Clear the records after saving
    self.csv_records = []

write(tensor, identifier)

Writes the tensor as a table record using the specified writer.

Parameters:

Name Type Description Default
tensor Any

The tensor to record. Should not have a batch dimension.

required
identifier int | str

Identifier for the record.

required
Source code in src/lighter/callbacks/writer/table.py
def write(self, tensor: Any, identifier: int | str) -> None:
    """
    Writes the tensor as a table record using the specified writer.

    Args:
        tensor: The tensor to record. Should not have a batch dimension.
        identifier: Identifier for the record.
    """
    self.csv_records.append({"identifier": identifier, "pred": self.writer(tensor)})