Skip to content

schema

Schema validation using dataclasses.

This module provides structured config validation using Python dataclasses. Define configuration schemas with type hints, then validate your YAML configs against them at runtime.

Example
from dataclasses import dataclass
from typing import Optional
from sparkwheel import Config
from sparkwheel.schema import validate

@dataclass
class OptimizerConfig:
    lr: float
    momentum: float = 0.9
    weight_decay: Optional[float] = None

@dataclass
class ModelConfig:
    hidden_size: int
    num_layers: int
    dropout: float
    optimizer: OptimizerConfig

# Load and validate config
config = Config.load("config.yaml")
validate(config.get(), ModelConfig)  # Raises error if invalid

# Or validate during load
config = Config.load("config.yaml", schema=ModelConfig)

ValidationError

Bases: BaseError

Raised when configuration validation fails.

Attributes:

Name Type Description
message

Error description

field_path

Dot-separated path to the invalid field (e.g., "model.optimizer.lr")

expected_type

The type that was expected

actual_value

The value that failed validation

source_location

Optional location in source file where error occurred

Source code in src/sparkwheel/schema.py
class ValidationError(BaseError):
    """Raised when configuration validation fails.

    Attributes:
        message: Error description
        field_path: Dot-separated path to the invalid field (e.g., "model.optimizer.lr")
        expected_type: The type that was expected
        actual_value: The value that failed validation
        source_location: Optional location in source file where error occurred
    """

    def __init__(
        self,
        message: str,
        field_path: str = "",
        expected_type: type | None = None,
        actual_value: Any = None,
        source_location: SourceLocation | None = None,
    ):
        """Initialize validation error.

        Args:
            message: Human-readable error message
            field_path: Dot-separated path to invalid field
            expected_type: Expected type for the field
            actual_value: The actual value that failed validation
            source_location: Source location where the invalid value was defined
        """
        self.field_path = field_path
        self.expected_type = expected_type
        self.actual_value = actual_value

        # Build detailed message
        full_message = message
        if field_path:
            full_message = f"Validation error at '{field_path}': {message}"
        if expected_type is not None:
            type_name = getattr(expected_type, "__name__", str(expected_type))
            full_message += f"\n  Expected type: {type_name}"
        if actual_value is not None:
            actual_type = type(actual_value).__name__
            full_message += f"\n  Actual type: {actual_type}"
            full_message += f"\n  Actual value: {actual_value!r}"

        super().__init__(full_message, source_location=source_location)

__init__(message, field_path='', expected_type=None, actual_value=None, source_location=None)

Initialize validation error.

Parameters:

Name Type Description Default
message str

Human-readable error message

required
field_path str

Dot-separated path to invalid field

''
expected_type type | None

Expected type for the field

None
actual_value Any

The actual value that failed validation

None
source_location SourceLocation | None

Source location where the invalid value was defined

None
Source code in src/sparkwheel/schema.py
def __init__(
    self,
    message: str,
    field_path: str = "",
    expected_type: type | None = None,
    actual_value: Any = None,
    source_location: SourceLocation | None = None,
):
    """Initialize validation error.

    Args:
        message: Human-readable error message
        field_path: Dot-separated path to invalid field
        expected_type: Expected type for the field
        actual_value: The actual value that failed validation
        source_location: Source location where the invalid value was defined
    """
    self.field_path = field_path
    self.expected_type = expected_type
    self.actual_value = actual_value

    # Build detailed message
    full_message = message
    if field_path:
        full_message = f"Validation error at '{field_path}': {message}"
    if expected_type is not None:
        type_name = getattr(expected_type, "__name__", str(expected_type))
        full_message += f"\n  Expected type: {type_name}"
    if actual_value is not None:
        actual_type = type(actual_value).__name__
        full_message += f"\n  Actual type: {actual_type}"
        full_message += f"\n  Actual value: {actual_value!r}"

    super().__init__(full_message, source_location=source_location)

_MissingSentinel

Sentinel for required-but-not-yet-set config values.

Source code in src/sparkwheel/schema.py
class _MissingSentinel:
    """Sentinel for required-but-not-yet-set config values."""

    def __repr__(self) -> str:
        return "MISSING"

    def __bool__(self) -> bool:
        return False

