"""Uniform data abstraction for simulation output.
Provides ``SimulationData``, a frozen dataclass that stores the full time
history of field and velocity arrays. Can be constructed from:
- A solver result dict (IDA/leapfrog output)
- A snapshot directory written by :class:`~tidal.measurement.SnapshotWriter`
(memory-mapped, O(1) RAM)
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from math import prod
from pathlib import Path
from typing import TYPE_CHECKING, cast
import numpy as np
from tidal.solver.state import StateLayout
if TYPE_CHECKING:
from numpy.typing import NDArray
from tidal.solver._types import SolverResult
from tidal.solver.grid import GridInfo
from tidal.symbolic.json_loader import EquationSystem
[docs]
@dataclass(frozen=True)
class SimulationData:
"""Full time-history of a simulation, ready for measurement.
Both *fields* and *velocities* store arrays of shape
``(n_snapshots, *grid_shape)`` — one spatial snapshot per recorded time.
Constraint fields (``time_derivative_order == 0``) have no velocity entry.
Attributes
----------
times : ndarray, shape (n_snapshots,)
Snapshot times.
fields : dict[str, ndarray]
Mapping ``field_name → (n_snapshots, *grid_shape)`` arrays.
velocities : dict[str, ndarray]
Mapping ``field_name → (n_snapshots, *grid_shape)`` velocity arrays
(v = dq/dt). Only present for 2nd-order (wave) fields.
grid_spacing : tuple[float, ...]
Cell size per spatial axis, e.g. ``(dx, dy)``.
grid_bounds : tuple[tuple[float, float], ...]
Domain bounds per spatial axis.
periodic : tuple[bool, ...]
Whether each spatial axis is periodic.
spec : EquationSystem
The equation specification (fields, equations, matrices).
parameters : dict[str, float]
Resolved parameter values used in the simulation.
bc_types : tuple[str, ...] or None
Per-axis boundary condition type (e.g. ``("periodic", "neumann")``).
``None`` for legacy data where BC info was not recorded.
Used by the energy module for BC-aware gradient computation.
"""
times: NDArray[np.float64]
fields: dict[str, NDArray[np.float64]]
velocities: dict[str, NDArray[np.float64]]
grid_spacing: tuple[float, ...]
grid_bounds: tuple[tuple[float, float], ...]
periodic: tuple[bool, ...]
spec: EquationSystem
parameters: dict[str, float]
bc_types: tuple[str, ...] | None = None
dt: float | None = None
# ------------------------------------------------------------------
# Derived helpers
# ------------------------------------------------------------------
@property
def n_snapshots(self) -> int:
"""Number of time snapshots."""
return len(self.times)
@property
def volume_element(self) -> float:
"""Uniform cell volume ``dx * dy * ...``."""
return float(prod(self.grid_spacing))
@property
def dynamical_fields(self) -> tuple[str, ...]:
"""Field names with ``time_derivative_order >= 2`` (have velocities)."""
return tuple(
eq.field_name
for eq in self.spec.equations
if eq.time_derivative_order >= 2 # noqa: PLR2004
)
# ------------------------------------------------------------------
# Constructors
# ------------------------------------------------------------------
[docs]
@classmethod
def from_result(
cls,
result: SolverResult,
spec: EquationSystem,
grid_info: GridInfo,
parameters: dict[str, float] | None = None,
dt: float | None = None,
) -> SimulationData:
"""Build from a solver result dict (IDA/leapfrog output).
This is the native-path constructor — no py-pde types involved.
Directly slices the flat state vector using ``StateLayout``.
Parameters
----------
result : dict
Solver output with keys ``"t"`` (1D times array) and
``"y"`` (2D array of shape ``(n_snapshots, total_flat_size)``).
spec : EquationSystem
Equation specification.
grid_info : GridInfo
Spatial grid descriptor.
parameters : dict, optional
Resolved parameter values.
dt : float, optional
Time-step size used by the solver (for conservation diagnostics).
Raises
------
ValueError
If *result* has no snapshots or flat vector size doesn't match
the layout.
"""
if not result["success"]:
msg = f"Cannot build SimulationData from failed solver result: {result['message']}"
raise ValueError(msg)
times = np.asarray(result["t"], dtype=np.float64)
y_all = np.asarray(result["y"], dtype=np.float64)
if len(times) == 0:
msg = "Solver result contains no snapshots"
raise ValueError(msg)
layout = StateLayout.from_spec(spec, grid_info.num_points)
if y_all.ndim == 1:
y_all = y_all.reshape(1, -1)
if y_all.shape[1] != layout.total_size:
msg = (
f"Flat vector size {y_all.shape[1]} doesn't match "
f"layout.total_size {layout.total_size}"
)
raise ValueError(msg)
if y_all.shape[0] != len(times):
msg = (
f"Snapshot count mismatch: {len(times)} time points but "
f"{y_all.shape[0]} state vectors in result['y']"
)
raise ValueError(msg)
n_pts = grid_info.num_points
shape = grid_info.shape
fields: dict[str, NDArray[np.float64]] = {}
velocities: dict[str, NDArray[np.float64]] = {}
for i, slot in enumerate(layout.slots):
start = i * n_pts
end = start + n_pts
# Slice all snapshots at once: (n_snapshots, *grid_shape)
arr = y_all[:, start:end].reshape(-1, *shape)
if slot.kind == "velocity":
velocities[slot.field_name] = arr
else:
fields[slot.name] = arr
# Merge exact constraint velocities from modal solver (if available).
# "Constraint" is a solver concept (algebraic evolution), not a physics
# statement — constraint fields have physically meaningful velocities
# computed from v_recovery = recovery @ A_reduced.
cv = result.get("constraint_velocities")
if cv is not None:
for c_name, c_vel_arr in cv.items():
velocities[c_name] = np.asarray(c_vel_arr, dtype=np.float64)
return cls(
times=times,
fields=fields,
velocities=velocities,
grid_spacing=grid_info.dx,
grid_bounds=grid_info.bounds,
periodic=grid_info.periodic,
spec=spec,
parameters=parameters or {},
bc_types=grid_info.bc_types,
dt=dt,
)
[docs]
@classmethod
def from_directory( # noqa: C901, PLR0912, PLR0914, PLR0915
cls,
path: Path | str,
spec: EquationSystem,
) -> SimulationData:
"""Load from a snapshot directory with memory-mapped arrays (O(1) RAM).
The directory must contain ``metadata.json`` and per-field ``.npy``
files written by :class:`~tidal.measurement.SnapshotWriter`. Arrays
are opened as read-only memory maps — only the pages actually accessed
by measurement functions are loaded into RAM.
Parameters
----------
path : Path or str
Path to the snapshot directory.
spec : EquationSystem
JSON-derived equation specification.
Raises
------
FileNotFoundError
If *path* does not exist or is not a directory.
ValueError
If ``metadata.json`` is missing or corrupt.
"""
p = Path(path)
if not p.is_dir():
msg = f"Snapshot directory not found: {p}"
raise FileNotFoundError(msg)
metadata_path = p / "metadata.json"
metadata: dict[str, object] = {}
if metadata_path.exists():
try:
raw = metadata_path.read_text()
metadata = json.loads(raw)
except (json.JSONDecodeError, UnicodeDecodeError):
# Corrupt metadata (e.g. truncated write) — fall through
metadata = {}
if not metadata:
# Crash recovery: infer snapshot count from times.npy size
times_path = p / "times.npy"
if not times_path.exists():
msg = f"Snapshot directory {p} missing both metadata.json and times.npy"
raise ValueError(msg)
times_arr = np.load(str(times_path), mmap_mode="r")
# Find actual count: last non-zero time (or all if first is 0.0)
n_recovered = _infer_snapshot_count(times_arr)
metadata = {
"n_snapshots": n_recovered,
"fields": list(spec.component_names),
"velocities": [
eq.field_name
for eq in spec.equations
if eq.time_derivative_order >= 2 # noqa: PLR2004
],
}
n = int(metadata["n_snapshots"]) # type: ignore[arg-type]
# Load times
times_npy = p / "times.npy"
if not times_npy.exists():
msg = f"Snapshot directory {p} missing times.npy"
raise ValueError(msg)
times_full = np.load(str(times_npy), mmap_mode="r")
if n > len(times_full):
msg = (
f"Metadata claims {n} snapshots but times.npy "
f"has only {len(times_full)} entries"
)
raise ValueError(msg)
times = times_full[:n]
# Load fields (memory-mapped)
fields: dict[str, NDArray[np.float64]] = {}
for name in spec.component_names:
npy = p / f"{name}.npy"
if npy.exists():
fields[name] = np.load(str(npy), mmap_mode="r")[:n]
# Load velocities (memory-mapped)
velocities: dict[str, NDArray[np.float64]] = {}
# Accept both "velocities" (new) and "momenta" (legacy) metadata keys
velocity_names = cast(
"list[str]",
metadata.get("velocities", metadata.get("momenta", [])),
)
for name in velocity_names:
npy = p / f"v_{name}.npy"
if npy.exists():
velocities[name] = np.load(str(npy), mmap_mode="r")[:n]
# Load constraint velocities (saved by modal solver).
# Constraint fields (time_order=0) have physical velocities
# determined by coupling to dynamical fields. The modal solver
# computes exact ∂_t via v_recovery = recovery @ A_reduced.
for eq in spec.equations:
if eq.time_derivative_order == 0 and eq.field_name not in velocities:
npy = p / f"v_{eq.field_name}.npy"
if npy.exists():
velocities[eq.field_name] = np.load(str(npy), mmap_mode="r")[:n]
# Grid metadata — from metadata.json or spec defaults
grid_spacing: tuple[float, ...]
if "grid_spacing" in metadata:
raw_spacing = cast("list[float]", metadata["grid_spacing"])
grid_spacing = tuple(float(v) for v in raw_spacing)
else:
grid_spacing = (1.0,) * spec.spatial_dimension
grid_bounds: tuple[tuple[float, float], ...]
if "grid_bounds" in metadata:
raw_bounds = cast("list[list[float]]", metadata["grid_bounds"])
grid_bounds = tuple((float(b[0]), float(b[1])) for b in raw_bounds)
else:
raw_shape = cast(
"list[int]",
metadata.get("grid_shape", [1] * spec.spatial_dimension),
)
grid_bounds = tuple(
(0.0, float(s) * grid_spacing[i]) for i, s in enumerate(raw_shape)
)
periodic: tuple[bool, ...]
if "periodic" in metadata:
raw_periodic = cast("list[bool]", metadata["periodic"])
periodic = tuple(bool(v) for v in raw_periodic)
else:
periodic = (False,) * spec.spatial_dimension
# Parameters
raw_params = cast("dict[str, float]", metadata.get("parameters", {}))
parameters: dict[str, float] = {str(k): float(v) for k, v in raw_params.items()}
# BC types (version 2+; None for legacy data)
bc_types: tuple[str, ...] | None = None
if "bc_types" in metadata:
raw_bc = cast("list[str]", metadata["bc_types"])
bc_types = tuple(str(v) for v in raw_bc)
# Solver time-step (for conservation diagnostics); None for legacy
dt_val: float | None = None
if "dt" in metadata:
dt_val = float(metadata["dt"]) # type: ignore[arg-type]
# Restore FD order and spectral flag so measurement operators
# match the solver's spatial operators.
if "fd_order" in metadata:
from tidal.solver.operators import set_fd_order # noqa: PLC0415
set_fd_order(int(metadata["fd_order"])) # type: ignore[arg-type]
if metadata.get("spectral"):
from tidal.solver.operators import set_spectral # noqa: PLC0415
set_spectral(True)
return cls(
times=times,
fields=fields,
velocities=velocities,
grid_spacing=grid_spacing,
grid_bounds=grid_bounds,
periodic=periodic,
spec=spec,
parameters=parameters,
bc_types=bc_types,
dt=dt_val,
)
[docs]
def save(self, path: Path | str) -> Path:
"""Save to a snapshot directory (metadata.json + .npy files).
This is the inverse of ``load()`` / ``from_directory()``.
Overwrites any existing files in the directory.
Parameters
----------
path : Path or str
Directory to write into (created if it does not exist).
Returns
-------
Path
The directory that was written.
"""
p = Path(path)
p.mkdir(parents=True, exist_ok=True)
np.save(str(p / "times.npy"), np.asarray(self.times))
for name, arr in self.fields.items():
np.save(str(p / f"{name}.npy"), np.asarray(arr))
for name, arr in self.velocities.items():
np.save(str(p / f"v_{name}.npy"), np.asarray(arr))
# Infer grid shape from the first field array
first_field = next(iter(self.fields.values()))
grid_shape = list(first_field.shape[1:]) # strip snapshot dim
metadata: dict[str, object] = {
"version": 1,
"n_snapshots": self.n_snapshots,
"grid_shape": grid_shape,
"grid_spacing": list(self.grid_spacing),
"grid_bounds": [list(b) for b in self.grid_bounds],
"periodic": list(self.periodic),
"parameters": self.parameters,
"fields": list(self.fields.keys()),
"velocities": list(self.velocities.keys()),
"momenta": list(self.velocities.keys()), # backward compat
"dtype": "float64",
}
if self.bc_types is not None:
metadata["bc_types"] = list(self.bc_types)
if self.dt is not None:
metadata["dt"] = self.dt
(p / "metadata.json").write_text(json.dumps(metadata, indent=2) + "\n")
return p
[docs]
@classmethod
def load(
cls,
path: Path | str,
spec: EquationSystem,
) -> SimulationData:
"""Load from a snapshot directory (memory-mapped, O(1) RAM).
Parameters
----------
path : Path or str
Path to snapshot directory.
spec : EquationSystem
JSON-derived equation specification.
Raises
------
ValueError
If *path* is not a directory.
"""
p = Path(path)
if not p.is_dir():
msg = (
f"Expected a snapshot directory, got file '{p}'. "
f"NPZ format is no longer supported — "
f"use 'tidal simulate --output <directory>'"
)
raise ValueError(msg)
return cls.from_directory(p, spec)
def _infer_snapshot_count(times_arr: NDArray[np.float64]) -> int:
"""Infer actual snapshot count from a times array (crash recovery).
When ``metadata.json`` is missing (writer wasn't closed), we look at
the pre-allocated ``times.npy`` and find how many entries were actually
written. Unwritten entries are zero (from memmap pre-allocation).
Strategy: Find the last index where ``times[i] > 0`` or ``i == 0``
(the initial time t=0 is legitimately zero).
"""
n = len(times_arr)
if n == 0:
return 0
# The first entry (t=0) is always valid. After that, unwritten
# entries are 0.0 from memmap pre-allocation. Find the last
# non-zero entry.
for i in range(n - 1, 0, -1):
t = float(times_arr[i])
if t != 0.0 and np.isfinite(t):
return i + 1
# Only t=0 was written
return 1