Source code for stilt.config.model

"""Project-level config models and YAML/doc helpers."""

from __future__ import annotations

from collections.abc import Iterator
from pathlib import Path
from typing import Any

import yaml
from pydantic import BaseModel, ConfigDict, model_validator
from typing_extensions import Self

from .fields import T, _field_meta, cfg_field
from .footprint import FootprintConfig
from .meteorology import MetConfig
from .params import ErrorParams, ModelParams, STILTParams, TransportParams
from .spatial import Bounds, Grid

_GRID_KEYS = frozenset({"xmin", "xmax", "ymin", "ymax", "xres", "yres", "projection"})
_REQUIRED_GRID_KEYS = frozenset({"xmin", "xmax", "ymin", "ymax", "xres", "yres"})


[docs] class ModelConfig(STILTParams): """Project-level config: STILT params plus met and footprint definitions.""" model_config = ConfigDict(extra="forbid") footprints: dict[str, FootprintConfig] = cfg_field( default_factory=dict, description="Named footprint products available for this model configuration.", ) grids: dict[str, Grid] = cfg_field( default_factory=dict, description="Named grids referenced by footprint definitions.", ) mets: dict[str, MetConfig] = cfg_field( default_factory=dict, description="Named meteorology streams available to the model.", ) execution: dict[str, Any] = cfg_field( default_factory=dict, description="Execution backend settings such as local, Slurm, or Kubernetes options.", visibility="advanced", ) skip_existing: bool = cfg_field( True, description=( "Skip simulations that already have output. " "Set False to force re-run all simulations. " "Can be overridden at call time via model.run(skip_existing=...)." ), )
[docs] @classmethod def basic( cls, *, mets: dict[str, MetConfig], n_hours: int = -24, numpar: int = 200, footprints: dict[str, FootprintConfig] | None = None, skip_existing: bool = True, **kwargs: Any, ) -> Self: """Build a science-facing config with the most common controls.""" return cls( mets=mets, n_hours=n_hours, numpar=numpar, footprints=footprints or {}, skip_existing=skip_existing, **kwargs, )
@model_validator(mode="before") @classmethod def _resolve_nested_configs(cls, data: dict) -> dict: """Expand named grid references in footprint configs before validation.""" if not isinstance(data, dict): return data grids_raw = data.get("grids") or {} fp_raw = data.get("footprints") or {} if fp_raw: resolved = {} for name, cfg in fp_raw.items(): if isinstance(cfg, dict): cfg = dict(cfg) grid_ref = cfg.get("grid") if isinstance(grid_ref, str): if grid_ref not in grids_raw: raise ValueError( f"Footprint '{name}' references unknown grid '{grid_ref}'" ) cfg["grid"] = grids_raw[grid_ref] elif grid_ref is None: shorthand_keys = _REQUIRED_GRID_KEYS & set(cfg) if shorthand_keys == _REQUIRED_GRID_KEYS: cfg["grid"] = { key: cfg.pop(key) for key in _GRID_KEYS if key in cfg } else: raise ValueError( f"Footprint '{name}' is missing a 'grid' key." ) resolved[name] = cfg data = {**data, "footprints": resolved} return data @model_validator(mode="after") def _validate_mets(self) -> Self: """Ensure each configured meteorology stream has a unique name.""" if not self.mets: raise ValueError( "ModelConfig.mets must contain at least one meteorology configuration" ) bad_keys = [k for k in self.mets if not k.isalnum()] if bad_keys: raise ValueError( f"Met keys must be alphanumeric (no underscores or special chars), got: {bad_keys}" ) return self
[docs] def to_yaml(self, path: str | Path) -> None: """Write the model config to a YAML file.""" path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) data = self.model_dump(mode="json", exclude=set()) for key in ("mets", "grids", "footprints", "execution"): if not data.get(key): del data[key] with path.open("w") as f: yaml.safe_dump(data, f, default_flow_style=False, sort_keys=False)
[docs] def to_stilt_params(self) -> STILTParams: """Project this model config onto the pure STILT run-parameter surface.""" data = self.model_dump( exclude={"footprints", "grids", "mets", "execution", "skip_existing"} ) return STILTParams(**data)
[docs] @classmethod def from_yaml(cls, path: str | Path) -> Self: """Load a model config from a YAML file.""" path = Path(path) with path.open() as f: raw: dict = yaml.safe_load(f) or {} return cls.model_validate(raw)
CONFIG_DOC_MODELS: tuple[type[BaseModel], ...] = ( Bounds, Grid, MetConfig, FootprintConfig, ModelParams, TransportParams, ErrorParams, ModelConfig, ) def iter_documented_config_fields( *models: type[BaseModel], include_internal: bool = False, ) -> Iterator[tuple[type[BaseModel], str, Any]]: """Yield config fields in declaration order for docs or UI generation.""" if not models: models = CONFIG_DOC_MODELS for model in models: for name, field in model.model_fields.items(): meta = _resolved_field_meta(model, name) if meta["visibility"] == "internal" and not include_internal: continue yield model, name, field def _resolved_field_meta(model: type[BaseModel], name: str) -> dict[str, Any]: """Return field metadata after applying class-level routing defaults.""" field = model.model_fields[name] meta = _field_meta(field) return { **meta, "target": meta.get("target", getattr(model, "DEFAULT_TARGET", None)), "visibility": meta.get("visibility", "public"), "namelist": meta.get("namelist", name), } def _collect_target_entries( params: BaseModel, model: type[BaseModel], *, target: str, ) -> dict[str, Any]: """Collect config fields whose metadata routes them to one output target.""" entries: dict[str, Any] = {} for name in model.model_fields: meta = _resolved_field_meta(model, name) if meta["target"] != target: continue value = getattr(params, name) if value is None: continue entries[meta["namelist"]] = value return entries def build_setup_entries(params: STILTParams) -> dict[str, Any]: """Collect fields that belong in HYSPLIT ``SETUP.CFG``.""" entries: dict[str, Any] = {} entries.update(_collect_target_entries(params, ModelParams, target="setup")) entries.update(_collect_target_entries(params, TransportParams, target="setup")) return entries def build_control_entries(params: STILTParams) -> dict[str, Any]: """Collect fields that belong in HYSPLIT ``CONTROL``.""" return _collect_target_entries(params, TransportParams, target="control") def _config_or_kwargs( config: T | None, kwargs: dict, cls: type[T], ) -> T | None: """Resolve a config-or-kwargs pair.""" if config is not None and kwargs: raise TypeError( f"Cannot pass both a {cls.__name__} instance and keyword arguments." ) if kwargs: return cls(**kwargs) return config __all__ = [ "CONFIG_DOC_MODELS", "ModelConfig", "build_control_entries", "build_setup_entries", "iter_documented_config_fields", ]