_find_discriminator(union_types)

Find discriminator field in a Union of dataclasses.

A discriminator is a field that: - Exists in all dataclass types in the Union - Has Literal type annotation - Has unique values per type

Parameters:

Name Type Description Default
union_types tuple[Any, ...]

Types in the Union

required

Returns:

Type Description
tuple[bool, str | None]

(has_discriminator, field_name)

Source code in src/sparkwheel/schema.py
def _find_discriminator(union_types: tuple[Any, ...]) -> tuple[bool, str | None]:
    """Find discriminator field in a Union of dataclasses.

    A discriminator is a field that:
    - Exists in all dataclass types in the Union
    - Has Literal type annotation
    - Has unique values per type

    Args:
        union_types: Types in the Union

    Returns:
        (has_discriminator, field_name)
    """
    from typing import Literal

    # Filter to dataclasses only
    dataclass_types = [t for t in union_types if dataclasses.is_dataclass(t)]
    if len(dataclass_types) < 2:
        return False, None

    # Find fields that exist in all types with Literal annotation
    all_fields: dict[str, list[Any]] = {}
    for dc_type in dataclass_types:
        for f in dataclasses.fields(dc_type):
            if get_origin(f.type) is Literal:
                if f.name not in all_fields:
                    all_fields[f.name] = []
                literal_values = get_args(f.type)
                all_fields[f.name].append({"type": dc_type, "values": literal_values})

    # Find a field present in all types with unique values
    for field_name, type_infos in all_fields.items():
        if len(type_infos) != len(dataclass_types):
            continue  # Not in all types

        # Check values are unique across types
        all_values = set()
        is_unique = True
        for info in type_infos:
            for val in info["values"]:
                if val in all_values:
                    is_unique = False
                    break
                all_values.add(val)
            if not is_unique:
                break

        if is_unique:
            return True, field_name

    return False, None

_format_union_type(types_tuple)

Format a tuple of types as Union[...] for error messages.

Source code in src/sparkwheel/schema.py
def _format_union_type(types_tuple: tuple[Any, ...]) -> str:
    """Format a tuple of types as Union[...] for error messages."""
    type_names = []
    for t in types_tuple:
        if hasattr(t, "__name__"):
            type_names.append(t.__name__)
        else:
            type_names.append(str(t))
    return f"Union[{', '.join(type_names)}]"

_get_source_location(metadata, field_path)

Get source location from metadata registry.

Parameters:

Name Type Description Default
metadata Any

MetadataRegistry instance

required
field_path str

Dot-separated field path to look up

required

Returns:

Type Description
SourceLocation | None

SourceLocation if found, None otherwise

Source code in src/sparkwheel/schema.py
def _get_source_location(metadata: Any, field_path: str) -> SourceLocation | None:
    """Get source location from metadata registry.

    Args:
        metadata: MetadataRegistry instance
        field_path: Dot-separated field path to look up

    Returns:
        SourceLocation if found, None otherwise
    """
    if metadata is None:
        return None

    try:
        # Convert dot notation to :: notation used by sparkwheel
        id_path = field_path.replace(".", "::")
        result = metadata.get(id_path)
        return result if result is None or isinstance(result, SourceLocation) else None
    except Exception:
        return None

_get_validators(schema_type)

Get all validator methods from a dataclass.

Source code in src/sparkwheel/schema.py
def _get_validators(schema_type: type) -> list[Any]:
    """Get all validator methods from a dataclass."""
    validators = []
    for attr_name in dir(schema_type):
        if attr_name.startswith("_"):
            continue
        try:
            attr = getattr(schema_type, attr_name)
            if callable(attr) and getattr(attr, "__is_validator__", False):
                validators.append(attr)
        except AttributeError:
            continue
    return validators

_is_union_type(origin)

Check if origin is a Union type (handles both typing.Union and types.UnionType).

Source code in src/sparkwheel/schema.py
def _is_union_type(origin: Any) -> bool:
    """Check if origin is a Union type (handles both typing.Union and types.UnionType)."""
    if origin is Union:
        return True
    # Python 3.10+ uses types.UnionType for X | Y syntax
    if hasattr(types, "UnionType") and origin is types.UnionType:
        return True
    return False

