from abc import ABC
from collections.abc import Callable
from pathlib import Path
from typing import Any, ClassVar, Literal
import pandas as pd
import yaml
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from stilt.receptors import Receptor
[docs]
class Resolution(BaseModel):
xres: float
yres: float
def __str__(self) -> str:
return f"{self.xres}x{self.yres}"
[docs]
class SystemParams(BaseModel):
stilt_wd: Path
output_wd: Path | None = None
lib_loc: Path | int | None = None
@model_validator(mode="after")
def _set_system_defaults(self) -> Self:
"""Set default values for system parameters."""
if self.output_wd is None:
self.output_wd = self.stilt_wd / "out"
return self
[docs]
class MetParams(BaseModel):
met_path: Path
met_file_format: str
met_file_tres: str
met_subgrid_buffer: float = 0.1
met_subgrid_enable: bool = False
met_subgrid_levels: int | None = None
n_met_min: int = 1
[docs]
class ModelParams(BaseModel):
n_hours: int = -24
numpar: int = 1000
rm_dat: bool = True
run_foot: bool = True
run_trajec: bool = True
simulation_id: str | list[str] | None = None
timeout: int = 3600
varsiwant: list[
Literal[
"time",
"indx",
"long",
"lati",
"zagl",
"sigw",
"tlgr",
"zsfc",
"icdx",
"temp",
"samt",
"foot",
"shtf",
"tcld",
"dmas",
"dens",
"rhfr",
"sphu",
"lcld",
"zloc",
"dswf",
"wout",
"mlht",
"rain",
"crai",
"pres",
"whtf",
"temz",
"zfx1",
]
] = Field(
default_factory=lambda: [
"time",
"indx",
"long",
"lati",
"zagl",
"foot",
"mlht",
"pres",
"dens",
"samt",
"sigw",
"tlgr",
]
)
@model_validator(mode="after")
def _validate_run_flags(self) -> Self:
"""Ensure at least one of `run_trajec` or `run_foot` is True."""
if not self.run_trajec and not self.run_foot:
raise ValueError("Nothing to do: set `run_trajec` or `run_foot` to True")
return self
[docs]
class TransportParams(BaseModel):
capemin: float = -1.0
cmass: int = 0
conage: int = 48
cpack: int = 1
delt: int = 1
dxf: int = 1
dyf: int = 1
dzf: float = 0.01
efile: str = ""
emisshrs: float = 0.01
frhmax: float = 3.0
frhs: float = 1.0
frme: float = 0.1
frmr: float = 0.0
frts: float = 0.1
frvs: float = 0.1
hscale: int = 10800
ichem: int = 8
idsp: int = 2
initd: int = 0
k10m: int = 1
kagl: int = 1
kbls: int = 1
kblt: int = 5
kdef: int = 0
khinp: int = 0
khmax: int = 9999
kmix0: int = 250
kmixd: int = 3
kmsl: int = 0
kpuff: int = 0
krand: int = 4
krnd: int = 6
kspl: int = 1
kwet: int = 1
kzmix: int = 0
maxdim: int = 1
maxpar: int | None = None
mgmin: int = 10
mhrs: int = 9999
nbptyp: int = 1
ncycl: int = 0
ndump: int = 0
ninit: int = 1
nstr: int = 0
nturb: int = 0
nver: int = 0
outdt: int = 0
p10f: int = 1
pinbc: str = ""
pinpf: str = ""
poutf: str = ""
qcycle: int = 0
rhb: float = 80.0
rht: float = 60.0
splitf: int = 1
tkerd: float = 0.18
tkern: float = 0.18
tlfrac: float = 0.1
tout: float = 0.0
tratio: float = 0.75
tvmix: float = 1.0
veght: float = 0.5
vscale: int = 200
vscaleu: int = 200
vscales: int = -1
w_option: int = 0
wbbh: int = 0
wbwf: int = 0
wbwr: int = 0
wvert: bool = False
z_top: float = 25000.0
zicontroltf: int = 0
ziscale: int | list[int] = 0
[docs]
class ErrorParams(BaseModel):
siguverr: float | None = None
tluverr: float | None = None
zcoruverr: float | None = None
horcoruverr: float | None = None
sigzierr: float | None = None
tlzierr: float | None = None
horcorzierr: float | None = None
XYERR_PARAMS: ClassVar[tuple[str, ...]] = (
"siguverr",
"tluverr",
"zcoruverr",
"horcoruverr",
)
ZIERR_PARAMS: ClassVar[tuple[str, ...]] = ("sigzierr", "tlzierr", "horcorzierr")
@model_validator(mode="after")
def _validate_error_params(self) -> Self:
"""
Validate error parameters to ensure they are either all set or all None
"""
xy_params = self.xyerr_params()
zi_params = self.zierr_params()
for name, params in [("XY", xy_params), ("ZI", zi_params)]:
is_na = [pd.isna(v) for v in params.values()]
if any(is_na) and not all(is_na):
raise ValueError(
f"Inconsistent {name} error parameters: all must be set or all must be None"
)
return self
[docs]
def xyerr_params(self) -> dict[str, float | None]:
"""
Get the XY error parameters as a dictionary.
"""
return {param: getattr(self, param) for param in self.XYERR_PARAMS}
[docs]
def zierr_params(self) -> dict[str, float | None]:
"""
Get the ZI error parameters as a dictionary.
"""
return {param: getattr(self, param) for param in self.ZIERR_PARAMS}
@property
def winderrtf(self) -> int:
"""
Determine the winderrtf flag based on the presence of error parameters.
Returns
-------
int
Wind error control flag.
0 : No error parameters are set
1 : ZI error parameters are set
2 : XY error parameters are set
3 : Both XY and ZI error parameters are set
"""
xyerr = all(self.xyerr_params().values())
zierr = all(self.zierr_params().values())
return 2 * xyerr + zierr
[docs]
class UserFuncParams(BaseModel):
before_footprint: Callable | Path | None = None
@field_validator("before_footprint", mode="before")
@classmethod
def _load_before_footprint(cls, v: Any) -> Any:
"""Ensure before_footprint is a callable or None."""
if isinstance(v, (str, Path)):
# Load the function from the specified path
p = Path(v)
if p.suffix.lower().endswith("r"):
# Pass the R path
return v
elif p.suffix.lower().endswith("py"):
# Load the Python function
raise NotImplementedError(
"Loading Python functions from file is not implemented yet."
)
else:
raise ValueError(f"Unsupported file type: {p.suffix}")
return v
[docs]
class BaseConfig(
ABC,
SystemParams,
FootprintParams,
MetParams,
ModelParams,
TransportParams,
ErrorParams,
UserFuncParams,
):
"""
STILT Configuration
This class consolidates all configuration parameters for the STILT model,
including system settings, footprint parameters, meteorological data,
model specifics, transport settings, error handling, and user-defined
functions.
"""
[docs]
class Config:
# Allows Pydantic to work with custom classes like Receptor
arbitrary_types_allowed = True
@staticmethod
def _load_yaml_params(path: str | Path) -> dict[str, Any]:
"""
Load a YAML config file and return its contents as a dictionary.
"""
with Path(path).open() as f:
config = yaml.safe_load(f)
# Flatten the config dictionary
params = {}
for key, value in config.items():
if isinstance(value, dict):
for subkey, subvalue in value.items():
params[f"{subkey}"] = subvalue
else:
params[key] = value
return params
[docs]
@classmethod
def from_path(cls, path: str | Path) -> Self:
"""
Load STILT configuration from a YAML file.
"""
params = cls._load_yaml_params(path)
return cls(**params)
@model_validator(mode="after")
def _validate_base_config(self) -> Self:
"""Perform validation that depends on multiple fields."""
# Check if there's anything to run
if not self.run_trajec and not self.run_foot:
raise ValueError("Nothing to do: set run_trajec or run_foot to True")
# Check for grid parameters if running footprint or subgrid met
if self.run_foot or self.met_subgrid_enable:
required_grid_params = ["xmn", "xmx", "xres", "ymn", "ymx"]
if any(getattr(self, arg) is None for arg in required_grid_params):
raise ValueError(
"xmn, xmx, xres, ymn, and ymx must be specified when "
"met_subgrid_enable or run_foot is True"
)
return self
@model_validator(mode="after")
def _set_config_defaults(self) -> Self:
"""Set default values for configuration parameters."""
# Set default for maxpar if not provided
if self.maxpar is None:
self.maxpar = self.numpar
return self
[docs]
def system_params(self) -> dict[str, Any]:
return {attr: getattr(self, attr) for attr in SystemParams.model_fields}
[docs]
def met_params(self) -> dict[str, Any]:
return {attr: getattr(self, attr) for attr in MetParams.model_fields}
[docs]
def model_params(self) -> dict[str, Any]:
return {attr: getattr(self, attr) for attr in ModelParams.model_fields}
[docs]
def transport_params(self) -> dict[str, Any]:
return {
attr: getattr(self, attr) for attr in TransportParams.model_fields
}
[docs]
def error_params(self) -> dict[str, Any]:
return {attr: getattr(self, attr) for attr in ErrorParams.model_fields}
[docs]
def user_funcs(self) -> dict[str, Any]:
return {
attr: getattr(self, attr) for attr in UserFuncParams.model_fields
}
[docs]
class SimulationConfig(BaseConfig):
receptor: Receptor
[docs]
@classmethod
def from_path(cls, path: str | Path) -> Self:
# Open simulation config like a model config
model_config = ModelConfig.from_path(path)
# Then extract the receptor
receptor = model_config.receptors[0]
return cls(receptor=receptor, **model_config.model_dump())
@field_validator("simulation_id", mode="after")
@classmethod
def _validate_simulation_id(cls, simulation_id) -> str:
if not simulation_id:
simulation_id = cls.receptor.id
elif not isinstance(simulation_id, str):
raise TypeError("simulation_id must be a string")
return simulation_id
[docs]
def to_model_config(self) -> "ModelConfig":
config = self.model_dump()
receptor = config.pop("receptor")
return ModelConfig(receptors=[receptor], **config)
[docs]
class ModelConfig(BaseConfig):
receptors: list[Receptor]
[docs]
@classmethod
def from_path(cls, path: str | Path) -> Self:
params = cls._load_yaml_params(path)
if "stilt_wd" not in params:
params["stilt_wd"] = Path(path).parent
return cls(**params)
@model_validator(mode="before")
@classmethod
def _load_receptors(cls, data) -> Self:
"""
Validates and loads receptors. If a path is provided, it loads
receptors from the corresponding CSV file.
"""
receptors = data.get("receptors")
if isinstance(receptors, (str, Path)):
# If the input is a path, load from the file.
receptor_path = Path(receptors)
if not receptor_path.is_absolute():
receptor_path = Path(data.get("stilt_wd")) / receptor_path
data["receptors"] = Receptor.load_receptors_from_csv(receptor_path)
return data
@model_validator(mode="after")
def _validate_model_config(self) -> Self:
"""Validate the model configuration."""
# Check if simulation_id is set
if isinstance(self.simulation_id, str) and len(self.receptors) > 1:
raise ValueError(
"Simulation ID must be specified for each receptor or be left blank."
)
return self
[docs]
def to_file(self):
# Write out receptor information to csv
# Write out config
raise NotImplementedError
[docs]
def build_simulation_configs(self) -> list[SimulationConfig]:
"""
Build a list of SimulationConfig objects, one for each receptor.
"""
raise NotImplementedError
config = self.model_dump()
receptors = config.pop("receptors")
simulation_id = config.pop("simulation_id")
if isinstance(simulation_id, list):
# TODO
pass
return [SimulationConfig(receptor=receptor, **config) for receptor in receptors]