Source code for tidal.measurement._writer

"""Streaming snapshot writer for disk-backed simulation storage.

Writes simulation snapshots incrementally to a directory of ``.npy`` files
using memory-mapped I/O, consuming O(1) memory regardless of snapshot count.

The directory layout is framework-agnostic — any solver (py-pde, Julia, etc.)
can write the same format, and :meth:`SimulationData.from_directory` reads it
with lazy memory-mapped access.

Directory structure::

    output_dir/
      metadata.json          # Grid info, parameters, field list, snapshot count
      times.npy              # shape (n_snapshots,), float64
      phi_0.npy              # shape (n_snapshots, *grid_shape), float64
      v_phi_0.npy            # shape (n_snapshots, *grid_shape), float64
      ...
"""

from __future__ import annotations

import contextlib
import json
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Self

import numpy as np

if TYPE_CHECKING:
    from collections.abc import Callable

    from numpy.typing import NDArray

    from tidal.symbolic.json_loader import EquationSystem

# Current format version.  Increment when the metadata schema changes.
_FORMAT_VERSION = 1


def _environment_fingerprint() -> dict[str, str]:
    """Collect software versions for reproducibility tracking."""
    import sys as _sys  # noqa: PLC0415
    from importlib.metadata import PackageNotFoundError  # noqa: PLC0415
    from importlib.metadata import version as _pkg_version  # noqa: PLC0415

    env: dict[str, str] = {"python": _sys.version.split()[0]}
    for pkg in ("tidal", "numpy", "scipy", "scikit-sundae", "matplotlib"):
        with contextlib.suppress(PackageNotFoundError):
            env[pkg] = _pkg_version(pkg)
    return env


_DEFAULT_DTYPE: np.dtype[np.float64] = np.dtype(np.float64)

# Flush mmaps every N snapshots for crash resilience.
# Higher values improve I/O performance; lower values reduce data loss on crash.
_DEFAULT_FLUSH_INTERVAL = 10


def _open_memmap(
    path: Path,
    shape: tuple[int, ...],
    dtype: np.dtype[np.float64] = _DEFAULT_DTYPE,
) -> np.memmap[tuple[int, ...], np.dtype[np.float64]]:
    """Create a writable memory-mapped ``.npy`` file with the given shape."""
    # np.lib.format.open_memmap creates a proper .npy file (with header)
    # that can later be loaded with np.load(..., mmap_mode="r").
    return np.lib.format.open_memmap(  # type: ignore[no-any-return]
        str(path),
        mode="w+",
        dtype=dtype,
        shape=shape,
    )