_run_validators(config, schema, field_path='', metadata=None)

Run all @validator methods on a dataclass.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dict

required
schema type

Dataclass type

required
field_path str

Path to this config

''
metadata Any

Optional metadata

None

Raises:

Type Description
ValidationError

If validation fails

Source code in src/sparkwheel/schema.py
def _run_validators(
    config: dict[str, Any],
    schema: type,
    field_path: str = "",
    metadata: Any = None,
) -> None:
    """Run all @validator methods on a dataclass.

    Args:
        config: Configuration dict
        schema: Dataclass type
        field_path: Path to this config
        metadata: Optional metadata

    Raises:
        ValidationError: If validation fails
    """
    validators = _get_validators(schema)
    if not validators:
        return

    # Skip validation for configs with references/expressions/macros
    # They'll be validated after resolution
    for value in config.values():
        if isinstance(value, str) and value.startswith(("@", "$", "%")):
            # Has unresolved references - skip custom validation
            return

    # Create instance to call validators on
    try:
        instance = schema(**config)
    except Exception:
        # Can't create instance - skip validation
        return

    source_loc = _get_source_location(metadata, field_path) if metadata else None

    for validator_method in validators:
        try:
            validator_method(instance)
        except ValueError as e:
            raise ValidationError(
                str(e),
                field_path=field_path,
                source_location=source_loc,
            ) from e
        except Exception as e:
            raise ValidationError(
                f"Validator '{validator_method.__name__}' raised {type(e).__name__}: {e}",
                field_path=field_path,
                source_location=source_loc,
            ) from e

_validate_discriminated_union(value, union_types, discriminator_field, field_path, metadata=None)

Validate a discriminated union by checking the discriminator.

Parameters:

Name Type Description Default
value Any

Value to validate (must be dict)

required
union_types tuple[Any, ...]

Types in the Union

required
discriminator_field str

Name of discriminator field

required
field_path str

Path to field

required
metadata Any

Optional metadata

None

Raises:

Type Description
ValidationError

If validation fails

Source code in src/sparkwheel/schema.py
def _validate_discriminated_union(
    value: Any,
    union_types: tuple[Any, ...],
    discriminator_field: str,
    field_path: str,
    metadata: Any = None,
) -> None:
    """Validate a discriminated union by checking the discriminator.

    Args:
        value: Value to validate (must be dict)
        union_types: Types in the Union
        discriminator_field: Name of discriminator field
        field_path: Path to field
        metadata: Optional metadata

    Raises:
        ValidationError: If validation fails
    """
    source_loc = _get_source_location(metadata, field_path) if metadata else None

    if not isinstance(value, dict):
        raise ValidationError(
            f"Discriminated union requires dict, got {type(value).__name__}",
            field_path=field_path,
            actual_value=value,
            source_location=source_loc,
        )

    # Check discriminator field exists
    if discriminator_field not in value:
        dataclass_types = [t for t in union_types if dataclasses.is_dataclass(t)]
        type_names = ", ".join(t.__name__ if isinstance(t, type) else type(t).__name__ for t in dataclass_types)
        raise ValidationError(
            f"Missing discriminator field '{discriminator_field}' (required for union of {type_names})",
            field_path=field_path,
            actual_value=value,
            source_location=source_loc,
        )

    discriminator_value = value[discriminator_field]

    # Find matching type
    dataclass_types = [t for t in union_types if dataclasses.is_dataclass(t)]
    matching_type = None

    for dc_type in dataclass_types:
        for f in dataclasses.fields(dc_type):
            if f.name == discriminator_field:
                literal_values = get_args(f.type)
                if discriminator_value in literal_values:
                    matching_type = dc_type
                    break
        if matching_type:
            break

    if matching_type is None:
        # Build helpful error with valid values
        valid_values = []
        for dc_type in dataclass_types:
            for f in dataclasses.fields(dc_type):
                if f.name == discriminator_field:
                    literal_values = get_args(f.type)
                    for val in literal_values:
                        type_name = dc_type.__name__ if isinstance(dc_type, type) else type(dc_type).__name__
                        valid_values.append(f"'{val}' ({type_name})")

        valid_str = ", ".join(valid_values)
        raise ValidationError(
            f"Invalid discriminator value '{discriminator_value}'. Valid: {valid_str}",
            field_path=field_path,
            actual_value=value,
            source_location=source_loc,
        )

    # Validate against the selected type
    assert isinstance(matching_type, type)
    validate(value, matching_type, field_path, metadata, allow_missing=False, strict=True)

