Source code for stilt.model

"""
Stochastic Time-Inverted Lagrangian Transport (STILT) Model.

A python implementation of the R-STILT model framework.
"""

import logging
import os
import tempfile
from collections.abc import Iterable
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING

from stilt.collections import (
    FootprintCollection,
    ReceptorCollection,
    SimulationCollection,
    TrajectoryCollection,
)
from stilt.config import (
    ModelConfig,
    RuntimeSettings,
    foot_names,
    resolve_runtime_settings,
)
from stilt.config.model import _config_or_kwargs
from stilt.errors import ConfigValidationError
from stilt.execution import (
    Executor,
    JobHandle,
    LocalHandle,
    SlurmExecutor,
    get_executor,
    sigterm_as_interrupt,
)
from stilt.index import IndexCounts, SimulationIndex
from stilt.index.factory import resolve_index
from stilt.meteorology import MetStream
from stilt.receptors import Receptor
from stilt.simulation import SimID
from stilt.storage import (
    ProjectFiles,
    ProjectLayout,
    Storage,
    make_store,
    project_slug,
)

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from stilt.visualization import ModelPlotAccessor


@dataclass
class _RebuildOnCompleteHandle:
    """JobHandle wrapper that rebuilds the output index once after push-mode completion."""

    _inner: JobHandle
    _index: SimulationIndex
    _completed: bool = field(default=False, init=False)

    @property
    def job_id(self) -> str:
        return self._inner.job_id

    def wait(self) -> None:
        if self._completed:
            return
        self._inner.wait()
        self._index.rebuild()
        self._completed = True


def _wrap_wait_with_rebuild(handle: JobHandle, index: SimulationIndex) -> JobHandle:
    """Wrap one handle so push-mode completion rebuilds output state once."""
    return _RebuildOnCompleteHandle(handle, index)  # type: ignore[return-value]


