Source code for arlmet.recordset

"""RecordCollection protocol, VariableView, VariableAccessor, and RecordSet."""

from __future__ import annotations

import io
from collections import OrderedDict
from collections.abc import Iterator
from typing import TYPE_CHECKING, Literal

if TYPE_CHECKING:
    from typing_extensions import override
else:

    def override(f: object) -> object:
        return f


import numpy.typing as npt
import pandas as pd

from arlmet.collection import VariableAccessor
from arlmet.grid import Grid
from arlmet.header import Header, split_grid_component
from arlmet.index import IndexRecord, LvlInfo, VarInfo, _derive_index_forecast
from arlmet.record import DataRecord, _require_mode
from arlmet.vertical import VerticalAxis

if TYPE_CHECKING:
    from arlmet.file import File


[docs] class RecordSet: """ Records for one valid time within an ARL file. Parameters ---------- file : File Parent ARL file handle. position : int Byte offset of the index record on disk, or ``-1`` for a new writable record set. time : pandas.Timestamp Valid time represented by the record set. forecast : int, optional Forecast hour stored in the index record header. Attributes ---------- file : File Parent ARL file. position : int Byte offset of the index record. time : pandas.Timestamp Valid time of the record set. forecast : int or None Forecast hour associated with the index record. records : list[DataRecord] Data records stored at this time. variables : VariableAccessor Lazy variable accessor inherited from RecordCollection. Methods ------- __getitem__(key) Get a DataRecord by (level, variable) key. __iter__() Iterate over DataRecords in this record set. __len__() Get the number of records in this record set. create_datarecord(variable, level, forecast=-1, data=None) Create a writable data record for this time. """ def __init__( self, file: File, position: int, time: pd.Timestamp, *, forecast: int | None = None, ): self.file = file self.position = position self.time = time self.forecast = forecast self._datarecords: OrderedDict[tuple[pd.Timestamp, int, str], DataRecord] = ( OrderedDict() ) self.variables = VariableAccessor(self) @property def mode(self) -> Literal["r", "w"]: """Access mode of the record set, inferred from the position.""" return "w" if self.position == -1 else "r" @property def source(self) -> str: """RecordSet file source""" return self.file.source @property def grid(self) -> Grid: """RecordSet file grid""" return self.file.grid @property def vertical_axis(self) -> VerticalAxis: """RecordSet file vertical axis""" return self.file.vertical_axis @property def record_length(self) -> int: """Record length in bytes for this record set, derived from the file grid.""" return self.file.record_length @property def records(self) -> list[DataRecord]: """List of DataRecords in this record set.""" return list(self._datarecords.values()) def _create_datarecord( self, position: int, variable: str, level: int, forecast: int | None = None, checksum: int | None = None, reserved: str | None = None, ) -> DataRecord: """ Internal method to create a DataRecord for an existing record on disk. """ dr = DataRecord( recordset=self, position=position, variable=variable, level=level, forecast=forecast, checksum=checksum, reserved=reserved, ) # Store the record in the record set's internal mapping self._datarecords[(self.time, level, variable)] = dr return dr
[docs] def create_datarecord( self, variable: str, level: int, forecast: int, data: npt.ArrayLike | None = None, diff: str | None = None, ) -> DataRecord: """ Create a writable DataRecord attached to this time step. Parameters ---------- variable : str Four-character ARL variable name. level : int ARL level index for the record. forecast : int Forecast hour to write into the record header. Missing data should use a value of -1. data : numpy.ndarray, optional Initial ``(ny, nx)`` field values to assign. diff : str, optional Name of a trailing DIF record to derive from the parent field. Returns ------- DataRecord Writable data record for the requested variable and level. """ _require_mode(self, "w") if variable.startswith("DIF"): raise ValueError( "Create DIF records through the parent record using diff='DIF...'." ) if diff is not None: self.file.register_diff_binding(diff_name=diff, parent_name=variable) dr = self._create_datarecord( position=-1, variable=variable, level=level, forecast=forecast ) if diff is not None: dr._create_diff(position=-1, variable=diff, forecast=forecast) dr._derive_diff_on_pack = True if data is not None: dr[:] = data return dr
def _build_index_record(self) -> IndexRecord: """ Build the index record for this time step from the writable records. """ if not self._datarecords: raise ValueError("Cannot flush an empty RecordSet.") if len(self.source) > 4: raise ValueError("ARL source identifiers must be 4 characters or fewer.") vaxis = self.file.vertical_axis heights = vaxis.levels.tolist() if not heights: raise ValueError("Vertical axis must contain at least one level.") forecast_hours: set[int] = set() level_records: dict[int, OrderedDict[str, DataRecord]] = { level: OrderedDict() for level in range(len(heights)) } def record_forecast(record: DataRecord) -> int: # Writable records may still hold raw header fields in a dict. # Read forecast directly from that state so index assembly does not # materialize a full Header object for every record. header_state = record._header if isinstance(header_state, dict): forecast = header_state.get("forecast") if forecast is None: raise ValueError( f"Writable DataRecord '{record.variable}' at level {record.level} is missing a forecast hour." ) return int(forecast) return record.forecast for dr in self.records: if len(dr.variable) > 4: raise ValueError( f"Variable names must be 4 characters or fewer, got '{dr.variable}'." ) if dr._unpacked is None: raise ValueError( f"Writable DataRecord '{dr.variable}' at level {dr.level} has no data." ) if dr.level < 0 or dr.level >= len(heights): raise ValueError( f"DataRecord level {dr.level} is outside the configured vertical axis." ) forecast_hours.add(record_forecast(dr)) level_records[dr.level][dr.variable] = dr dr._pack() if dr.diff is not None: forecast_hours.add(record_forecast(dr.diff)) level_records[dr.level][dr.diff.variable] = dr.diff # Derive the index record forecast hour from the data records, ensuring consistency forecast = _derive_index_forecast( record_forecasts=forecast_hours, explicit_forecast=self.forecast ) grid_x, nx = split_grid_component(self.grid.nx) grid_y, ny = split_grid_component(self.grid.ny) levels = [ LvlInfo( level=level, height=float(height), variables=OrderedDict( ( name, VarInfo( checksum=dr.checksum, reserved=(dr._reserved or "")[:1], ), ) for name, dr in level_records[level].items() ), ) for level, height in enumerate(heights) ] projection = self.grid.projection header = Header( year=self.time.year, month=self.time.month, day=self.time.day, hour=self.time.hour, forecast=forecast, level=0, grid=(grid_x, grid_y), variable="INDX", exponent=0, precision=0.0, initial_value=0.0, ) return IndexRecord( header=header, source=self.source, forecast=forecast, minutes=self.time.minute, pole_lat=projection.pole_lat, pole_lon=projection.pole_lon, tangent_lat=projection.tangent_lat, tangent_lon=projection.tangent_lon, grid_size=projection.grid_size, orientation=projection.orientation, cone_angle=projection.cone_angle, sync_x=projection.sync_x, sync_y=projection.sync_y, sync_lat=projection.sync_lat, sync_lon=projection.sync_lon, reserved=vaxis.offset, nx=nx, ny=ny, nz=len(levels), vertical_flag=vaxis.flag, index_length=0, levels=levels, ) def _flush(self): """Write the index record and all pending data records to disk.""" _require_mode(self, "w") index = self._build_index_record() fh = self.file.handle if self.position == -1: fh.seek(0, io.SEEK_END) self.position = fh.tell() else: fh.seek(self.position) fh.write(index.to_record_bytes(self.record_length)) for level in index.levels: for name in level.variables: self._lookup_flush_record(level.level, name)._flush() def _lookup_flush_record(self, level: int, variable: str) -> DataRecord: key = (self.time, level, variable) record = self._datarecords.get(key) if record is not None: return record for parent in self.records: if ( parent.level == level and parent.diff is not None and parent.diff.variable == variable ): return parent.diff raise KeyError(f"No writable record found for ({level}, {variable}).")
[docs] def __getitem__(self, key: tuple[int, str]) -> DataRecord: if not isinstance(key, tuple) or len(key) != 2: raise KeyError("Key must be a tuple of (level, variable).") return self._datarecords[(self.time, *key)]
[docs] def __iter__(self) -> Iterator[DataRecord]: return iter(self.records)
[docs] def __len__(self) -> int: return len(self._datarecords)
def __contains__(self, key: object) -> bool: if isinstance(key, str): return key in {r.variable for r in self.records} return False @override def __repr__(self) -> str: t = self.time.strftime("%Y-%m-%d %H:%M") return f"RecordSet(time={t}, forecast={self.forecast}, n={len(self)})"