_validate_field(value, expected_type, field_path, metadata=None, allow_missing=False)

Validate a single field value against its expected type.

Parameters:

Name Type Description Default
value Any

The value to validate

required
expected_type type

The expected type (may be generic like list[int])

required
field_path str

Dot-separated path to this field

required
metadata Any

Optional metadata registry for source locations

None
allow_missing bool

If True, allow MISSING sentinel values for partial configs

False

Raises:

Type Description
ValidationError

If validation fails

Source code in src/sparkwheel/schema.py
def _validate_field(
    value: Any,
    expected_type: type,
    field_path: str,
    metadata: Any = None,
    allow_missing: bool = False,
) -> None:
    """Validate a single field value against its expected type.

    Args:
        value: The value to validate
        expected_type: The expected type (may be generic like list[int])
        field_path: Dot-separated path to this field
        metadata: Optional metadata registry for source locations
        allow_missing: If True, allow MISSING sentinel values for partial configs

    Raises:
        ValidationError: If validation fails
    """
    source_loc = _get_source_location(metadata, field_path) if metadata else None

    # Handle MISSING values
    if isinstance(value, _MissingSentinel):
        if allow_missing:
            return  # OK for partial configs
        else:
            raise ValidationError(
                "Field has MISSING value but MISSING not allowed",
                field_path=field_path,
                expected_type=expected_type,
                actual_value=value,
                source_location=source_loc,
            )

    # Handle None values
    origin = get_origin(expected_type)
    args = get_args(expected_type)

    # Handle Optional[T] (which is Union[T, None])
    if _is_union_type(origin):
        # Check for discriminated union first
        has_discriminator, discriminator_field = _find_discriminator(args)
        if has_discriminator and discriminator_field:
            _validate_discriminated_union(value, args, discriminator_field, field_path, metadata)
            return

        # Check if None is allowed
        if type(None) in args:
            if value is None:
                return  # None is valid
            # Remove None from the union and validate against remaining types
            non_none_types = [t for t in args if t is not type(None)]
            if len(non_none_types) == 1:
                # Simple Optional[T] case - recursively validate with the single type
                _validate_field(value, non_none_types[0], field_path, metadata, allow_missing)
                return
            else:
                # Union with multiple non-None types - try each and collect errors
                errors = []
                for union_type in non_none_types:
                    try:
                        _validate_field(value, union_type, field_path, metadata, allow_missing)
                        return  # Validation succeeded
                    except ValidationError as e:
                        type_name = getattr(union_type, "__name__", str(union_type))
                        # Extract just the error message without field path prefix
                        error_msg = str(e).split("\n")[0]
                        if f"Validation error at '{field_path}': " in error_msg:
                            error_msg = error_msg.replace(f"Validation error at '{field_path}': ", "")
                        errors.append(f"  Tried {type_name}: {error_msg}")

                # All failed - build comprehensive error message
                union_str = _format_union_type(tuple(non_none_types))
                error_details = "\n".join(errors)
                raise ValidationError(
                    f"Value doesn't match any type in {union_str}\n{error_details}",
                    field_path=field_path,
                    expected_type=expected_type,
                    actual_value=value,
                    source_location=source_loc,
                )
        else:
            # Non-Optional Union - try each type and collect errors
            errors = []
            for union_type in args:
                try:
                    _validate_field(value, union_type, field_path, metadata, allow_missing)
                    return  # Validation succeeded
                except ValidationError as e:
                    type_name = getattr(union_type, "__name__", str(union_type))
                    # Extract just the error message without field path prefix
                    error_msg = str(e).split("\n")[0]
                    if f"Validation error at '{field_path}': " in error_msg:
                        error_msg = error_msg.replace(f"Validation error at '{field_path}': ", "")
                    errors.append(f"  Tried {type_name}: {error_msg}")

            # All failed - build comprehensive error message
            union_str = _format_union_type(args)
            error_details = "\n".join(errors)
            raise ValidationError(
                f"Value doesn't match any type in {union_str}\n{error_details}",
                field_path=field_path,
                expected_type=expected_type,
                actual_value=value,
                source_location=source_loc,
            )

    # Handle list[T]
    if origin is list:
        if not isinstance(value, list):
            raise ValidationError(
                "Expected list",
                field_path=field_path,
                expected_type=list,
                actual_value=value,
                source_location=source_loc,
            )
        if args:
            item_type = args[0]
            # Skip validation for List[Any] - accept any item types
            if item_type is not Any:
                for i, item in enumerate(value):
                    _validate_field(
                        item,
                        item_type,
                        f"{field_path}[{i}]",
                        metadata,
                        allow_missing,
                    )
        return

    # Handle dict[K, V]
    if origin is dict:
        if not isinstance(value, dict):
            raise ValidationError(
                "Expected dict",
                field_path=field_path,
                expected_type=dict,
                actual_value=value,
                source_location=source_loc,
            )
        if args and len(args) == 2:
            key_type, value_type = args
            # For Dict[K, Any], only validate keys and allow arbitrary values
            if value_type is Any:
                for k in value.keys():
                    if not isinstance(k, key_type):
                        raise ValidationError(
                            "Dict key has wrong type",
                            field_path=f"{field_path}[{k!r}]",
                            expected_type=key_type,
                            actual_value=k,
                            source_location=source_loc,
                        )
                return
            # Otherwise validate both keys and values
            for k, v in value.items():
                # Validate key type
                if not isinstance(k, key_type):
                    raise ValidationError(
                        "Dict key has wrong type",
                        field_path=f"{field_path}[{k!r}]",
                        expected_type=key_type,
                        actual_value=k,
                        source_location=source_loc,
                    )
                # Validate value type
                _validate_field(
                    v,
                    value_type,
                    f"{field_path}[{k!r}]",
                    metadata,
                    allow_missing,
                )
        return

    # Handle nested dataclasses
    if dataclasses.is_dataclass(expected_type):
        validate(value, expected_type, field_path, metadata, allow_missing, strict=True)
        return

    # Handle Literal types
    from typing import Literal

    if origin is Literal:
        if value not in args:
            valid_values = ", ".join(repr(v) for v in args)
            raise ValidationError(
                f"Value must be one of {valid_values}, got {value!r}",
                field_path=field_path,
                expected_type=expected_type,
                actual_value=value,
                source_location=source_loc,
            )
        return

    # Handle Any type - accept any value
    if expected_type == Any:
        return

    # Handle basic types (int, str, float, bool, etc.)
    if not isinstance(value, expected_type):
        # Special case: accept resolved references (@), raw references (%), and expressions ($) as strings
        # since they'll be resolved/expanded later
        if isinstance(value, str) and (value.startswith("@") or value.startswith("$") or value.startswith("%")):
            # This is a resolved reference/raw reference/expression that will be processed later
            # We can't validate its type until resolution
            return

        # Special case: allow int for float
        if expected_type is float and isinstance(value, int):
            return

        raise ValidationError(
            "Type mismatch",
            field_path=field_path,
            expected_type=expected_type,
            actual_value=value,
            source_location=source_loc,
        )

