Source code for tidal.solver.state

"""State vector layout and flat ↔ field-dict conversions for TIDAL solvers.

IDA and leapfrog operate on flat numpy arrays.  This module provides:

- ``StateLayout``: describes which slots map to which fields/velocities
- ``state_to_flat`` / ``flat_to_fields``: convert between flat vectors and
  named field dictionaries
"""

from __future__ import annotations

from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING

import numpy as np

from tidal.solver._defaults import SECOND_ORDER

if TYPE_CHECKING:
    from tidal.symbolic.json_loader import EquationSystem


# ---------------------------------------------------------------------------
# Slot descriptor
# ---------------------------------------------------------------------------


[docs] @dataclass(frozen=True) class SlotInfo: """Metadata for a single slot in the flat state vector. Each slot holds ``grid.num_points`` contiguous values in C-order. Attributes ---------- name : str Field or velocity name (e.g. ``"phi_0"`` or ``"v_phi_0"``). field_name : str The physical field this slot belongs to (always the field name, even for velocity slots). kind : str One of ``"field"``, ``"velocity"``, ``"constraint"``. time_order : int Original LHS time-derivative order (0, 1, or 2). dynamical_index : int | None Index into the dynamical-field subset (for K matrix lookup). None for constraints and first-order fields. """ name: str field_name: str kind: str time_order: int dynamical_index: int | None = None
# --------------------------------------------------------------------------- # Layout # ---------------------------------------------------------------------------
[docs] @dataclass(frozen=True) class StateLayout: """Describes the mapping between flat state vector and named fields. Built from an ``EquationSystem`` spec. The flat vector is partitioned into contiguous blocks of ``num_points`` elements, one per slot. Attributes ---------- slots : tuple[SlotInfo, ...] Ordered slot descriptors. num_points : int Grid points per slot (``grid.num_points``). field_slot_map : dict[str, int] Maps field name → slot index. velocity_slot_map : dict[str, int] Maps field name → velocity slot index (second-order fields only). dynamical_fields : tuple[str, ...] Names of dynamical fields (time_order >= 2), in order. """ slots: tuple[SlotInfo, ...] num_points: int field_slot_map: dict[str, int] velocity_slot_map: dict[str, int] dynamical_fields: tuple[str, ...] @property def momentum_slot_map(self) -> dict[str, int]: """Alias for ``velocity_slot_map`` (transition aid).""" return self.velocity_slot_map
[docs] @classmethod def from_spec(cls, spec: EquationSystem, num_points: int) -> StateLayout: """Build layout from an equation system specification. Follows the same ordering as py-pde FieldCollection: for each equation, emit a field slot; if second-order, also emit a velocity slot immediately after. """ slots: list[SlotInfo] = [] field_slot_map: dict[str, int] = {} velocity_slot_map: dict[str, int] = {} dynamical_fields: list[str] = [] dyn_idx = 0 for eq in spec.equations: name = eq.field_name order = eq.time_derivative_order kind = "constraint" if order == 0 else "field" d_idx = None if order >= SECOND_ORDER: d_idx = dyn_idx dynamical_fields.append(name) dyn_idx += 1 field_slot_map[name] = len(slots) slots.append( SlotInfo( name=name, field_name=name, kind=kind, time_order=order, dynamical_index=d_idx, ) ) if order >= SECOND_ORDER: vel_name = f"v_{name}" velocity_slot_map[name] = len(slots) slots.append( SlotInfo( name=vel_name, field_name=name, kind="velocity", time_order=order, dynamical_index=d_idx, ) ) return cls( slots=tuple(slots), num_points=num_points, field_slot_map=field_slot_map, velocity_slot_map=velocity_slot_map, dynamical_fields=tuple(dynamical_fields), )
@cached_property def slot_name_to_idx(self) -> dict[str, int]: """Map from slot name to slot index. Cached on frozen dataclass.""" return {slot.name: i for i, slot in enumerate(self.slots)} @cached_property def total_size(self) -> int: """Total flat vector length (num_slots * num_points).""" return len(self.slots) * self.num_points @cached_property def num_slots(self) -> int: """Number of slots in the state vector.""" return len(self.slots)
[docs] def slot_slice(self, slot_idx: int) -> slice: """Return the flat-array slice for a given slot index.""" n = self.num_points return slice(slot_idx * n, (slot_idx + 1) * n)
# ---- Pre-computed slot groups (branch-free hot-path iteration) ---- @cached_property def velocity_slot_groups(self) -> tuple[tuple[int, slice, str], ...]: """Pre-computed ``(slot_idx, flat_slice, field_name)`` for velocity slots.""" return tuple( (i, self.slot_slice(i), s.field_name) for i, s in enumerate(self.slots) if s.kind == "velocity" ) @cached_property def dynamical_field_slot_groups(self) -> tuple[tuple[int, slice, int], ...]: """Pre-computed ``(slot_idx, flat_slice, vel_slot_idx)`` for 2nd-order fields.""" return tuple( (i, self.slot_slice(i), self.velocity_slot_map[s.field_name]) for i, s in enumerate(self.slots) if s.kind == "field" and s.time_order >= SECOND_ORDER ) @cached_property def drift_slot_pairs(self) -> tuple[tuple[slice, slice], ...]: """Pre-computed ``(field_slice, vel_slice)`` for zero-copy drift. Allows ``y[field_slice] += dt * y[vel_slice]`` without copying velocity data into a separate buffer. """ return tuple( (self.slot_slice(i), self.slot_slice(self.velocity_slot_map[s.field_name])) for i, s in enumerate(self.slots) if s.kind == "field" and s.time_order >= SECOND_ORDER ) @cached_property def first_order_slot_groups(self) -> tuple[tuple[int, slice, str], ...]: """Pre-computed ``(slot_idx, flat_slice, field_name)`` for 1st-order fields.""" return tuple( (i, self.slot_slice(i), s.field_name) for i, s in enumerate(self.slots) if s.kind == "field" and s.time_order == 1 ) @cached_property def constraint_slot_groups(self) -> tuple[tuple[int, slice, str], ...]: """Pre-computed ``(slot_idx, flat_slice, field_name)`` for constraint slots.""" return tuple( (i, self.slot_slice(i), s.field_name) for i, s in enumerate(self.slots) if s.time_order == 0 ) @cached_property def algebraic_indices(self) -> list[int]: """Flat indices of algebraic (constraint) variables for IDA. IDA needs to know which entries in the state vector are algebraic (not differential) so it can handle them appropriately. """ indices: list[int] = [] n = self.num_points for i, slot in enumerate(self.slots): if slot.time_order == 0: indices.extend(range(i * n, (i + 1) * n)) return indices
# --------------------------------------------------------------------------- # Flat ↔ field-dict conversions # ---------------------------------------------------------------------------
[docs] def state_to_flat( fields: dict[str, np.ndarray], layout: StateLayout, ) -> np.ndarray: """Pack named field arrays into a single flat vector. Parameters ---------- fields : dict[str, np.ndarray] Mapping from slot name (e.g. ``"phi_0"``, ``"v_phi_0"``) to grid-shaped arrays. layout : StateLayout State layout descriptor. Returns ------- np.ndarray Flat vector of length ``layout.total_size``. """ y = np.zeros(layout.total_size) n = layout.num_points for i, slot in enumerate(layout.slots): y[i * n : (i + 1) * n] = fields[slot.name].ravel() return y
[docs] def flat_to_fields( y: np.ndarray, layout: StateLayout, shape: tuple[int, ...], ) -> dict[str, np.ndarray]: """Unpack a flat vector into named field arrays. Parameters ---------- y : np.ndarray Flat state vector. layout : StateLayout State layout descriptor. shape : tuple[int, ...] Grid shape to reshape each slot's data into. Returns ------- dict[str, np.ndarray] Mapping from slot name to grid-shaped array. """ fields: dict[str, np.ndarray] = {} n = layout.num_points for i, slot in enumerate(layout.slots): fields[slot.name] = y[i * n : (i + 1) * n].reshape(shape) return fields