Source code for arlmet.xarray.dataset

from __future__ import annotations

import os
import warnings
from collections import defaultdict
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd
import xarray as xr
from xarray.core import indexing

from arlmet._time import ensure_timestamp
from arlmet.subset import normalize_levels, resolve_window, select_records
from arlmet.vertical import VerticalAxis

from ._backend import ArlVariableArray
from ._coords import (
    _extract_dataset_forecast_hours,
    _extract_dataset_vertical_axis,
    add_cf_spatial_attrs,
    grid_from_coord,
    grid_to_coord,
    physical_coord_for_vaxis,
)

if TYPE_CHECKING:
    from arlmet.file import File
    from arlmet.record import DataRecord


def _build_dataset_from_file(
    met: File,
    *,
    drop_variables: Sequence[str] | None = None,
    bbox: tuple[float, float, float, float] | None = None,
    levels: list[int] | tuple[int, ...] | None = None,
) -> xr.Dataset:
    drop_variables_set: set[str] = set(drop_variables or [])
    window = resolve_window(met, bbox)
    read_window = None if bbox is None else window
    selected_grid = met.grid if bbox is None else met.grid.subset(window)
    requested_levels = (
        None if levels is None else normalize_levels(met.vertical_axis, levels)
    )
    requested_level_set = None if requested_levels is None else set(requested_levels)

    selected_recordsets: list[tuple[pd.Timestamp, int | None, list[DataRecord]]] = []
    present_levels: dict[int, None] = {}

    for time in met.times:
        recordset = met[time]
        selected_records = [
            record
            for record in select_records(recordset.records, levels=requested_level_set)
            if record.variable not in drop_variables_set
        ]
        if not selected_records:
            continue

        for record in selected_records:
            present_levels.setdefault(record.level, None)
        selected_recordsets.append(
            (ensure_timestamp(time), recordset.forecast, selected_records)
        )

    selected_levels = (
        list(requested_levels) if requested_levels is not None else list(present_levels)
    )
    level_pos = {level: pos for pos, level in enumerate(selected_levels)}

    ds = xr.Dataset(coords=selected_grid.calculate_coords())
    ds = add_cf_spatial_attrs(ds)
    # Integer 1-N level coord; surface (index 0) is handled separately below
    ds = ds.assign_coords(
        level=("level", list(selected_levels), {"standard_name": "model_level_number"})
    )

    times: list[pd.Timestamp] = []
    index_forecasts: list[int] = []
    records_by_variable: dict[str, dict[tuple[int, int], Any]] = defaultdict(dict)

    for time_pos, (time, _index_forecast, selected_records) in enumerate(
        selected_recordsets
    ):
        times.append(time)
        index_forecasts.append(0 if _index_forecast is None else int(_index_forecast))
        for record in selected_records:
            records_by_variable[record.variable][
                (time_pos, level_pos[record.level])
            ] = record

    if times:
        ds = ds.assign_coords(
            time=xr.Variable(
                "time", times, attrs={"standard_name": "time", "axis": "T"}
            )
        )
        ds["forecast_hour"] = xr.Variable(
            "time",
            [int(value) for value in index_forecasts],
            attrs={
                "long_name": "ARL index record forecast hour",
                "description": (
                    "Forecast hour stored in the INDX record for each time step. "
                    "Individual variable record forecasts may differ and are not represented "
                    "in this Dataset API."
                ),
                "units": "hours",
            },
        )

    shape = (len(times), len(selected_levels), selected_grid.ny, selected_grid.nx)
    dims = ("time", "level", *selected_grid.dims)
    for name in sorted(records_by_variable):
        backend = ArlVariableArray(
            records=records_by_variable[name],
            shape=shape,
            window=read_window,
        )
        ds[name] = xr.Variable(dims, indexing.LazilyIndexedArray(backend))

    sfc_pos = level_pos.get(0)
    if sfc_pos is not None:
        sfc_only_vars = {
            name
            for name, recs in records_by_variable.items()
            if all(z == sfc_pos for (_, z) in recs)
        }
        for name in sfc_only_vars:
            ds[name] = ds[name].isel(level=sfc_pos).drop_vars("level")

        upper_level_indices = [i for i in range(len(selected_levels)) if i != sfc_pos]
        if upper_level_indices:
            ds = ds.isel(level=upper_level_indices)
        else:
            ds = ds.drop_dims("level")

    ds = ds.assign_coords(arl_grid=grid_to_coord(selected_grid))
    ds.attrs["source"] = met.source
    ds.attrs["vertical_flag"] = met.vertical_axis.flag
    if "level" in ds.coords:
        # Add physical non-dim coord for upper-air levels
        upper_level_ints = [int(v) for v in ds.coords["level"].values.tolist()]
        try:
            phys_name, phys_var = physical_coord_for_vaxis(
                met.vertical_axis, upper_level_ints
            )
            ds = ds.assign_coords({phys_name: phys_var})
        except NotImplementedError:
            warnings.warn(
                f"Physical coordinate for vertical_flag={met.vertical_axis.flag} is not implemented; "
                "the 'level' coordinate will contain raw ARL level indices only.",
                stacklevel=2,
            )
    return ds


