Skip to content

Index

CsvWriter

Bases: BaseWriter

Writer for saving predictions in a CSV format. It accumulates predictions in a temporary file and saves them to the final destination at the end of the prediction epoch.

Parameters:

Name Type Description Default
path str | Path

Path to save the final CSV file.

required
keys list[str]

A list of keys to be included in the CSV file. These keys must be present in the outputs dictionary from the prediction step.

required
Example
trainer:
  callbacks:
    - _target_: lighter.callbacks.CsvWriter
      path: predictions.csv
      keys: [id, pred, target]
Source code in src/lighter/callbacks/csv_writer.py
class CsvWriter(BaseWriter):
    """
    Writer for saving predictions in a CSV format. It accumulates predictions in a temporary
    file and saves them to the final destination at the end of the prediction epoch.

    Args:
        path (str | Path): Path to save the final CSV file.
        keys (list[str]): A list of keys to be included in the CSV file.
                          These keys must be present in the `outputs` dictionary
                          from the prediction step.

    Example:
        ```yaml
        trainer:
          callbacks:
            - _target_: lighter.callbacks.CsvWriter
              path: predictions.csv
              keys: [id, pred, target]
        ```
    """

    def __init__(self, path: str | Path, keys: list[str]) -> None:
        super().__init__(path)
        self.keys = keys
        self._temp_path: Path | None = None
        self._csv_writer: Any = None  # csv.writer type is not easily annotated
        self._csv_file: TextIOWrapper | None = None

    def _close_file(self) -> None:
        """Close the CSV file if it's open and reset related state."""
        if self._csv_file is not None and not self._csv_file.closed:
            self._csv_file.close()
        self._csv_file = None
        self._csv_writer = None

    def setup(self, trainer: Trainer, pl_module: LighterModule, stage: str) -> None:
        if stage != Stage.PREDICT:
            return
        super().setup(trainer, pl_module, stage)

        # Create a temporary file for writing predictions
        self._temp_path = self.path.with_suffix(f".tmp_rank{trainer.global_rank}{self.path.suffix}")
        self._csv_file = open(self._temp_path, "w", newline="")
        self._csv_writer = csv.writer(self._csv_file)
        # Write header
        self._csv_writer.writerow(self.keys)

    def _get_sequence_length(self, value: Any) -> int | None:
        if isinstance(value, (list, tuple)):
            return len(value)
        elif isinstance(value, torch.Tensor):
            if value.ndim == 0:  # Scalar tensor
                return 1
            else:
                return len(value)  # For non-scalar tensors, len() works
        return None  # Not a sequence type we care about

    def _get_record_value(self, value: Any, index: int) -> Any:
        if isinstance(value, (list, tuple)):
            return value[index]
        elif isinstance(value, torch.Tensor):
            if value.ndim == 0:  # Scalar tensor
                return value.item()  # Get Python scalar
            else:
                # For non-scalar tensors, get the item at index.
                # If the item itself is a scalar tensor, convert to Python scalar.
                # Otherwise, convert to a list (e.g., for image data).
                item = value[index]
                return item.item() if item.ndim == 0 else item.tolist()
        else:
            return value  # Non-sequence value, return as is (assumed to be for all samples)

    def write(self, outputs: dict[str, Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None:
        if self._csv_writer is None:
            return

        # Validate that at least one configured key is present in outputs
        present_keys = [key for key in self.keys if key in outputs]
        if not present_keys:
            missing_keys = self.keys
            raise KeyError(
                f"CsvWriter: none of the configured keys {missing_keys} were found in outputs. "
                f"Available keys in outputs: {list(outputs.keys())}"
            )

        # Determine the number of samples in the batch.
        num_samples = 0
        for key in self.keys:
            if key in outputs:
                length = self._get_sequence_length(outputs[key])
                if length is not None:
                    num_samples = length
                    break
                else:
                    # If it's not a sequence type we handle, assume it's a single sample
                    if num_samples == 0:
                        num_samples = 1

        # Validate that all list-like or tensor outputs have the same length
        for key in self.keys:
            if key in outputs:
                current_len = self._get_sequence_length(outputs[key])

                # Only validate if it's a sequence type and its length is not None
                if current_len is not None and current_len != num_samples:
                    raise ValueError(
                        f"CsvWriter found inconsistent lengths for keys: "
                        f"expected {num_samples}, but found {current_len} for key '{key}'."
                    )

        # Transpose the dictionary of lists into a list of per-sample records and write to CSV
        for i in range(num_samples):
            record = []
            for key in self.keys:
                if key not in outputs:
                    raise KeyError(f"CsvWriter expected key '{key}' in outputs but it was missing.")

                value = outputs[key]
                record.append(self._get_record_value(value, i))
            self._csv_writer.writerow(record)

    def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterModule) -> None:
        """
        At the end of the prediction epoch, it saves the temporary file to the final destination.
        """
        if self._csv_file is None:
            return

        # Close the temporary file
        self._close_file()

        all_temp_paths: list[Path | None] = [None] * trainer.world_size
        if dist.is_initialized():
            dist.all_gather_object(all_temp_paths, self._temp_path)
        else:
            all_temp_paths = [self._temp_path]

        if trainer.is_global_zero:
            # Read all temporary files into pandas DataFrames and concatenate them
            dfs = [pd.read_csv(path) for path in all_temp_paths if path is not None]
            if not dfs:
                return
            df = pd.concat(dfs, ignore_index=True)

            # Save the final CSV file
            df.to_csv(self.path, index=False)

            # Remove all temporary files
            for path in all_temp_paths:
                if path is not None:
                    path.unlink()

        # Reset temporary path
        self._temp_path = None

    def on_exception(self, trainer: Trainer, pl_module: LighterModule, exception: BaseException) -> None:
        """Close the file on errors to prevent file handle leaks."""
        self._close_file()

    def teardown(self, trainer: Trainer, pl_module: LighterModule, stage: str) -> None:
        """Guarantee cleanup when stage is PREDICT."""
        if stage == Stage.PREDICT:
            self._close_file()

_close_file()

Close the CSV file if it's open and reset related state.

Source code in src/lighter/callbacks/csv_writer.py
def _close_file(self) -> None:
    """Close the CSV file if it's open and reset related state."""
    if self._csv_file is not None and not self._csv_file.closed:
        self._csv_file.close()
    self._csv_file = None
    self._csv_writer = None

on_exception(trainer, pl_module, exception)

Close the file on errors to prevent file handle leaks.

Source code in src/lighter/callbacks/csv_writer.py
def on_exception(self, trainer: Trainer, pl_module: LighterModule, exception: BaseException) -> None:
    """Close the file on errors to prevent file handle leaks."""
    self._close_file()

on_predict_epoch_end(trainer, pl_module)

At the end of the prediction epoch, it saves the temporary file to the final destination.

Source code in src/lighter/callbacks/csv_writer.py
def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterModule) -> None:
    """
    At the end of the prediction epoch, it saves the temporary file to the final destination.
    """
    if self._csv_file is None:
        return

    # Close the temporary file
    self._close_file()

    all_temp_paths: list[Path | None] = [None] * trainer.world_size
    if dist.is_initialized():
        dist.all_gather_object(all_temp_paths, self._temp_path)
    else:
        all_temp_paths = [self._temp_path]

    if trainer.is_global_zero:
        # Read all temporary files into pandas DataFrames and concatenate them
        dfs = [pd.read_csv(path) for path in all_temp_paths if path is not None]
        if not dfs:
            return
        df = pd.concat(dfs, ignore_index=True)

        # Save the final CSV file
        df.to_csv(self.path, index=False)

        # Remove all temporary files
        for path in all_temp_paths:
            if path is not None:
                path.unlink()

    # Reset temporary path
    self._temp_path = None

