Source code for arlmet.xarray._vertical

"""xarray-aware vertical coordinate helpers for ARL datasets."""

from __future__ import annotations

from collections.abc import Hashable
from typing import Any

import numpy as np
import xarray as xr

from arlmet.vertical import (
    HybridAxis,
    PressureAxis,
    SigmaAxis,
    TerrainAxis,
    hypsometric_z_agl,
)

__all__ = ["pressure", "z_agl", "z_msl"]


def _hypsometric_z_agl(
    p_levels: xr.DataArray,
    p_surface: xr.DataArray,
    temp: xr.DataArray,
) -> xr.DataArray:
    """
    Height AGL (m) at each level via the hypsometric equation.

    Thin xarray wrapper around :func:`arlmet.vertical.hypsometric_z_agl`.
    ``p_levels`` must be ordered from high to low pressure (surface to top).
    For flag=2, ``p_levels`` is 1D ``(level,)`` and is broadcast to match
    ``temp``. For flags 1/4, ``p_levels`` has the same shape as ``temp``.
    """
    if "level" not in temp.dims:
        raise ValueError("'TEMP' must have a 'level' dimension.")

    level_ax = list(temp.dims).index("level")
    z_vals = hypsometric_z_agl(
        p_levels.values,
        p_surface.values,
        temp.values,
        level_axis=level_ax,
    )

    return xr.DataArray(
        z_vals.astype(np.float32),
        dims=temp.dims,
        coords=temp.coords,
        attrs={
            "units": "m",
            "long_name": "height above ground level",
            "standard_name": "height",
        },
    )


def _sigma_hybrid_pressure(
    ds: xr.Dataset, axis: SigmaAxis | HybridAxis
) -> xr.DataArray:
    """Compute spatially varying pressure for sigma/hybrid axes."""
    if "PRSS" not in ds:
        raise ValueError(
            "Sigma/hybrid pressure conversion requires 'PRSS' (surface pressure) in dataset."
        )
    if "level" not in ds.coords:
        raise ValueError("Dataset has no 'level' coordinate.")

    prss = ds["PRSS"]
    level_ints = ds.coords["level"].values.astype(int).tolist()
    n_spatial = len(ds.arl.grid.dims)

    prss_vals = prss.values
    orig_shape = prss_vals.shape
    # to_pressure returns (..., all_levels) — select only the dataset's levels
    p_all = axis.to_pressure(surface_pressure=prss_vals.reshape(-1))
    p_flat = p_all[:, level_ints]  # (n, nlev_dataset)
    nlev = len(level_ints)

    p_arr = p_flat.reshape(*orig_shape, nlev)
    # Move the last (level) axis before the spatial dims
    level_target = len(orig_shape) - n_spatial
    p_arr = np.moveaxis(p_arr, -1, level_target)

    result_dims = (
        list(prss.dims[:level_target]) + ["level"] + list(prss.dims[level_target:])
    )
    coords: dict[Hashable, Any] = {
        d: ds.coords[d] for d in result_dims if d in ds.coords
    }
    for name in ("lat", "lon"):
        if name in ds.coords and name not in coords:
            coords[name] = ds.coords[name]

    return xr.DataArray(
        p_arr.astype(np.float32),
        dims=result_dims,
        coords=coords,
        attrs={
            "units": "hPa",
            "long_name": "air pressure",
            "standard_name": "air_pressure",
        },
    )


[docs] def pressure(ds: xr.Dataset) -> xr.DataArray: """ Pressure (hPa) at each level. Parameters ---------- ds : xr.Dataset Dataset from :func:`arlmet.open_dataset`. Returns ------- xr.DataArray - **flag=2**: the stored 1D ``pressure`` non-dim coord ``(level,)``. - **flag=1/4**: spatially varying array ``(time, level, y, x)`` computed from surface pressure (``PRSS``) via sigma-to-pressure conversion. - **flag=3**: returns the ``PRES`` data variable (3D pressure field in hPa) if present in the dataset. Terrain-following ARL files typically store this field. Raises ``ValueError`` if ``PRES`` is absent. """ vaxis = ds.arl.vertical_axis if isinstance(vaxis, PressureAxis): if "pressure" not in ds.coords: raise ValueError("Dataset has no 'pressure' coordinate.") return ds.coords["pressure"] if isinstance(vaxis, (SigmaAxis, HybridAxis)): return _sigma_hybrid_pressure(ds, vaxis) if isinstance(vaxis, TerrainAxis): if "PRES" not in ds: raise ValueError( "pressure() for terrain-following (flag=3) files requires a 'PRES' " "(3D pressure) variable in the dataset." ) return ds["PRES"] raise NotImplementedError( f"pressure() is not supported for {type(vaxis).__name__}." )
[docs] def z_agl(ds: xr.Dataset) -> xr.DataArray: """ Height above ground level (m) at each level. Parameters ---------- ds : xr.Dataset Dataset from :func:`arlmet.open_dataset`. Returns ------- xr.DataArray - **flag=2**: ``HGTS - SHGT`` (geopotential height minus terrain). Requires ``HGTS`` and ``SHGT`` in the dataset. - **flag=1/4**: hypsometric integration from surface pressure and temperature. Requires ``PRSS`` and ``TEMP``. - **flag=3**: the stored 1D ``height`` non-dim coord ``(level,)`` — terrain-following levels are heights AGL. """ vaxis = ds.arl.vertical_axis if isinstance(vaxis, PressureAxis): for name in ("HGTS", "SHGT"): if name not in ds: raise ValueError( f"z_agl() for pressure-level (flag=2) files requires '{name}' in dataset." ) return ds["HGTS"] - ds["SHGT"] if isinstance(vaxis, (SigmaAxis, HybridAxis)): for name in ("PRSS", "TEMP"): if name not in ds: raise ValueError( f"z_agl() for {vaxis.coord_system} (flag={vaxis.flag}) files " f"requires '{name}' in dataset." ) p_levels = pressure(ds) return _hypsometric_z_agl(p_levels, ds["PRSS"], ds["TEMP"]) if isinstance(vaxis, TerrainAxis): if "height" not in ds.coords: raise ValueError( "Dataset has no 'height' coordinate and no 'HGTS' variable." ) return ds.coords["height"] raise NotImplementedError(f"z_agl() is not supported for {type(vaxis).__name__}.")
[docs] def z_msl(ds: xr.Dataset) -> xr.DataArray: """ Height above mean sea level (m) at each level. Requires ``SHGT`` (surface terrain height in meters) in the dataset. Parameters ---------- ds : xr.Dataset Dataset from :func:`arlmet.open_dataset`. Returns ------- xr.DataArray ``z_agl(ds) + ds["SHGT"]``. """ if "SHGT" not in ds: raise ValueError( "z_msl() requires 'SHGT' (terrain height in meters) in dataset." ) result = z_agl(ds) + ds["SHGT"] result.attrs = {"units": "m", "long_name": "height above mean sea level"} return result