Skip to content

misc

This module contains miscellaneous utility functions for handling lists, attributes, and function arguments.

ensure_list(input)

Ensures that the input is wrapped in a list. If the input is None, returns an empty list.

Parameters:

Name Type Description Default
input Any

The input to wrap in a list.

required

Returns:

Name Type Description
List list

The input wrapped in a list, or an empty list if input is None.

Source code in src/lighter/utils/misc.py
def ensure_list(input: Any) -> list:
    """
    Ensures that the input is wrapped in a list. If the input is None, returns an empty list.

    Args:
        input: The input to wrap in a list.

    Returns:
        List: The input wrapped in a list, or an empty list if input is None.
    """
    if isinstance(input, list):
        return input
    if isinstance(input, tuple):
        return list(input)
    if input is None:
        return []
    return [input]

get_name(_callable, include_module_name=False)

Retrieves the name of a callable, optionally including the module name.

Parameters:

Name Type Description Default
_callable Callable

The callable whose name to retrieve.

required
include_module_name bool

Whether to include the module name in the result.

False

Returns:

Name Type Description
str str

The name of the callable, optionally prefixed with the module name.

Source code in src/lighter/utils/misc.py
def get_name(_callable: Callable, include_module_name: bool = False) -> str:
    """
    Retrieves the name of a callable, optionally including the module name.

    Args:
        _callable: The callable whose name to retrieve.
        include_module_name: Whether to include the module name in the result.

    Returns:
        str: The name of the callable, optionally prefixed with the module name.
    """
    # Get the name directly from the callable's __name__ attribute
    name = getattr(_callable, "__name__", type(_callable).__name__)

    if include_module_name:
        # Get the module name directly from the callable's __module__ attribute
        module = getattr(_callable, "__module__", type(_callable).__module__)
        name = f"{module}.{name}"

    return name

get_optimizer_stats(optimizer)

Extract hyperparameters from a PyTorch optimizer.

Collects learning rate and other key hyperparameters from each parameter group in the optimizer and returns them in a dictionary. Keys are formatted to show the optimizer type and group number (if multiple groups exist).

Parameters:

Name Type Description Default
optimizer Optimizer

The PyTorch optimizer to extract values from.

required

Returns:

Type Description
dict[str, float]

dict[str, float]: dictionary containing optimizer hyperparameters: - Learning rate: "optimizer/{name}/lr[/group{N}]" - Momentum: "optimizer/{name}/momentum[/group{N}]" (SGD, RMSprop) - Beta1: "optimizer/{name}/beta1[/group{N}]" (Adam variants) - Beta2: "optimizer/{name}/beta2[/group{N}]" (Adam variants) - Weight decay: "optimizer/{name}/weight_decay[/group{N}]"

Where [/group{N}] is only added for optimizers with multiple groups.

Source code in src/lighter/utils/misc.py
def get_optimizer_stats(optimizer: Optimizer) -> dict[str, float]:
    """
    Extract hyperparameters from a PyTorch optimizer.

    Collects learning rate and other key hyperparameters from each parameter group
    in the optimizer and returns them in a dictionary. Keys are formatted to show
    the optimizer type and group number (if multiple groups exist).

    Args:
        optimizer: The PyTorch optimizer to extract values from.

    Returns:
        dict[str, float]: dictionary containing optimizer hyperparameters:
            - Learning rate: "optimizer/{name}/lr[/group{N}]"
            - Momentum: "optimizer/{name}/momentum[/group{N}]" (SGD, RMSprop)
            - Beta1: "optimizer/{name}/beta1[/group{N}]" (Adam variants)
            - Beta2: "optimizer/{name}/beta2[/group{N}]" (Adam variants)
            - Weight decay: "optimizer/{name}/weight_decay[/group{N}]"

            Where [/group{N}] is only added for optimizers with multiple groups.
    """
    stats_dict = {}
    for group_idx, group in enumerate(optimizer.param_groups):
        base_key = f"optimizer/{optimizer.__class__.__name__}"

        # Add group index suffix if there are multiple parameter groups
        suffix = f"/group{group_idx + 1}" if len(optimizer.param_groups) > 1 else ""

        # Always extract learning rate (present in all optimizers)
        stats_dict[f"{base_key}/lr{suffix}"] = group["lr"]

        # Extract momentum (SGD, RMSprop)
        if "momentum" in group:
            stats_dict[f"{base_key}/momentum{suffix}"] = group["momentum"]

        # Extract betas (Adam, AdamW, NAdam, RAdam, etc.)
        if "betas" in group:
            stats_dict[f"{base_key}/beta1{suffix}"] = group["betas"][0]
            if len(group["betas"]) > 1:
                stats_dict[f"{base_key}/beta2{suffix}"] = group["betas"][1]

        # Extract weight decay if non-zero
        if "weight_decay" in group and group["weight_decay"] != 0:
            stats_dict[f"{base_key}/weight_decay{suffix}"] = group["weight_decay"]

    return stats_dict

hasarg(fn, arg_name)

Checks if a callable (function, method, or class) has a specific argument.

Parameters:

Name Type Description Default
fn Callable

The callable to inspect.

required
arg_name str

The name of the argument to check for.

required

Returns:

Name Type Description
bool bool

True if the argument exists, False otherwise.

Source code in src/lighter/utils/misc.py
def hasarg(fn: Callable, arg_name: str) -> bool:
    """
    Checks if a callable (function, method, or class) has a specific argument.

    Args:
        fn: The callable to inspect.
        arg_name: The name of the argument to check for.

    Returns:
        bool: True if the argument exists, False otherwise.
    """
    args = inspect.signature(fn).parameters.keys()
    return arg_name in args