Source code for stilt.index.protocol

"""
Output simulation index types for PYSTILT.

The index layer tracks one row per simulation and sits between:

- queue registration APIs such as ``Model.register_pending()``,
- worker result recording,
- read paths such as ``model.status()`` and collection-level queries.

`OutputSummary` is the light output presence summary for one simulation.
`IndexCounts` is the cheap aggregate view over a whole index or one scene.
"""

from __future__ import annotations

from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Protocol, runtime_checkable

from stilt.receptors import Receptor
from stilt.simulation import SimID

if TYPE_CHECKING:
    from stilt.execution import SimulationResult

COMPLETE_FOOTPRINT_STATUSES = frozenset({"complete", "complete-empty"})


def _normalize_footprint_names(
    footprint_names: list[str] | None = None,
) -> list[str]:
    """Return sorted unique footprint names for output index storage."""
    return sorted(set(footprint_names or []))


def _normalize_registration_pairs(
    pairs: str | tuple[str, Receptor] | list[tuple[str, Receptor]],
    receptor: Receptor | None = None,
) -> list[tuple[str, Receptor]]:
    """Normalize one-or-many register inputs to a list of pairs."""
    if isinstance(pairs, str):
        if receptor is None:
            raise TypeError("register() requires a receptor when called with sim_id")
        return [(pairs, receptor)]
    if isinstance(pairs, tuple):
        if receptor is not None:
            raise TypeError(
                "register() accepts either (sim_id, receptor) or a list of pairs"
            )
        sim_id, receptor = pairs
        return [(sim_id, receptor)]
    if receptor is not None:
        raise TypeError(
            "register() accepts either sim_id plus receptor or a list of pairs"
        )
    return list(pairs)


[docs] @dataclass(frozen=True, slots=True) class OutputSummary: """Lightweight output presence summary for one simulation.""" traj_present: bool = False error_traj_present: bool = False log_present: bool = False footprints: dict[str, str] = field(default_factory=dict)
[docs] def footprint_complete(self, name: str) -> bool: """Return whether one named footprint has reached a terminal state.""" return self.footprints.get(name) in COMPLETE_FOOTPRINT_STATUSES
[docs] def footprints_complete(self, names: Iterable[str]) -> bool: """Return whether all requested footprints have reached terminal states.""" return all(self.footprint_complete(name) for name in names)
[docs] def outputs_complete(self, footprint_names: Iterable[str]) -> bool: """Return whether trajectory and all requested footprints are complete.""" return self.traj_present and self.footprints_complete(footprint_names)
[docs] def pending_footprints(self, names: Iterable[str]) -> list[str]: """Return configured footprint names that still need work.""" return [name for name in names if not self.footprint_complete(name)]
[docs] def needs_work( self, footprint_names: Iterable[str], *, skip_existing: bool, ) -> bool: """Return whether this simulation still requires worker execution.""" if not skip_existing: return True targets = list(footprint_names) if targets: return not self.outputs_complete(targets) return not self.traj_present
[docs] @dataclass(frozen=True, slots=True) class IndexCounts: """Cheap aggregate counts for one output simulation index view.""" total: int = 0 completed: int = 0 running: int = 0 pending: int = 0 failed: int = 0
[docs] @runtime_checkable class SimulationIndex(Protocol): """Output simulation registry surface for model, CLI, and workers."""
[docs] def record(self, result: SimulationResult) -> None: """Record one completed worker result into the durable index.""" ...
[docs] def register( self, pairs: str | tuple[str, Receptor] | list[tuple[str, Receptor]], receptor: Receptor | None = None, footprint_names: list[str] | None = None, scene_id: str | None = None, ) -> None: """Register one or many simulations as known rows in the index.""" ...
[docs] def sim_ids(self) -> list[str]: """Return all registered simulation identifiers in stable order.""" ...
[docs] def has(self, sim_id: SimID | str) -> bool: """Return whether one simulation id is already registered.""" ...
[docs] def count(self) -> int: """Return the total number of registered simulation rows.""" ...
[docs] def counts(self, scene_id: str | None = None) -> IndexCounts: """Return aggregate queue counts for the whole index or one scene.""" ...
[docs] def scene_counts(self) -> dict[str, IndexCounts]: """Return aggregate counts grouped by non-null scene id.""" ...
[docs] def receptors_for(self, sim_ids: list[str]) -> dict[str, Receptor]: """Return receptors keyed by simulation id for the requested rows.""" ...
[docs] def reset_to_pending( self, sim_ids: list[str], *, clear_outputs: bool = False, ) -> None: """Reset matching non-running rows back to pending state.""" ...
[docs] def pending_trajectories(self) -> list[str]: """Return simulation ids whose trajectory work is still pending.""" ...
[docs] def summaries( self, sim_ids: list[str] | None = None, ) -> dict[str, OutputSummary]: """Return output summaries for all rows or one requested subset.""" ...
[docs] def rebuild(self) -> None: """Rebuild output index rows by rescanning outputs.""" ...
__all__ = [ "OutputSummary", "COMPLETE_FOOTPRINT_STATUSES", "IndexCounts", "SimulationIndex", ]