[docs] def open_dataset( filename_or_obj: str | os.PathLike[str], drop_variables: Sequence[str] | None = None, bbox: tuple[float, float, float, float] | None = None, levels: list[int] | tuple[int, ...] | None = None, ) -> xr.Dataset: """ Open an ARL meteorology file as an xarray Dataset. Surface-only variables (e.g. ``SHGT``, ``T02M``, ``PRSS``) have dimensions ``(time, lat, lon)`` with no ``level`` dimension. Upper-air variables (e.g. ``UWND``, ``VWND``, ``TEMP``) have dimensions ``(time, level, lat, lon)``. There is no NaN padding. ARL metadata is accessible via the ``.arl`` accessor:: ds = arlmet.open_dataset("met.arl") ds.arl.grid # Grid object (survives isel/sel) ds.arl.vertical_axis # VerticalAxis ds.isel(level=0) # selects level 0 for upper vars; sfc vars unchanged Parameters ---------- filename_or_obj : path-like Path to the ARL file. drop_variables : iterable of str, optional Variable names to omit from the resulting dataset. bbox : tuple[float, float, float, float], optional Geographic bounding box ``(west, south, east, north)`` in degrees. levels : list[int] or tuple[int, ...], optional ARL level indices to keep. Returns ------- xarray.Dataset """ from arlmet.file import File with File(filename_or_obj) as met: return met.to_dataset( drop_variables=drop_variables, bbox=bbox, levels=levels, )
[docs] def write_dataset( ds: xr.Dataset, filename_or_obj: str | os.PathLike[str], *, vertical_axis: VerticalAxis | None = None, ) -> None: """ Write a simple flat Dataset representation to an ARL file. This writer supports the common-case Dataset contract returned by ``open_dataset()``: - surface variables use dims ``(time, y, x)`` or ``(time, lat, lon)`` - upper-air variables use dims ``(time, level, y, x)`` or ``(time, level, lat, lon)`` - all upper-air variables share the same ``level`` coordinate Per-variable forecast heterogeneity is intentionally not represented in the flat Dataset API. ``forecast_hour(time)`` supplies only the ARL index-record forecast written for each time step. """ from arlmet.file import File if not isinstance(ds, xr.Dataset): raise TypeError("write_dataset() requires an xarray.Dataset.") if "time" not in ds.coords: raise ValueError("Dataset must define a 'time' coordinate.") times = pd.to_datetime(ds.coords["time"].values) source = ds.attrs.get("source") if not isinstance(source, str) or not source: raise ValueError("Dataset is missing attrs['source'].") if "arl_grid" not in ds.coords: raise ValueError("Dataset is missing an 'arl_grid' coordinate.") grid = grid_from_coord(ds.coords["arl_grid"]) h_dims = grid.dims resolved_vertical_axis = _extract_dataset_vertical_axis(ds, vertical_axis) forecast_hours = _extract_dataset_forecast_hours(ds, times) data_vars = [name for name in ds.data_vars if name != "forecast_hour"] if not data_vars: raise ValueError("Dataset contains no data variables to write.") diff_names: dict[str, str] = {} for name in data_vars: da = ds[str(name)] diff_name = da.attrs.get("diff") if diff_name is None: continue if not isinstance(diff_name, str) or not diff_name: raise ValueError( f"Variable '{name}' attrs['diff'] must be a non-empty string when provided." ) if not diff_name.startswith("DIF"): raise ValueError( f"Variable '{name}' attrs['diff'] must start with 'DIF', got '{diff_name}'." ) diff_names[str(name)] = diff_name with File( filename_or_obj, mode="w", source=source, grid=grid, vertical_axis=resolved_vertical_axis, ) as arl: for time_index, time in enumerate(times): recordset = arl.create_recordset( pd.Timestamp(time), forecast=forecast_hours[time_index], ) for var_name in data_vars: var_name = str(var_name) if len(var_name) > 4: raise ValueError( f"Variable names must be 4 characters or fewer, got '{var_name}'." ) if var_name.startswith("DIF"): raise ValueError( "Dataset DIF generation is parent-led; set attrs['diff'] on the parent variable instead of writing a DIF* data variable." ) da = ds[var_name] for dim, size in zip(h_dims, (grid.ny, grid.nx), strict=True): if dim not in da.dims: raise ValueError( f"Variable '{var_name}' is missing horizontal dimension '{dim}'." ) if da.sizes[dim] != size: raise ValueError( f"Variable '{var_name}' dimension '{dim}' has size {da.sizes[dim]}, expected {size}." ) if "time" not in da.dims: raise ValueError( f"Variable '{var_name}' data must include the 'time' dimension." ) data_times = pd.to_datetime(da.coords["time"].values) if not data_times.equals(times): raise ValueError( f"Variable '{var_name}' time coordinate must align exactly with the dataset time coordinate." ) if "level" in da.dims: if "level" not in ds.coords: raise ValueError( f"Variable '{var_name}' uses a 'level' dimension but the dataset has no shared level coordinate." ) level_ints = np.atleast_1d( np.asarray(ds.coords["level"].values, dtype=int) ).tolist() transposed = da.transpose( "time", "level", *h_dims, missing_dims="raise" ) write_iter = enumerate(level_ints) else: transposed = da.transpose("time", *h_dims, missing_dims="raise") write_iter = [(None, 0)] for slot, level_index in write_iter: if slot is None: data = np.asarray( transposed.isel(time=time_index).values, dtype=np.float32, ) else: data = np.asarray( transposed.isel(time=time_index, level=slot).values, dtype=np.float32, ) if data.shape != (grid.ny, grid.nx): raise ValueError( f"Variable '{var_name}' slice has shape {data.shape}, expected {(grid.ny, grid.nx)}." ) if np.isnan(data).any(): raise ValueError( f"Variable '{var_name}' contains missing values at time {pd.Timestamp(time)} " f"level {level_index}; write_dataset() requires complete slices." ) if not np.isfinite(data).all(): raise ValueError( f"Variable '{var_name}' contains non-finite values that cannot be packed." ) recordset.create_datarecord( variable=var_name, level=int(level_index), forecast=forecast_hours[time_index], data=data, diff=diff_names.get(var_name), )