Source code for stilt.observations.weighting

"""Generic particle-weighting interfaces for observation workflows."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol, cast

import numpy as np
import pandas as pd

from .operators import VerticalOperator

if TYPE_CHECKING:
    from .observation import Observation


# Modes that require scaling by n_particles.
#
# The footprint calculator divides the aggregated particle influence by
# n_particles, which means that fractional pressure-weight profiles (PWF sums
# to 1 across the column) must be rescaled to preserve the column-integrated
# signal. AK-only modes do not need this correction because AK_norm is already
# dimensionless and approximately normalized without a per-particle factor.
_PWF_MODES: frozenset[str] = frozenset({"pwf", "ak_pwf", "integration", "tccon"})


@dataclass(frozen=True, slots=True)
class WeightingContext:
    """
    Context passed into a particle-weighting model.

    This is intentionally small for the first pass. It gives the weighting interface
    a place to carry observation/operator metadata today while leaving room for
    later chemistry/lifetime extensions without coupling them to transport code.
    """

    coordinate: str = "xhgt"
    observation: Observation | None = None
    operator: VerticalOperator | None = None
    metadata: dict[str, Any] = field(default_factory=dict)


class WeightingModel(Protocol):
    """Behavioral interface for models that reweight particle sensitivities."""

    def apply(
        self,
        particles: pd.DataFrame,
        *,
        context: WeightingContext | None = None,
    ) -> pd.DataFrame:
        """Return particles after applying a weighting model."""
        ...


class NoOpWeighting:
    """Weighting model that returns an unchanged copy of the particles."""

    def apply(
        self,
        particles: pd.DataFrame,
        *,
        context: WeightingContext | None = None,
    ) -> pd.DataFrame:
        """Return an unchanged copy of the particles."""
        return particles.copy()


[docs] class VerticalOperatorWeighting: """Apply a :class:`VerticalOperator` to particle ``foot`` values.""" def __init__(self, operator: VerticalOperator | None = None) -> None: self._operator = operator
[docs] def apply( self, particles: pd.DataFrame, *, context: WeightingContext | None = None, ) -> pd.DataFrame: """Apply a vertical operator to the particle ``foot`` values.""" operator = self._operator or (context.operator if context is not None else None) if operator is None: raise ValueError( "VerticalOperatorWeighting requires a VerticalOperator either " "at construction time or in WeightingContext.operator." ) coordinate = context.coordinate if context is not None else "xhgt" return _apply_vertical_operator_impl( particles, operator, coordinate=coordinate, )
def apply_weighting( particles: pd.DataFrame, weighting: WeightingModel, *, context: WeightingContext | None = None, ) -> pd.DataFrame: """Apply a weighting model to a particle DataFrame.""" return weighting.apply(particles, context=context) def _apply_vertical_operator_impl( particles: pd.DataFrame, operator: VerticalOperator, *, coordinate: str = "xhgt", ) -> pd.DataFrame: """Internal implementation used by the vertical-operator weighting interface.""" if operator.mode == "none": return particles p = particles.copy() # Re-weighting guard: restore the original foot before applying new weights. if "foot_before_weight" in p.columns: p["foot"] = p["foot_before_weight"] p = p.drop(columns=["foot_before_weight"]) if operator.mode == "uniform": return p if coordinate not in p.columns: raise ValueError( f"Particle DataFrame has no column {coordinate!r}. " "Assign release heights ('xhgt') before applying a vertical operator, " "or pass coordinate='pres' for pressure-based interpolation." ) if not operator.levels or not operator.values: raise ValueError( "VerticalOperator.levels and .values must both be non-empty " f"for mode={operator.mode!r}." ) levels = np.asarray(operator.levels, dtype=float) values = np.asarray(operator.values, dtype=float) if len(levels) != len(values): raise ValueError( f"VerticalOperator.levels ({len(levels)}) and .values " f"({len(values)}) must have the same length." ) sort_idx = np.argsort(levels) levels = levels[sort_idx] values = values[sort_idx] coords = p[coordinate].to_numpy(dtype=float) weights = np.interp(coords, levels, values, left=values[0], right=values[-1]) if operator.mode in _PWF_MODES: n_particles = int(cast(pd.Series, p["indx"]).nunique()) weights = weights * n_particles p["foot_before_weight"] = p["foot"] p["foot"] = p["foot"] * weights return p __all__ = [ "NoOpWeighting", "VerticalOperatorWeighting", "WeightingContext", "WeightingModel", "apply_weighting", ]