Skip to content

containers

Adapters dataclass

Root configuration class for all adapters across different modes.

Source code in lighter/utils/types/containers.py
@nested
@dataclass
class Adapters:
    """Root configuration class for all adapters across different modes."""

    train: Train = field(default_factory=Train)
    val: Val = field(default_factory=Val)
    test: Test = field(default_factory=Test)
    predict: Predict = field(default_factory=Predict)

Predict dataclass

Predict mode sub-dataclass for Adapters.

Source code in lighter/utils/types/containers.py
@dataclass
class Predict:
    """Predict mode sub-dataclass for Adapters."""

    batch: BatchAdapter = field(default_factory=lambda: BatchAdapter(input_accessor=lambda batch: batch))
    logging: LoggingAdapter = field(default_factory=LoggingAdapter)

Test dataclass

Test mode sub-dataclass for Adapters.

Source code in lighter/utils/types/containers.py
@dataclass
class Test:
    """Test mode sub-dataclass for Adapters."""

    batch: BatchAdapter = field(default_factory=lambda: BatchAdapter(input_accessor=0, target_accessor=1))
    metrics: MetricsAdapter = field(default_factory=lambda: MetricsAdapter(pred_argument=0, target_argument=1))
    logging: LoggingAdapter = field(default_factory=LoggingAdapter)

Train dataclass

Train mode sub-dataclass for Adapters.

Source code in lighter/utils/types/containers.py
@dataclass
class Train:
    """Train mode sub-dataclass for Adapters."""

    batch: BatchAdapter = field(default_factory=lambda: BatchAdapter(input_accessor=0, target_accessor=1))
    criterion: CriterionAdapter = field(default_factory=lambda: CriterionAdapter(pred_argument=0, target_argument=1))
    metrics: MetricsAdapter = field(default_factory=lambda: MetricsAdapter(pred_argument=0, target_argument=1))
    logging: LoggingAdapter = field(default_factory=LoggingAdapter)

Val dataclass

Val mode sub-dataclass for Adapters.

Source code in lighter/utils/types/containers.py
@dataclass
class Val:
    """Val mode sub-dataclass for Adapters."""

    batch: BatchAdapter = field(default_factory=lambda: BatchAdapter(input_accessor=0, target_accessor=1))
    criterion: CriterionAdapter = field(default_factory=lambda: CriterionAdapter(pred_argument=0, target_argument=1))
    metrics: MetricsAdapter = field(default_factory=lambda: MetricsAdapter(pred_argument=0, target_argument=1))
    logging: LoggingAdapter = field(default_factory=LoggingAdapter)

nested(cls)

Decorator to handle nested dataclass creation. Example:

@nested
@dataclass
class Example:
    ...

Source code in lighter/utils/types/containers.py
def nested(cls):
    """
    Decorator to handle nested dataclass creation.
    Example:
        ```
        @nested
        @dataclass
        class Example:
            ...
        ```
    """
    original_init = cls.__init__

    def __init__(self, *args, **kwargs):
        for f in fields(cls):
            if is_dataclass(f.type) and f.name in kwargs:
                kwargs[f.name] = f.type(**kwargs[f.name])
        original_init(self, *args, **kwargs)

    cls.__init__ = __init__
    return cls