Source code for tidal.solver.fields

"""Typed field container for TIDAL solvers.

``FieldSet`` owns named field data on a grid, backed by a single contiguous
flat numpy array.  Named access returns zero-copy views (numpy slices).

This consolidates the state packing/unpacking logic previously duplicated
across ida.py, leapfrog.py, _simulate.py, _io.py, and _writer.py.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
from typing_extensions import override

if TYPE_CHECKING:
    from tidal.solver.state import StateLayout


[docs] class FieldSet: # noqa: PLR0904 """Typed container owning named field data on a grid. Backed by a single contiguous flat ``np.ndarray``. Named access returns zero-copy views (numpy slices) into this array. Parameters ---------- layout : StateLayout Describes the slot structure (field/velocity names and ordering). grid_shape : tuple[int, ...] Shape of the spatial grid (e.g. ``(64,)`` or ``(32, 32)``). data : np.ndarray, optional Flat data array of length ``layout.total_size``. If ``None``, initializes to zeros. """ __slots__ = ( "_aux", "_data", "_field_names", "_grid_shape", "_layout", "_n", "_name_to_idx", "_name_to_range", "_slot_names", "_velocity_names", ) def __init__( self, layout: StateLayout, grid_shape: tuple[int, ...], data: np.ndarray | None = None, ) -> None: self._layout = layout self._grid_shape = grid_shape self._n = layout.num_points if data is not None: if data.shape != (layout.total_size,): msg = ( f"Expected flat array of length {layout.total_size}, " f"got shape {data.shape}" ) raise ValueError(msg) self._data = data else: self._data = np.zeros(layout.total_size) # Reuse the cached name → slot index mapping from StateLayout self._name_to_idx = layout.slot_name_to_idx # Pre-computed (start, end) byte offsets for zero-multiply field access n = layout.num_points self._name_to_range: dict[str, tuple[int, int]] = { name: (idx * n, (idx + 1) * n) for name, idx in layout.slot_name_to_idx.items() } # Auxiliary fields not backed by the flat state array (e.g. # constraint velocities injected from IDA's yp vector). self._aux: dict[str, np.ndarray] = {} # Pre-compute name tuples (layout is immutable, names never change) self._field_names = tuple( slot.name for slot in layout.slots if slot.kind != "velocity" ) self._velocity_names = tuple( slot.name for slot in layout.slots if slot.kind == "velocity" ) self._slot_names = tuple(slot.name for slot in layout.slots) # ---- Named access (zero-copy views) ---- def __getitem__(self, name: str) -> np.ndarray: """Return grid-shaped view of named slot or auxiliary field. Raises ------ KeyError If *name* is not a valid slot or auxiliary name. """ r = self._name_to_range.get(name) if r is not None: return self._data[r[0] : r[1]].reshape(self._grid_shape) if name in self._aux: return self._aux[name] valid = sorted(set(self._name_to_idx) | set(self._aux)) msg = f"Unknown slot '{name}'. Valid: {valid}" raise KeyError(msg) def __setitem__(self, name: str, value: np.ndarray) -> None: """Write grid-shaped data into named slot. Raises ------ KeyError If *name* is not a valid slot name. """ r = self._name_to_range.get(name) if r is None: msg = f"Unknown slot '{name}'. Valid slots: {sorted(self._name_to_idx)}" raise KeyError(msg) self._data[r[0] : r[1]] = np.asarray(value).ravel() def __contains__(self, name: object) -> bool: """Check if *name* is a valid slot or auxiliary name.""" return name in self._name_to_idx or name in self._aux def __len__(self) -> int: """Return the number of slots.""" return len(self._name_to_idx) @override def __repr__(self) -> str: slots = ", ".join(self._name_to_idx) return f"FieldSet(slots=[{slots}], grid_shape={self._grid_shape})" # ---- Auxiliary fields ----
[docs] def set_aux(self, name: str, data: np.ndarray) -> None: """Register an auxiliary named field (not backed by flat array). Used to inject constraint velocities from IDA's ``yp`` vector. """ self._aux[name] = data
# ---- Flat array access ---- @property def flat(self) -> np.ndarray: """Underlying flat vector (no copy). Used by IDA/leapfrog.""" return self._data # ---- Metadata ---- @property def layout(self) -> StateLayout: """The state layout this FieldSet wraps.""" return self._layout @property def grid_shape(self) -> tuple[int, ...]: """Spatial grid shape.""" return self._grid_shape @property def field_names(self) -> tuple[str, ...]: """Ordered field slot names (e.g. ``("phi_0", "A_0")``). Excludes velocity slots. """ return self._field_names @property def velocity_names(self) -> tuple[str, ...]: """Ordered velocity slot names (e.g. ``("v_phi_0",)``). Only present for second-order equations. """ return self._velocity_names @property def momentum_names(self) -> tuple[str, ...]: """Alias for ``velocity_names`` (transition aid).""" return self._velocity_names @property def slot_names(self) -> tuple[str, ...]: """All slot names in order.""" return self._slot_names # ---- Dict-like views (zero-copy) ----
[docs] def fields_dict(self) -> dict[str, np.ndarray]: """Return dict of field name → grid-shaped view (zero-copy).""" return {name: self[name] for name in self.field_names}
[docs] def velocities_dict(self) -> dict[str, np.ndarray]: """Return dict of velocity name → grid-shaped view (zero-copy).""" return {name: self[name] for name in self.velocity_names}
[docs] def momenta_dict(self) -> dict[str, np.ndarray]: """Alias for ``velocities_dict`` (transition aid).""" return self.velocities_dict()
[docs] def as_dict(self) -> dict[str, np.ndarray]: """Return dict of all slot names + aux → grid-shaped views.""" d = {name: self[name] for name in self.slot_names} d.update(self._aux) return d
# ---- Constructors ----
[docs] def copy(self) -> FieldSet: """Deep copy (new flat array, copied aux).""" fs = FieldSet(self._layout, self._grid_shape, self._data.copy()) for k, v in self._aux.items(): fs._aux[k] = v.copy() return fs
[docs] @classmethod def zeros(cls, layout: StateLayout, grid_shape: tuple[int, ...]) -> FieldSet: """Create a FieldSet initialized to zero.""" return cls(layout, grid_shape)
[docs] @classmethod def from_flat( cls, layout: StateLayout, grid_shape: tuple[int, ...], flat: np.ndarray, ) -> FieldSet: """Wrap an existing flat array (no copy). The caller must ensure the array is not unexpectedly mutated elsewhere. """ return cls(layout, grid_shape, flat)
[docs] def rebind(self, flat: np.ndarray) -> None: """Replace the underlying flat array reference (zero-copy, zero-alloc). Used by solver loops to reuse a single FieldSet instead of allocating a new one each timestep. Auxiliary fields are cleared. """ self._data = flat self._aux.clear()
[docs] @classmethod def from_dict( cls, layout: StateLayout, grid_shape: tuple[int, ...], slot_data: dict[str, np.ndarray], ) -> FieldSet: """Pack named arrays into a FieldSet. Missing slots default to zero. Extra keys not in the layout are silently ignored. Parameters ---------- layout : StateLayout State layout descriptor. grid_shape : tuple[int, ...] Spatial grid shape. slot_data : dict[str, np.ndarray] Mapping from slot name → grid-shaped array. """ fs = cls(layout, grid_shape) for name, arr in slot_data.items(): if name in fs: fs[name] = arr return fs
# ---- Diagnostics ----
[docs] def max_norm(self) -> float: """Maximum absolute value across all slots.""" return float(np.max(np.abs(self._data)))
[docs] def check_finite(self) -> bool: """Return True iff all values are finite (no NaN or Inf).""" return bool(np.isfinite(self._data).all())