Source code for arlmet.subset
"""Subset extraction helpers for ARL meteorology files."""
from __future__ import annotations
from collections import OrderedDict
from collections.abc import Iterable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
from arlmet.file import File
from arlmet.grid import Grid, GridWindow
from arlmet.header import Header, record_length_from_grid, split_grid_component
from arlmet.index import IndexRecord, LvlInfo, VarInfo, _derive_index_forecast
from arlmet.packing import unpack
from arlmet.record import DataRecord
from arlmet.vertical import VerticalAxis
if TYPE_CHECKING:
from arlmet.recordset import RecordSet
def normalize_levels(
vertical_axis: VerticalAxis, levels: Iterable[int] | None
) -> tuple[int, ...]:
"""
Normalize a level selection to sorted unique ARL level indices.
"""
if levels is None:
return tuple(range(len(vertical_axis.levels)))
normalized = tuple(sorted({int(level) for level in levels}))
if not normalized:
raise ValueError("levels must include at least one level index.")
max_index = len(vertical_axis.levels) - 1
if normalized[0] < 0 or normalized[-1] > max_index:
raise ValueError(f"levels must be between 0 and {max_index}, got {normalized}.")
return normalized
def resolve_window(
file: File, bbox: tuple[float, float, float, float] | None
) -> GridWindow:
"""
Resolve a bbox selection to a grid window.
"""
if bbox is None:
return file.grid.full_window()
return file.grid.window_from_bbox(bbox)
def select_records(
records: Sequence[DataRecord],
*,
levels: set[int] | None = None,
variables: set[str] | None = None,
) -> list[DataRecord]:
"""
Filter records by ARL level index and variable name.
"""
return [
record
for record in records
if (levels is None or record.level in levels)
and (variables is None or record.variable in variables)
]
def _build_subset_index_record(
recordset: RecordSet,
*,
subset_grid: Grid,
subset_axis: VerticalAxis,
selected_records: Sequence[DataRecord],
level_map: dict[int, int],
) -> IndexRecord:
"""Build the destination index record for one subsetted time step."""
forecast = _derive_index_forecast(
(record.forecast for record in selected_records),
recordset.forecast,
)
level_records: dict[int, OrderedDict[str, VarInfo]] = {
level: OrderedDict() for level in range(len(subset_axis.levels))
}
for record in selected_records:
level_records[level_map[record.level]][record.variable] = VarInfo(
checksum=record.checksum,
reserved=(record._reserved or "")[:1],
)
grid_x, nx = split_grid_component(subset_grid.nx)
grid_y, ny = split_grid_component(subset_grid.ny)
levels = [
LvlInfo(
level=level,
height=float(height),
variables=level_records[level],
)
for level, height in enumerate(subset_axis.levels)
]
projection = subset_grid.projection
time = recordset.time
return IndexRecord(
header=Header(
year=time.year,
month=time.month,
day=time.day,
hour=time.hour,
forecast=forecast,
level=0,
grid=(grid_x, grid_y),
variable="INDX",
exponent=0,
precision=0.0,
initial_value=0.0,
),
source=recordset.source,
forecast=forecast,
minutes=time.minute,
pole_lat=projection.pole_lat,
pole_lon=projection.pole_lon,
tangent_lat=projection.tangent_lat,
tangent_lon=projection.tangent_lon,
grid_size=projection.grid_size,
orientation=projection.orientation,
cone_angle=projection.cone_angle,
sync_x=projection.sync_x,
sync_y=projection.sync_y,
sync_lat=projection.sync_lat,
sync_lon=projection.sync_lon,
reserved=subset_axis.offset,
nx=nx,
ny=ny,
nz=len(levels),
vertical_flag=subset_axis.flag,
index_length=0,
levels=levels,
)
def validate_subset_record_length(
selected_recordsets: Sequence[tuple[RecordSet, Sequence[DataRecord]]],
*,
subset_grid: Grid,
subset_axis: VerticalAxis,
level_map: dict[int, int],
) -> None:
"""
Fail early when a cropped ARL grid cannot fit its index record.
"""
record_len = record_length_from_grid(grid=subset_grid)
for recordset, selected_records in selected_recordsets:
index = _build_subset_index_record(
recordset,
subset_grid=subset_grid,
subset_axis=subset_axis,
selected_records=selected_records,
level_map=level_map,
)
index_len = len(index.tobytes())
if index_len > record_len:
min_cells = index_len - Header.N_BYTES
raise ValueError(
"Subset grid is too small to encode an ARL index record: "
f"time {recordset.time} needs {index_len} bytes, but each record is "
f"only {record_len} bytes for grid {subset_grid.nx}x{subset_grid.ny}. "
f"The bbox must yield at least {min_cells} grid cells (nx*ny). "
"Expand the bbox or reduce levels/variables."
)
[docs]
def extract_subset(
source_path: str | Path,
destination_path: str | Path,
*,
bbox: tuple[float, float, float, float] | None = None,
levels: Iterable[int] | None = None,
variables: Iterable[str] | None = None,
) -> None:
"""
Extract a spatial/vertical subset from an ARL file into a new ARL file.
Parameters
----------
source_path, destination_path : path-like
Input and output ARL file paths.
bbox : tuple[float, float, float, float], optional
Geographic bounding box ``(west, south, east, north)`` in degrees.
levels : iterable of int, optional
ARL level indices to keep. Output levels are compacted and renumbered
from zero while preserving the selected level heights.
variables : iterable of str, optional
Variable names to keep. All variables are included by default.
Examples
--------
>>> import arlmet
>>> arlmet.extract_subset(
... "met.arl",
... "subset.arl",
... bbox=(-114.0, 39.0, -110.0, 42.0),
... levels=[0, 1, 2],
... )
"""
variable_names = None if variables is None else set(variables)
with File(source_path) as source:
window = resolve_window(source, bbox)
selected_levels = normalize_levels(source.vertical_axis, levels)
selected_level_set = set(selected_levels)
level_map = {
old_level: new_level for new_level, old_level in enumerate(selected_levels)
}
subset_grid = source.grid.subset(window)
subset_axis = VerticalAxis(
flag=source.vertical_axis.flag,
levels=source.vertical_axis.levels[list(selected_levels)].tolist(),
offset=source.vertical_axis.offset,
)
selected_recordsets = []
for time in source.times:
src_recordset = source[time]
selected_records = select_records(
src_recordset.records,
levels=selected_level_set,
variables=variable_names,
)
if selected_records:
selected_recordsets.append((src_recordset, selected_records))
validate_subset_record_length(
selected_recordsets,
subset_grid=subset_grid,
subset_axis=subset_axis,
level_map=level_map,
)
with File(
destination_path,
mode="w",
source=source.source,
grid=subset_grid,
vertical_axis=subset_axis,
) as destination:
for src_recordset, selected_records in selected_recordsets:
dst_recordset = destination.create_recordset(
src_recordset.time,
forecast=src_recordset.forecast,
)
for record in selected_records:
if record.diff is None:
data = record.read(window=window)
dst_recordset.create_datarecord(
variable=record.variable,
level=level_map[record.level],
forecast=record.forecast,
data=data,
)
continue
parent_data = np.asarray(
unpack(
packed=record.bytes[Header.N_BYTES :],
nx=source.grid.nx,
ny=source.grid.ny,
precision=record.header.precision,
exponent=record.header.exponent,
initial_value=record.header.initial_value,
window=window,
driver=np,
),
dtype=np.float32,
)
diff_data = np.asarray(
record.diff.read(window=window), dtype=np.float32
)
destination.register_diff_binding(
diff_name=record.diff.variable,
parent_name=record.variable,
)
dst_parent = dst_recordset.create_datarecord(
variable=record.variable,
level=level_map[record.level],
forecast=record.forecast,
data=parent_data,
)
dst_diff = dst_parent._create_diff(
position=-1,
variable=record.diff.variable,
forecast=record.diff.forecast,
)
dst_diff[:] = diff_data