validate(config, schema, field_path='', metadata=None, allow_missing=False, strict=True)

Validate configuration against a dataclass schema.

Performs recursive type checking to ensure the configuration matches the structure and types defined in the dataclass schema.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dictionary to validate

required
schema type

Dataclass type defining the expected structure

required
field_path str

Internal parameter for tracking nested field paths

''
metadata Any

Optional metadata registry for source locations

None
allow_missing bool

If True, allow MISSING sentinel values for partial configs

False
strict bool

If True, reject unexpected fields. If False, ignore them.

True

Raises:

Type Description
ValidationError

If validation fails

TypeError

If schema is not a dataclass

Example
from dataclasses import dataclass
from sparkwheel import Config
from sparkwheel.schema import validate

@dataclass
class AppConfig:
    name: str
    port: int
    debug: bool = False

config = Config().update("app.yaml")
validate(config.get(), AppConfig)
Source code in src/sparkwheel/schema.py
def validate(
    config: dict[str, Any],
    schema: type,
    field_path: str = "",
    metadata: Any = None,
    allow_missing: bool = False,
    strict: bool = True,
) -> None:
    """Validate configuration against a dataclass schema.

    Performs recursive type checking to ensure the configuration matches
    the structure and types defined in the dataclass schema.

    Args:
        config: Configuration dictionary to validate
        schema: Dataclass type defining the expected structure
        field_path: Internal parameter for tracking nested field paths
        metadata: Optional metadata registry for source locations
        allow_missing: If True, allow MISSING sentinel values for partial configs
        strict: If True, reject unexpected fields. If False, ignore them.

    Raises:
        ValidationError: If validation fails
        TypeError: If schema is not a dataclass

    Example:
        ```python
        from dataclasses import dataclass
        from sparkwheel import Config
        from sparkwheel.schema import validate

        @dataclass
        class AppConfig:
            name: str
            port: int
            debug: bool = False

        config = Config().update("app.yaml")
        validate(config.get(), AppConfig)
        ```
    """
    if not dataclasses.is_dataclass(schema):
        raise TypeError(f"Schema must be a dataclass, got {type(schema).__name__}")

    if not isinstance(config, dict):
        source_loc = _get_source_location(metadata, field_path) if metadata else None  # type: ignore[unreachable]
        raise ValidationError(
            f"Expected dict for dataclass {schema.__name__}",
            field_path=field_path,
            expected_type=dict,
            actual_value=config,
            source_location=source_loc,
        )

    # Get all fields from the dataclass
    schema_fields = {f.name: f for f in dataclasses.fields(schema)}

    # Check for required fields
    for field_name, field_info in schema_fields.items():
        current_path = f"{field_path}.{field_name}" if field_path else field_name

        # Check if field is missing
        if field_name not in config:
            # Field has default or default_factory -> optional
            if field_info.default is not dataclasses.MISSING or field_info.default_factory is not dataclasses.MISSING:
                continue
            # No default -> required
            source_loc = _get_source_location(metadata, field_path) if metadata else None
            raise ValidationError(
                f"Missing required field '{field_name}'",
                field_path=current_path,
                expected_type=field_info.type,  # type: ignore[arg-type]
                source_location=source_loc,
            )

        # Validate the field value
        _validate_field(
            config[field_name],
            field_info.type,  # type: ignore[arg-type]
            current_path,
            metadata,
            allow_missing=allow_missing,
        )

    # Check for unexpected fields - only if strict mode
    if strict:
        unexpected_fields = set(config.keys()) - set(schema_fields.keys())
        # Filter out sparkwheel special keys
        special_keys = {"_target_", "_disabled_", "_requires_", "_mode_"}
        unexpected_fields = unexpected_fields - special_keys

        if unexpected_fields:
            first_unexpected = sorted(unexpected_fields)[0]
            current_path = f"{field_path}.{first_unexpected}" if field_path else first_unexpected
            source_loc = _get_source_location(metadata, current_path) if metadata else None
            raise ValidationError(
                f"Unexpected field '{first_unexpected}' not in schema {schema.__name__}",
                field_path=current_path,
                source_location=source_loc,
            )

    # Run custom validators
    _run_validators(config, schema, field_path, metadata)

validator(func)

Decorator to mark a method as a validator.

Validators run after type checking and can validate single fields or relationships between fields. Raise ValueError on failure.

Example

@dataclass class Config: lr: float start: int end: int

@validator
def check_lr(self):
    if not (0 < self.lr < 1):
        raise ValueError("lr must be between 0 and 1")

@validator
def check_range(self):
    if self.end <= self.start:
        raise ValueError("end must be > start")
Source code in src/sparkwheel/schema.py
def validator(func):
    """Decorator to mark a method as a validator.

    Validators run after type checking and can validate single fields
    or relationships between fields. Raise ValueError on failure.

    Example:
        @dataclass
        class Config:
            lr: float
            start: int
            end: int

            @validator
            def check_lr(self):
                if not (0 < self.lr < 1):
                    raise ValueError("lr must be between 0 and 1")

            @validator
            def check_range(self):
                if self.end <= self.start:
                    raise ValueError("end must be > start")
    """
    func.__is_validator__ = True
    return func