[docs] class Model: """ Science-facing STILT project interface. ``Model`` is the primary Python entry point for configuring a STILT project, running one-off simulations, and loading simulation results. It also owns the output project index used by batch, HPC, and cloud execution. Pull-mode workers and streaming consumers operate against the model's output index directly when the configured index supports claims. Claim-worker control is exposed primarily through the CLI and :func:`stilt.execution.pull_simulations` for advanced Python use. Parameters ---------- project : str or Path or None, optional Local project root or cloud URI used to identify the model. receptors : Receptor or iterable or str or Path or None, optional In-memory receptors or a path to a receptor CSV. config : ModelConfig or None, optional In-memory project config. When omitted, config is loaded lazily from storage. output_dir : str or Path or None, optional Output root. Defaults to ``project``. compute_root : str or Path or None, optional Local parent directory where worker simulation directories are created. runtime : RuntimeSettings or None, optional Runtime-only deployment settings such as cache roots and DB URLs. **kwargs Forwarded to :class:`~stilt.config.ModelConfig` when *config* is not provided. Any valid ``ModelConfig`` field name is accepted (e.g. ``n_hours=-48``, ``numpar=500``). Mutually exclusive with *config* — passing both raises ``TypeError``. Attributes ---------- config : ModelConfig Lazily loaded project configuration. receptors : ReceptorCollection Science-facing receptor accessor. simulations : SimulationCollection Registered simulation handles backed by the output index. mets : dict[str, MetStream] Named meteorology streams. trajectories : TrajectoryCollection Cross-simulation trajectory accessor. footprints : FootprintCollection Cross-simulation footprint accessor namespace. plot : ModelPlotAccessor Plotting namespace for model summaries and outputs. index : SimulationIndex Output simulation registry used by the coordinator and output queries. storage : Storage Output bootstrap and output storage facade. runtime : RuntimeSettings Runtime-only deployment settings. layout : ProjectLayout Resolved project/output refs (``project_dir``, ``output_dir``, ``project_root``, ``output_root``, ``is_cloud_project``, ``is_cloud_output``). """ def __init__( self, project: str | Path | None = None, receptors: Receptor | Iterable | str | Path | None = None, config: ModelConfig | None = None, output_dir: str | Path | None = None, compute_root: str | Path | None = None, runtime: RuntimeSettings | None = None, **kwargs, ): # Runtime settings are not persisted in the project config, but they may be # needed to resolve the output index backend, so resolve them first. self.runtime = resolve_runtime_settings(runtime) # The project layout resolver also handles output_dir defaulting and cloud URI parsing. self.layout = ProjectLayout.resolve(project, output_dir) self.project = self._project_name(self.layout.project_ref) self.storage = Storage( project_dir=self.layout.project_dir, output_dir=self.layout.output_dir, store=make_store(self.layout.output_root, cache_dir=self.runtime.cache_dir), is_cloud_project=self.layout.is_cloud_project, ) # Compute root is not persisted in the project config, but it may be needed to # resolve the met archive and worker sim dirs, so resolve it early. self.compute_root = self._resolve_compute_root(compute_root) self._config = _config_or_kwargs(config, kwargs, ModelConfig) # Lazy caches for expensive properties self._mets: dict[str, MetStream] | None = None self._receptors = receptors # may be None, a Receptor, or an iterable of Receptors; resolved lazily in the accessor self._index: SimulationIndex | None = None self._simulations: SimulationCollection | None = None self._trajectories: TrajectoryCollection | None = None self._footprints: FootprintCollection | None = None self._plot: ModelPlotAccessor | None = None def __repr__(self) -> str: """Compact developer-facing model representation.""" return ( f"Model(project={self.project!r}, output_root={self.layout.output_root!r})" ) @property def index(self) -> SimulationIndex: """Output simulation registry for this model.""" if self._index is None: self._index = resolve_index( None, output_root=self.layout.output_root, runtime=self.runtime, builtin_backend="postgres" if self.layout.is_cloud_output else "sqlite", ) return self._index def _resolve_compute_root(self, compute_root: str | Path | None) -> Path: """Return the parent directory under which worker sim dirs are created.""" if compute_root is not None: raw = os.path.expandvars(os.path.expanduser(str(compute_root))) return Path(raw).resolve() if self.runtime.compute_root is not None: return self.runtime.compute_root.expanduser().resolve() if not self.layout.is_cloud_output: return ProjectFiles(self.layout.output_dir).by_id_dir tmp_root = os.environ.get("TMPDIR") or tempfile.gettempdir() return Path(tmp_root) / "pystilt" def _project_name(self, project_str: str) -> str: """Return a human-readable project name for local paths and cloud URIs.""" if self.layout.is_cloud_project and project_str: return project_slug(project_str) return self.layout.project_dir.name @property def config(self) -> ModelConfig: """ Project :class:`ModelConfig`, loaded from ``config.yaml`` if not provided at construction. Returns ------- ModelConfig """ if self._config is None: self._config = self.storage.load_config() return self._config @property def receptors(self) -> ReceptorCollection: """ Sequence-like receptor accessor for this project. Access by position (``model.receptors[0]``) or by receptor identifier (``model.receptors[sim_id.receptor_id]``). Returns ------- ReceptorCollection """ if not isinstance(self._receptors, ReceptorCollection): self._receptors = ReceptorCollection( self._receptors, storage=self.storage, ) return self._receptors
[docs] def register_pending( self, receptors: Iterable[Receptor] | None = None, *, scene_id: str | None = None, ) -> list[str]: """ Persist model inputs and register one batch of pending work. This is the output registration boundary shared by local runs, queue-backed workers, and the CLI ``stilt register`` command. Registration is idempotent: calling this multiple times with the same receptors does not create duplicate index entries. The underlying index backend uses an upsert (``ON CONFLICT DO UPDATE``) so existing rows are updated in place rather than duplicated. Parameters ---------- receptors : iterable of Receptor, optional Receptors to register. When omitted, the model's current receptor inputs are used. scene_id : str, optional Optional grouping label stored on all registered simulations. Returns ------- list[str] Registered simulation identifiers. """ recs = tuple(receptors) if receptors is not None else tuple(self.receptors) source_path = self.receptors.source_path if receptors is None else None self.storage.publish_config(self.config) self.storage.publish_receptors( None if source_path is not None else list(recs), source_path=source_path, ) pairs = tuple( (str(SimID.from_parts(met_name, receptor)), receptor) for met_name in self.mets for receptor in recs ) self.index.register( list(pairs), footprint_names=foot_names(dict(self.config.footprints)), scene_id=scene_id, ) # Registration updates the output receptor/index truth, so drop # cached accessors — next access rebuilds from that output surface. self._receptors = None self._simulations = None return [sim_id for sim_id, _ in pairs]
@property def simulations(self) -> SimulationCollection: """ Lazy simulation collection for registered simulations. Returns ------- SimulationCollection """ if self._simulations is None: params = self.config.to_stilt_params() self._simulations = SimulationCollection( self.layout.output_dir, params, self.mets, self.receptors, list(self.config.footprints), self.index, self.storage.store, ) return self._simulations @property def plot(self) -> "ModelPlotAccessor": """ Plotting namespace for this model (e.g. ``model.plot.availability()``). Returns ------- ModelPlotAccessor """ if self._plot is None: from stilt.visualization import ModelPlotAccessor self._plot = ModelPlotAccessor(self) return self._plot @property def trajectories(self) -> TrajectoryCollection: """Trajectory accessor for querying and loading simulation outputs.""" if self._trajectories is None: self._trajectories = TrajectoryCollection(self) return self._trajectories @property def footprints(self) -> FootprintCollection: """Footprint accessor namespace for named footprint outputs.""" if self._footprints is None: self._footprints = FootprintCollection(self) return self._footprints @property def mets(self) -> dict[str, MetStream]: """Named met streams, resolved from config.""" if self._mets is None: self._mets = {} for name, config in self.config.mets.items(): self._mets[name] = MetStream( name, directory=config.directory, file_format=config.file_format, file_tres=config.file_tres, n_min=config.n_min, source_type=config.source, source_kwargs=config.source_kwargs, backend=config.backend, subgrid_enable=config.subgrid_enable, subgrid_bounds=config.subgrid_bounds, subgrid_buffer=config.subgrid_buffer, subgrid_levels=config.subgrid_levels, subgrid_dir=config.subgrid_dir, ) return self._mets # -- Queries --------------------------------------------------------------
[docs] def status(self, scene_id: str | None = None) -> IndexCounts: """Return cheap aggregate counts for the current project registry.""" return self.index.counts(scene_id=scene_id)
[docs] def scene_counts(self) -> dict[str, IndexCounts]: """Return grouped aggregate counts for each registered scene.""" return self.index.scene_counts()
# -- Execution ------------------------------------------------------------
[docs] def run( self, executor: Executor | None = None, skip_existing: bool | None = None, rebuild: bool | None = None, wait: bool = True, ) -> JobHandle: """ Register pending work and start workers to drain it. When ``config.footprints`` contains one or more footprint configurations, workers auto-run HYSPLIT trajectories as needed and then compute footprints in a single pass. When no footprint configs are defined, only HYSPLIT trajectories are dispatched. Workers either drain a claim-capable project index via :func:`~stilt.execution.pull_simulations` or consume immutable chunk shards via :func:`~stilt.execution.push_simulations`; the coordinator registers simulations and starts workers, then optionally blocks. Parameters ---------- executor : Executor, optional Override the executor resolved from ``config.execution``. skip_existing : bool or None, optional Skip simulations that already have output. ``None`` (default) reads the value from ``config.skip_existing``. rebuild : bool or None, optional Rebuild the output index from outputs before planning work. ``None`` (default) uses auto mode: rebuild when skip-existing is enabled, otherwise skip the pre-run rebuild. wait : bool, optional If ``True`` (default), block until all workers finish and rebuild output index from outputs. If ``False``, return the :class:`JobHandle` immediately — suitable for fire-and-forget Slurm runs. Returns ------- JobHandle A handle you can call ``.wait()`` on later, or ignore. Notes ----- ``Model.run()`` is the main science-facing execution path for local or notebook use. For claim-based service workflows, use :func:`stilt.execution.pull_simulations` against a model configured with a PostgreSQL-backed index, or use the CLI ``stilt pull-worker`` and ``stilt serve`` commands. """ index = self.index # Execution mutates output index state, so any cached simulation view must # be rebuilt after each coordinator run. self._simulations = None resolved_skip = ( skip_existing if skip_existing is not None else self.config.skip_existing ) resolved_rebuild = resolved_skip if rebuild is None else rebuild resolved_executor = executor or get_executor(self.config.execution or {}) if isinstance(resolved_executor, SlurmExecutor) and ( self.layout.is_cloud_project or self.layout.is_cloud_output ): raise ConfigValidationError( "Slurm execution currently requires both project and output roots " "to be local paths." ) dispatch = resolved_executor.dispatch foot_configs = dict(self.config.footprints) sim_ids = self.register_pending() if not sim_ids: logger.info("run: no receptors configured — nothing to do") return LocalHandle() if resolved_rebuild: logger.info("run: rebuilding output index before planning") index.rebuild() index.reset_to_pending(sim_ids, clear_outputs=not resolved_skip) pending = index.pending_trajectories() if not pending: logger.info("run: all simulations already complete — nothing to do") return LocalHandle() names = list(foot_configs.keys()) logger.info( "run(%s): starting %s workers for %d simulations", ", ".join(names) if names else "trajectories", dispatch, len(pending), ) handle = resolved_executor.start( pending, project=self.layout.output_root, output_dir=str(self.layout.output_dir), compute_root=str(self.compute_root), skip_existing=resolved_skip, ) if dispatch == "push": handle = _wrap_wait_with_rebuild(handle, index) if wait: logger.info("run: waiting for workers to finish...") try: with sigterm_as_interrupt(): handle.wait() except KeyboardInterrupt: logger.warning("run interrupted") raise return handle