Get stage-specific configuration by filtering unused components.
Source code in lighter/engine/resolver.py
| def get_stage_config(self, stage: str) -> Config:
"""Get stage-specific configuration by filtering unused components."""
if stage not in self.STAGE_MODES:
raise ValueError(f"Invalid stage: {stage}. Allowed stages are {list(self.STAGE_MODES)}")
stage_config = self.config.get().copy()
system_config = stage_config.get("system", {})
dataloader_config = system_config.get("dataloaders", {})
metrics_config = system_config.get("metrics", {})
# Remove dataloaders not relevant to the current stage
for mode in set(dataloader_config) - set(self.STAGE_MODES[stage]):
dataloader_config.pop(mode, None)
# Remove metrics not relevant to the current stage
for mode in set(metrics_config) - set(self.STAGE_MODES[stage]):
metrics_config.pop(mode, None)
# Remove optimizer, scheduler, and criterion if not relevant to the current stage
if stage in [Stage.VALIDATE, Stage.TEST, Stage.PREDICT]:
if stage != Stage.VALIDATE:
system_config.pop("criterion", None)
system_config.pop("optimizer", None)
system_config.pop("scheduler", None)
# Retain only relevant args for the current stage
if "args" in stage_config:
stage_config["args"] = {stage: stage_config["args"].get(stage, {})}
return Config(stage_config, validate=False)
|