teardown(trainer, pl_module, stage)

Guarantee cleanup when stage is PREDICT.

Source code in src/lighter/callbacks/csv_writer.py
def teardown(self, trainer: Trainer, pl_module: LighterModule, stage: str) -> None:
    """Guarantee cleanup when stage is PREDICT."""
    if stage == Stage.PREDICT:
        self._close_file()

FileWriter

Bases: BaseWriter

Persist a prediction value per sample to disk.

Parameters:

Name Type Description Default
directory str | Path

Directory to save prediction files.

required
value_key str

Key in the prediction outputs dict containing values to save.

required
writer_fn str | Callable[[Path, Any], None]

Writer function name (e.g., "tensor", "image_2d", "text") or callable.

required
name_key str | None

Optional key for custom file names. If None, uses sequential numbering.

None
Example
trainer:
  callbacks:
    - _target_: lighter.callbacks.FileWriter
      directory: predictions/
      value_key: pred
      writer_fn: tensor
Source code in src/lighter/callbacks/file_writer.py
class FileWriter(BaseWriter):
    """
    Persist a prediction value per sample to disk.

    Args:
        directory: Directory to save prediction files.
        value_key: Key in the prediction outputs dict containing values to save.
        writer_fn: Writer function name (e.g., "tensor", "image_2d", "text") or callable.
        name_key: Optional key for custom file names. If None, uses sequential numbering.

    Example:
        ```yaml
        trainer:
          callbacks:
            - _target_: lighter.callbacks.FileWriter
              directory: predictions/
              value_key: pred
              writer_fn: tensor
        ```
    """

    def __init__(
        self,
        directory: str | Path,
        value_key: str,
        writer_fn: str | Callable[[Path, Any], None],
        name_key: str | None = None,
    ) -> None:
        super().__init__(directory)
        self.value_key = value_key
        self.name_key = name_key
        if isinstance(writer_fn, str):
            self.writer_fn = writer_registry.get(writer_fn)
        elif callable(writer_fn):
            self.writer_fn = writer_fn
        else:
            raise TypeError("writer_fn must be a string or a callable")

        self._counter: int | None = None
        self._step: int = 1

    def setup(self, trainer: Trainer, pl_module: LighterModule, stage: str) -> None:
        super().setup(trainer, pl_module, stage)
        if stage != Stage.PREDICT:
            return

        if self.path.suffix:
            raise ValueError("FileWriter expects 'directory' to be a directory path, not a file path")

        if trainer.is_global_zero:
            self.path.mkdir(parents=True, exist_ok=True)

        if trainer.world_size > 1:
            self._step = trainer.world_size
            self._counter = trainer.global_rank
        else:
            self._step = 1
            self._counter = 0

    def write(self, outputs: dict[str, Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None:  # noqa: ARG002
        if self._counter is None:
            logger.debug("FileWriter received outputs before setup; skipping batch")
            return

        values = self._to_sequence(outputs, self.value_key)
        if not values:
            logger.debug("FileWriter value key '{}' yielded no samples; skipping batch", self.value_key)
            return

        if self.name_key is not None:
            names = self._to_sequence(outputs, self.name_key)
            if len(names) != len(values):
                raise ValueError(
                    "Length mismatch between value key "
                    f"'{self.value_key}' ({len(values)}) and name key "
                    f"'{self.name_key}' ({len(names)})."
                )
        else:
            names = []

        for offset, value in enumerate(values):
            global_index = self._counter + offset * self._step
            name = self._prepare_name(names[offset]) if names else global_index

            target_path = self.path / str(name)
            target_path.parent.mkdir(parents=True, exist_ok=True)

            prepared_value = self._prepare_value(value)
            self.writer_fn(target_path, prepared_value)

        self._counter += len(values) * self._step

    @staticmethod
    def _prepare_value(value: Any) -> Any:
        if isinstance(value, torch.Tensor):
            return value.detach().cpu()
        return value

    @staticmethod
    def _prepare_name(value: Any) -> Any:
        if isinstance(value, torch.Tensor):
            return value.detach().cpu().item() if value.ndim == 0 else value.detach().cpu().tolist()
        return value

    @staticmethod
    def _to_sequence(outputs: dict[str, Any], key: str) -> list:
        if key not in outputs:
            raise KeyError(f"FileWriter expected key '{key}' in outputs but it was missing.")

        value = outputs[key]
        if isinstance(value, torch.Tensor):
            if value.ndim == 0:
                return [value]
            return [tensor for tensor in value]
        if isinstance(value, (list, tuple)):
            return list(value)
        if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
            return list(value)
        return [value]

Freezer

Bases: Callback

Callback to freeze model parameters during training. Parameters can be frozen by exact name or prefix. Freezing can be applied indefinitely or until a specified step/epoch.

Parameters:

Name Type Description Default
names str | list[str] | None

Full names of parameters to freeze.

None
name_starts_with str | list[str] | None

Prefixes of parameter names to freeze.

None
except_names str | list[str] | None

Names of parameters to exclude from freezing.

None
except_name_starts_with str | list[str] | None

Prefixes of parameter names to exclude from freezing.

None
until_step int | None

Maximum step to freeze parameters until.

None
until_epoch int | None

Maximum epoch to freeze parameters until.

None

Raises:

Type Description
ValueError

If neither names nor name_starts_with are specified.

ValueError

If both until_step and until_epoch are specified.

Source code in src/lighter/callbacks/freezer.py
class Freezer(Callback):
    """
    Callback to freeze model parameters during training. Parameters can be frozen by exact name or prefix.
    Freezing can be applied indefinitely or until a specified step/epoch.

    Args:
        names: Full names of parameters to freeze.
        name_starts_with: Prefixes of parameter names to freeze.
        except_names: Names of parameters to exclude from freezing.
        except_name_starts_with: Prefixes of parameter names to exclude from freezing.
        until_step: Maximum step to freeze parameters until.
        until_epoch: Maximum epoch to freeze parameters until.

    Raises:
        ValueError: If neither `names` nor `name_starts_with` are specified.
        ValueError: If both `until_step` and `until_epoch` are specified.

    """

    def __init__(
        self,
        names: str | list[str] | None = None,
        name_starts_with: str | list[str] | None = None,
        except_names: str | list[str] | None = None,
        except_name_starts_with: str | list[str] | None = None,
        until_step: int | None = None,
        until_epoch: int | None = None,
    ) -> None:
        super().__init__()

        if names is None and name_starts_with is None:
            raise ValueError("At least one of `names` or `name_starts_with` must be specified.")

        if until_step is not None and until_epoch is not None:
            raise ValueError("Only one of `until_step` or `until_epoch` can be specified.")

        self.names = ensure_list(names)
        self.name_starts_with = ensure_list(name_starts_with)
        self.except_names = ensure_list(except_names)
        self.except_name_starts_with = ensure_list(except_name_starts_with)
        self.until_step = until_step
        self.until_epoch = until_epoch

        self._frozen_state = False

    def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None:
        """
        Called at the start of each training batch to freeze or unfreeze model parameters.

        Args:
            trainer: The trainer instance.
            pl_module: The LightningModule instance.
            batch: The current batch.
            batch_idx: The index of the batch.
        """
        current_step = trainer.global_step
        current_epoch = trainer.current_epoch

        # Unfreeze if the step or epoch limit has been reached.
        unfreeze_step = self.until_step is not None and current_step >= self.until_step
        unfreeze_epoch = self.until_epoch is not None and current_epoch >= self.until_epoch
        if unfreeze_step or unfreeze_epoch:
            if self._frozen_state:
                logger.info("Unfreezing the model.")
                self._set_model_requires_grad(pl_module, requires_grad=True)
                self._frozen_state = False
            return

        # Freeze if not already frozen.
        if not self._frozen_state:
            logger.info("Freezing the model.")
            self._set_model_requires_grad(pl_module, requires_grad=False)
            self._frozen_state = True

    def _set_model_requires_grad(self, model: LightningModule, requires_grad: bool) -> None:
        """
        Sets the `requires_grad` attribute for model parameters.

        When freezing (requires_grad=False):
        - Freeze specified parameters
        - Keep all others trainable (requires_grad=True)
        - Respect exception rules

        When unfreezing (requires_grad=True):
        - Unfreeze specified parameters
        - Keep all others trainable

        Args:
            model: The model whose parameters to modify.
            requires_grad: Whether to allow gradients (unfreeze) or not (freeze).
        """
        # If the model is a `LighterModule`, get the underlying network so users
        # can specify layer names without the "network." prefix.
        from lighter import LighterModule

        target = model.network if isinstance(model, LighterModule) else model

        frozen_layers = []
        unfrozen_layers = []

        for name, param in target.named_parameters():
            # Check if the parameter should be excluded from freezing.
            is_excepted = (self.except_names and name in self.except_names) or (
                self.except_name_starts_with and any(name.startswith(prefix) for prefix in self.except_name_starts_with)
            )
            if is_excepted:
                # Exceptions are always trainable
                param.requires_grad = True
                if not requires_grad:  # Only log when we're in freezing mode
                    unfrozen_layers.append(name)
                continue

            # Check if the parameter should be frozen/unfrozen.
            is_to_freeze = (self.names and name in self.names) or (
                self.name_starts_with and any(name.startswith(prefix) for prefix in self.name_starts_with)
            )
            if is_to_freeze:
                param.requires_grad = requires_grad
                if not requires_grad:
                    frozen_layers.append(name)
                else:
                    unfrozen_layers.append(name)
            else:
                # Not specified and not excepted - keep trainable
                param.requires_grad = True

        # Log the frozen/unfrozen layers.
        if frozen_layers:
            logger.info(
                f"Froze layers: {frozen_layers}"
                + (f" until step {self.until_step}" if self.until_step is not None else "")
                + (f" until epoch {self.until_epoch}" if self.until_epoch is not None else "")
            )
        if unfrozen_layers:
            suffix = " (excepted from freeze)" if not requires_grad else ""
            logger.info(f"Unfroze layers: {unfrozen_layers}{suffix}")

_set_model_requires_grad(model, requires_grad)

Sets the requires_grad attribute for model parameters.

When freezing (requires_grad=False): - Freeze specified parameters - Keep all others trainable (requires_grad=True) - Respect exception rules

When unfreezing (requires_grad=True): - Unfreeze specified parameters - Keep all others trainable

Parameters:

Name Type Description Default
model LightningModule

The model whose parameters to modify.

required
requires_grad bool

Whether to allow gradients (unfreeze) or not (freeze).

required
Source code in src/lighter/callbacks/freezer.py
def _set_model_requires_grad(self, model: LightningModule, requires_grad: bool) -> None:
    """
    Sets the `requires_grad` attribute for model parameters.

    When freezing (requires_grad=False):
    - Freeze specified parameters
    - Keep all others trainable (requires_grad=True)
    - Respect exception rules

    When unfreezing (requires_grad=True):
    - Unfreeze specified parameters
    - Keep all others trainable

    Args:
        model: The model whose parameters to modify.
        requires_grad: Whether to allow gradients (unfreeze) or not (freeze).
    """
    # If the model is a `LighterModule`, get the underlying network so users
    # can specify layer names without the "network." prefix.
    from lighter import LighterModule

    target = model.network if isinstance(model, LighterModule) else model

    frozen_layers = []
    unfrozen_layers = []

    for name, param in target.named_parameters():
        # Check if the parameter should be excluded from freezing.
        is_excepted = (self.except_names and name in self.except_names) or (
            self.except_name_starts_with and any(name.startswith(prefix) for prefix in self.except_name_starts_with)
        )
        if is_excepted:
            # Exceptions are always trainable
            param.requires_grad = True
            if not requires_grad:  # Only log when we're in freezing mode
                unfrozen_layers.append(name)
            continue

        # Check if the parameter should be frozen/unfrozen.
        is_to_freeze = (self.names and name in self.names) or (
            self.name_starts_with and any(name.startswith(prefix) for prefix in self.name_starts_with)
        )
        if is_to_freeze:
            param.requires_grad = requires_grad
            if not requires_grad:
                frozen_layers.append(name)
            else:
                unfrozen_layers.append(name)
        else:
            # Not specified and not excepted - keep trainable
            param.requires_grad = True

    # Log the frozen/unfrozen layers.
    if frozen_layers:
        logger.info(
            f"Froze layers: {frozen_layers}"
            + (f" until step {self.until_step}" if self.until_step is not None else "")
            + (f" until epoch {self.until_epoch}" if self.until_epoch is not None else "")
        )
    if unfrozen_layers:
        suffix = " (excepted from freeze)" if not requires_grad else ""
        logger.info(f"Unfroze layers: {unfrozen_layers}{suffix}")

on_train_batch_start(trainer, pl_module, batch, batch_idx)

Called at the start of each training batch to freeze or unfreeze model parameters.

Parameters:

Name Type Description Default
trainer Trainer

The trainer instance.

required
pl_module LightningModule

The LightningModule instance.

required
batch Any

The current batch.

required
batch_idx int

The index of the batch.

required
Source code in src/lighter/callbacks/freezer.py
def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None:
    """
    Called at the start of each training batch to freeze or unfreeze model parameters.

    Args:
        trainer: The trainer instance.
        pl_module: The LightningModule instance.
        batch: The current batch.
        batch_idx: The index of the batch.
    """
    current_step = trainer.global_step
    current_epoch = trainer.current_epoch

    # Unfreeze if the step or epoch limit has been reached.
    unfreeze_step = self.until_step is not None and current_step >= self.until_step
    unfreeze_epoch = self.until_epoch is not None and current_epoch >= self.until_epoch
    if unfreeze_step or unfreeze_epoch:
        if self._frozen_state:
            logger.info("Unfreezing the model.")
            self._set_model_requires_grad(pl_module, requires_grad=True)
            self._frozen_state = False
        return

    # Freeze if not already frozen.
    if not self._frozen_state:
        logger.info("Freezing the model.")
        self._set_model_requires_grad(pl_module, requires_grad=False)
        self._frozen_state = True