[docs] class SnapshotWriter: """Stream simulation snapshots to disk with O(1) memory. Pre-allocates one ``.npy`` file per field/velocity (plus ``times.npy``) using ``numpy.memmap``, then writes each snapshot in-place. The exact number of snapshots must be known at construction time — compute it as ``int(t_end / snapshot_interval) + 1``. Use as a context manager for automatic ``close()``:: with SnapshotWriter(output_dir, ...) as writer: for t, fields, velocities in simulation: writer.append(t, fields, velocities) Parameters ---------- output_dir : Path Directory to create. Must not already exist. field_names : list[str] Names of field arrays to store (e.g. ``["phi_0", "chi_0"]``). velocity_names : list[str] Names of velocity arrays to store (e.g. ``["phi_0", "chi_0"]``). Stored as ``v_{name}.npy``. grid_shape : tuple[int, ...] Spatial grid shape (e.g. ``(96, 96)``). n_snapshots : int Exact number of snapshots to write. grid_spacing : tuple[float, ...] Cell size per spatial axis. grid_bounds : tuple[tuple[float, float], ...] Domain bounds per spatial axis. periodic : tuple[bool, ...] Whether each spatial axis is periodic. parameters : dict[str, float] or None Resolved parameter values. spec_path : Path or None Path to the JSON spec file (for auto-discovery by ``tidal measure``). """ def __init__( # noqa: PLR0913, PLR0917 self, output_dir: Path, field_names: list[str], velocity_names: list[str], grid_shape: tuple[int, ...], n_snapshots: int, grid_spacing: tuple[float, ...], grid_bounds: tuple[tuple[float, float], ...], periodic: tuple[bool, ...], parameters: dict[str, float] | None = None, spec_path: Path | None = None, flush_interval: int = _DEFAULT_FLUSH_INTERVAL, bc_types: tuple[str, ...] | None = None, dt: float | None = None, ) -> None: if n_snapshots < 1: msg = f"n_snapshots must be >= 1, got {n_snapshots}" raise ValueError(msg) if not field_names: msg = "field_names must be non-empty" raise ValueError(msg) self._output_dir = Path(output_dir) if self._output_dir.exists() and not self._output_dir.is_dir(): msg = f"Output path exists but is not a directory: {self._output_dir}" raise ValueError(msg) self._output_dir.mkdir(parents=True, exist_ok=True) # Remove stale .npy / metadata files from a previous run existing_npy = list(self._output_dir.glob("*.npy")) stale_meta = self._output_dir / "metadata.json" if existing_npy or stale_meta.exists(): for f in existing_npy: f.unlink() if stale_meta.exists(): stale_meta.unlink() self._n_snapshots = n_snapshots self._grid_shape = grid_shape self._field_names = list(field_names) self._velocity_names = list(velocity_names) self._grid_spacing = grid_spacing self._grid_bounds = grid_bounds self._periodic = periodic self._parameters = parameters or {} self._spec_path = spec_path self._bc_types = bc_types self._dt = dt self._count = 0 self._closed = False self._flush_interval = max(flush_interval, 1) # Pre-allocate times self._times_mmap = _open_memmap( self._output_dir / "times.npy", shape=(n_snapshots,), ) # Pre-allocate per-field .npy files snapshot_shape = (n_snapshots, *grid_shape) self._field_mmaps: dict[ str, np.memmap[tuple[int, ...], np.dtype[np.float64]] ] = {} for name in field_names: self._field_mmaps[name] = _open_memmap( self._output_dir / f"{name}.npy", shape=snapshot_shape, ) # Pre-allocate per-velocity .npy files self._velocity_mmaps: dict[ str, np.memmap[tuple[int, ...], np.dtype[np.float64]] ] = {} for name in velocity_names: self._velocity_mmaps[name] = _open_memmap( self._output_dir / f"v_{name}.npy", shape=snapshot_shape, ) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def append( # noqa: C901 self, t: float, fields: dict[str, NDArray[np.float64]], velocities: dict[str, NDArray[np.float64]], ) -> None: """Write one snapshot at the next time index. Parameters ---------- t : float Simulation time for this snapshot. fields : dict[str, ndarray] Mapping ``field_name → spatial_array`` for this snapshot. velocities : dict[str, ndarray] Mapping ``field_name → spatial_array`` for velocities. Raises ------ ValueError If the writer is closed, the snapshot count is exceeded, the time is non-finite or non-monotonic, or a field/velocity array has the wrong shape. """ if self._closed: msg = "SnapshotWriter is already closed" raise ValueError(msg) if self._count >= self._n_snapshots: msg = ( f"Cannot write snapshot {self._count}: " f"pre-allocated for {self._n_snapshots} snapshots" ) raise ValueError(msg) # Validate time if not np.isfinite(t): msg = f"Time must be finite, got {t}" raise ValueError(msg) if self._count > 0: prev_t = float(self._times_mmap[self._count - 1]) if t < prev_t: msg = ( f"Times must be non-decreasing: " f"snapshot {self._count - 1} at t={prev_t}, " f"snapshot {self._count} at t={t}" ) raise ValueError(msg) idx = self._count self._times_mmap[idx] = t for name in self._field_names: if name not in fields: msg = f"Missing field '{name}' in snapshot {idx}" raise ValueError(msg) arr = fields[name] if arr.shape != self._grid_shape: msg = ( f"Field '{name}' has shape {arr.shape}, expected {self._grid_shape}" ) raise ValueError(msg) self._field_mmaps[name][idx] = arr for name in self._velocity_names: if name not in velocities: msg = f"Missing velocity '{name}' in snapshot {idx}" raise ValueError(msg) arr = velocities[name] if arr.shape != self._grid_shape: msg = ( f"Velocity '{name}' has shape {arr.shape}, " f"expected {self._grid_shape}" ) raise ValueError(msg) self._velocity_mmaps[name][idx] = arr self._count += 1 # Flush periodically for crash resilience (every _flush_interval snapshots) if self._count % self._flush_interval == 0: self._flush_mmaps()
@property def count(self) -> int: """Number of snapshots written so far.""" return self._count @property def n_snapshots(self) -> int: """Total number of snapshots pre-allocated.""" return self._n_snapshots @property def output_dir(self) -> Path: """Directory where snapshot files are written.""" return self._output_dir
[docs] def close(self) -> None: """Flush all mmaps and write ``metadata.json``. It is safe to call ``close()`` multiple times. """ if self._closed: return if self._count != self._n_snapshots: warnings.warn( f"SnapshotWriter closed with {self._count}/{self._n_snapshots} " f"snapshots written", stacklevel=2, ) # Flush and release mmaps self._flush_mmaps() del self._times_mmap for mmap in self._field_mmaps.values(): del mmap self._field_mmaps.clear() for mmap in self._velocity_mmaps.values(): del mmap self._velocity_mmaps.clear() # Write metadata metadata: dict[str, Any] = { "version": _FORMAT_VERSION, "n_snapshots": self._count, "grid_spacing": list(self._grid_spacing), "grid_bounds": [list(b) for b in self._grid_bounds], "grid_shape": list(self._grid_shape), "periodic": list(self._periodic), "parameters": self._parameters, "fields": self._field_names, "velocities": self._velocity_names, "momenta": self._velocity_names, # backward compat "dtype": "float64", } if self._bc_types is not None: metadata["bc_types"] = list(self._bc_types) if self._dt is not None: metadata["dt"] = self._dt if self._spec_path is not None: metadata["spec_path"] = str(self._spec_path) # Save FD order and spectral flag so that measurement tools can # restore the correct operators for energy computation (must match # the solver's spatial operators). from tidal.solver.operators import get_fd_order, get_spectral # noqa: PLC0415 fd_order = get_fd_order() if fd_order != 2: # noqa: PLR2004 metadata["fd_order"] = fd_order if get_spectral(): metadata["spectral"] = True # Environment fingerprint for reproducibility metadata["environment"] = _environment_fingerprint() metadata_path = self._output_dir / "metadata.json" # Atomic write: temp file + rename to avoid corrupt JSON on crash temp_path = self._output_dir / "metadata.json.tmp" temp_path.write_text(json.dumps(metadata, indent=2) + "\n") temp_path.replace(metadata_path) self._closed = True
# ------------------------------------------------------------------ # Context manager # ------------------------------------------------------------------ def __enter__(self) -> Self: return self def __exit__(self, *exc: object) -> None: self.close() # ------------------------------------------------------------------ # Internal # ------------------------------------------------------------------ def _flush_mmaps(self) -> None: """Flush all memory-mapped files to disk.""" self._times_mmap.flush() for mmap in self._field_mmaps.values(): mmap.flush() for mmap in self._velocity_mmaps.values(): mmap.flush()
[docs] def compute_snapshot_count(t_end: float, snapshot_interval: float) -> int: """Compute the exact number of snapshots for a simulation. Parameters ---------- t_end : float Total simulation time. snapshot_interval : float Time between snapshots. Returns ------- int Exact snapshot count (includes the initial state at t=0). Raises ------ ValueError If *t_end* or *snapshot_interval* are non-positive. """ if t_end <= 0: msg = f"t_end must be positive, got {t_end}" raise ValueError(msg) if snapshot_interval <= 0: msg = f"snapshot_interval must be positive, got {snapshot_interval}" raise ValueError(msg) return int(t_end / snapshot_interval) + 1
def _field_names_from_spec( spec: EquationSystem, *, exclude_constraints: bool = False, ) -> tuple[list[str], list[str]]: """Extract field and velocity names from an equation spec. Parameters ---------- spec : EquationSystem The equation specification. exclude_constraints : bool If True, omit fields with ``time_derivative_order == 0``. Returns ------- field_names : list[str] Field names to store. velocity_names : list[str] Velocity names to store (fields with ``time_derivative_order >= 2``). """ field_names: list[str] = [] velocity_names: list[str] = [] for eq in spec.equations: if exclude_constraints and eq.time_derivative_order == 0: continue field_names.append(eq.field_name) if eq.time_derivative_order >= 2: # noqa: PLR2004 velocity_names.append(eq.field_name) return field_names, velocity_names
[docs] def create_snapshot_callback( # noqa: PLR0913, PLR0917 output_dir: Path | str, spec: EquationSystem, grid: Any, # noqa: ANN401 t_end: float, snapshot_interval: float, parameters: dict[str, float] | None = None, spec_path: Path | None = None, ) -> tuple[SnapshotWriter, Callable[[Any, float], None]]: """Create a SnapshotWriter + callback for streaming snapshots to disk. The returned *callback* accepts ``(state, time)`` where *state* is a ``pde.FieldCollection``. Wrap it in ``CallbackTracker(callback, ...)`` and pass to ``pde.solve()``. Call ``writer.close()`` after the solve. Parameters ---------- output_dir : Path or str Directory to write snapshot files into. spec : EquationSystem Equation system (provides field/velocity names and state layout). grid : CartesianGrid py-pde grid (provides shape, spacing, bounds, periodicity). t_end : float Total simulation time. snapshot_interval : float Time between snapshots. parameters : dict or None Resolved parameter values to store in metadata. spec_path : Path or None Path to the JSON spec file (for auto-discovery by ``tidal measure``). Returns ------- writer : SnapshotWriter Call ``writer.close()`` after the solve finishes. callback : callable Pass to ``CallbackTracker(callback, interrupts=snapshot_interval)``. """ out = Path(output_dir) field_names, velocity_names = _field_names_from_spec(spec) n_snapshots = compute_snapshot_count(t_end, snapshot_interval) grid_bounds_raw = grid.bounds spacing = tuple( float((b[1] - b[0]) / s) for b, s in zip(grid_bounds_raw, grid.shape, strict=True) ) bounds = tuple((float(b[0]), float(b[1])) for b in grid_bounds_raw) periodic_flags = tuple(bool(p) for p in grid.periodic) writer = SnapshotWriter( output_dir=out, field_names=field_names, velocity_names=velocity_names, grid_shape=tuple(grid.shape), n_snapshots=n_snapshots, grid_spacing=spacing, grid_bounds=bounds, periodic=periodic_flags, parameters=parameters, spec_path=Path(spec_path) if spec_path else None, ) # Build slot maps from spec.state_layout field_slots: dict[str, int] = {} velocity_slots: dict[str, int] = {} for idx, (name, slot_type) in enumerate(spec.state_layout): if slot_type == "field": field_slots[name] = idx elif slot_type == "velocity": velocity_slots[name] = idx field_set = set(field_names) velocity_set = set(velocity_names) def _on_snapshot(state_view: Any, time: float) -> None: # noqa: ANN401 fields = { name: np.asarray(state_view[slot].data, dtype=np.float64) for name, slot in field_slots.items() if name in field_set } vels = { name: np.asarray(state_view[slot].data, dtype=np.float64) for name, slot in velocity_slots.items() if name in velocity_set } writer.append(time, fields, vels) return writer, _on_snapshot