Skip to content

resolver

Resolver

Resolves stage-specific configurations from the main configuration.

Source code in lighter/engine/resolver.py
class Resolver:
    """
    Resolves stage-specific configurations from the main configuration.
    """

    STAGE_MODES = {
        Stage.FIT: [Mode.TRAIN, Mode.VAL],
        Stage.VALIDATE: [Mode.VAL],
        Stage.TEST: [Mode.TEST],
        Stage.PREDICT: [Mode.PREDICT],
        Stage.LR_FIND: [Mode.TRAIN, Mode.VAL],
        Stage.SCALE_BATCH_SIZE: [Mode.TRAIN, Mode.VAL],
    }

    def __init__(self, config: Config):
        self.config = config

    def get_stage_config(self, stage: str) -> Config:
        """Get stage-specific configuration by filtering unused components."""
        if stage not in self.STAGE_MODES:
            raise ValueError(f"Invalid stage: {stage}. Allowed stages are {list(self.STAGE_MODES)}")

        stage_config = self.config.get().copy()
        system_config = stage_config.get("system", {})
        dataloader_config = system_config.get("dataloaders", {})
        metrics_config = system_config.get("metrics", {})

        # Remove dataloaders not relevant to the current stage
        for mode in set(dataloader_config) - set(self.STAGE_MODES[stage]):
            dataloader_config.pop(mode, None)

        # Remove metrics not relevant to the current stage
        for mode in set(metrics_config) - set(self.STAGE_MODES[stage]):
            metrics_config.pop(mode, None)

        # Remove optimizer, scheduler, and criterion if not relevant to the current stage
        if stage in [Stage.VALIDATE, Stage.TEST, Stage.PREDICT]:
            if stage != Stage.VALIDATE:
                system_config.pop("criterion", None)
            system_config.pop("optimizer", None)
            system_config.pop("scheduler", None)

        # Retain only relevant args for the current stage
        if "args" in stage_config:
            stage_config["args"] = {stage: stage_config["args"].get(stage, {})}

        return Config(stage_config, validate=False)

get_stage_config(stage)

Get stage-specific configuration by filtering unused components.

Source code in lighter/engine/resolver.py
def get_stage_config(self, stage: str) -> Config:
    """Get stage-specific configuration by filtering unused components."""
    if stage not in self.STAGE_MODES:
        raise ValueError(f"Invalid stage: {stage}. Allowed stages are {list(self.STAGE_MODES)}")

    stage_config = self.config.get().copy()
    system_config = stage_config.get("system", {})
    dataloader_config = system_config.get("dataloaders", {})
    metrics_config = system_config.get("metrics", {})

    # Remove dataloaders not relevant to the current stage
    for mode in set(dataloader_config) - set(self.STAGE_MODES[stage]):
        dataloader_config.pop(mode, None)

    # Remove metrics not relevant to the current stage
    for mode in set(metrics_config) - set(self.STAGE_MODES[stage]):
        metrics_config.pop(mode, None)

    # Remove optimizer, scheduler, and criterion if not relevant to the current stage
    if stage in [Stage.VALIDATE, Stage.TEST, Stage.PREDICT]:
        if stage != Stage.VALIDATE:
            system_config.pop("criterion", None)
        system_config.pop("optimizer", None)
        system_config.pop("scheduler", None)

    # Retain only relevant args for the current stage
    if "args" in stage_config:
        stage_config["args"] = {stage: stage_config["args"].get(stage, {})}

    return Config(stage_config, validate=False)