Source code for tidal.measurement._diagnostics
"""Energy conservation and summary diagnostics."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import numpy as np
from tidal.measurement._energy import (
_ENERGY_FLOOR, # pyright: ignore[reportPrivateUsage]
compute_energy_timeseries,
)
if TYPE_CHECKING:
from numpy.typing import NDArray
from tidal.measurement._io import SimulationData
[docs]
@dataclass(frozen=True)
class EnergyDiagnostics:
"""Energy conservation diagnostic result.
Attributes
----------
times : ndarray, shape ``(n_snapshots,)``
total_energy : ndarray, shape ``(n_snapshots,)``
Spatially-averaged energy density ⟨ε⟩ at each snapshot.
relative_error : ndarray, shape ``(n_snapshots,)``
``(⟨ε⟩(t) - ⟨ε⟩(0)) / ⟨ε⟩(0)``.
max_relative_error : float
Peak relative energy density drift.
is_conserved : bool
Whether ``max_relative_error < threshold``.
"""
times: NDArray[np.float64]
total_energy: NDArray[np.float64]
relative_error: NDArray[np.float64]
max_relative_error: float
is_conserved: bool
[docs]
def check_energy_conservation(
data: SimulationData,
threshold: float = 1e-3,
) -> EnergyDiagnostics:
"""Check whether energy density is conserved over the simulation.
For symplectic (leapfrog) solvers, the physical Hamiltonian oscillates
by O(dt²) around the shadow Hamiltonian. When ``data.dt`` is available,
the threshold is automatically raised to ``max(threshold, 10 * dt²)``
so that the expected shadow-Hamiltonian offset does not cause false
FAIL results.
Parameters
----------
data : SimulationData
threshold : float
Maximum allowed ``|ΔE/E₀|``. Default ``1e-3`` (0.1%).
Returns
-------
EnergyDiagnostics
Raises
------
ValueError
If *threshold* is not positive.
"""
if threshold <= 0:
msg = f"threshold must be positive, got {threshold}"
raise ValueError(msg)
# Scale threshold by shadow-Hamiltonian bound when dt is known.
# Störmer-Verlet conserves H̃ = H + O(dt²); the physical H oscillates
# by O(dt²) around H̃. The factor 10 accounts for O(1) prefactors
# (wave speed, multi-field interactions).
if data.dt is not None:
shadow_bound = 10.0 * data.dt**2
threshold = max(threshold, shadow_bound)
times, _per_field, _interaction, total = compute_energy_timeseries(data)
e0 = total[0]
relative_error = (total - e0) / e0 if e0 >= _ENERGY_FLOOR else np.zeros_like(total)
max_err = float(np.max(np.abs(relative_error)))
return EnergyDiagnostics(
times=times,
total_energy=total,
relative_error=relative_error,
max_relative_error=max_err,
is_conserved=max_err < threshold,
)
[docs]
def summarize(data: SimulationData) -> dict[str, Any]:
"""Compute a measurement summary of the simulation.
Returns
-------
dict with keys:
- ``per_field_energy``: ``dict[str, list[float]]`` time series
- ``interaction_energy``: ``list[float]``
- ``total_energy``: ``list[float]``
- ``energy_conservation``: :class:`EnergyDiagnostics`
- ``field_peaks``: ``dict[str, tuple[float, float]]`` (initial, final peak amplitude)
"""
times, per_field, interaction, total = compute_energy_timeseries(data)
# Peak amplitudes
field_peaks: dict[str, tuple[float, float]] = {}
for name in data.fields:
initial_peak = float(np.max(np.abs(data.fields[name][0])))
final_peak = float(np.max(np.abs(data.fields[name][-1])))
field_peaks[name] = (initial_peak, final_peak)
return {
"times": times.tolist(),
"per_field_energy": {k: v.tolist() for k, v in per_field.items()},
"interaction_energy": interaction.tolist(),
"total_energy": total.tolist(),
"energy_conservation": check_energy_conservation(data),
"field_peaks": field_peaks,
}