"""Coefficient evaluation with multi-level caching for TIDAL solvers.
``CoefficientEvaluator`` resolves operator term coefficients, handling:
- Constant numeric coefficients (no symbolic expression)
- Parameter-overridable symbolic coefficients (e.g. ``"m2"``, ``"-kappa"``)
- Position-dependent coefficients (evaluated on the spatial grid)
- Time-dependent coefficients (re-evaluated each timestep)
Caching levels:
L0: Constants pre-resolved at init (no coordinate/time dependence)
L1: Mathematica → Python string cache (computed once)
L2: Spatial-only arrays (position-dependent, NOT time-dependent)
L3: Per-timestep deduplication (time+position dependent)
Reuses the evaluation pipeline from ``tidal.symbolic._eval_utils`` — no
reimplementation of Mathematica→Python conversion or eval() logic.
"""
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, cast
import numpy as np
from tidal.symbolic._eval_utils import (
build_eval_namespace,
evaluate_coefficient,
)
if TYPE_CHECKING:
from numpy.typing import NDArray
from tidal.solver.grid import GridInfo
from tidal.symbolic.json_loader import EquationSystem, OperatorTerm
# ---------------------------------------------------------------------------
# CoefficientEvaluator
# ---------------------------------------------------------------------------
_MISSING = object() # sentinel for dict.get() fast path
# Thresholds for periodic boundary discontinuity checks.
# The leak metric = (jump/scale) * (boundary_magnitude/scale) estimates the
# IBP energy leak: dE/dt ~ |β(L)-β(0)| * |f(L)|². For localized problems
# both the coefficient jump AND the boundary significance are small, so the
# product naturally suppresses false positives without a hardcoded skip.
_LEAK_WARN_THRESHOLD = 0.01 # leak metric > 1% → warning
_LEAK_ERROR_THRESHOLD = 0.25 # leak metric > 25% → hard error
_BOUNDARY_NEGLIGIBLE = 1e-10 # absolute boundary values below this are noise
[docs]
class CoefficientEvaluator:
"""Resolve operator term coefficients with multi-level caching.
Parameters
----------
spec : EquationSystem
Parsed JSON equation specification.
grid : GridInfo
Spatial grid (for position-dependent coefficient evaluation).
parameters : dict[str, float]
Runtime parameter overrides (e.g. ``{"m2": 1.0}``).
Raises
------
ValueError
If a symbolic coefficient references an unknown parameter and
no default numeric value is available.
"""
def __init__(
self,
spec: EquationSystem,
grid: GridInfo,
parameters: dict[str, float] | None = None,
) -> None:
self._spec = spec
self._grid = grid
self._parameters = parameters or {}
self._coordinates = spec.effective_coordinates
self._spatial_coords = spec.spatial_coordinates
# Static eval namespace (math funcs + parameters)
self._namespace = build_eval_namespace(self._parameters)
# Build coordinate arrays for position-dependent evaluation
self._coord_arrays: dict[str, NDArray[np.float64]] = {}
grid_coords = grid.coord_arrays()
for i, name in enumerate(self._spatial_coords):
if i < len(grid_coords):
self._coord_arrays[name] = grid_coords[i]
# L0: Pre-resolve constant coefficients
self._constants: dict[tuple[int, int], float] = {}
for eq_idx, eq in enumerate(spec.equations):
for term_idx, term in enumerate(eq.rhs_terms):
if self._is_constant(term):
self._constants[eq_idx, term_idx] = self._resolve_constant(term)
# L2: Spatial-only cache (position-dependent, NOT time-dependent)
self._spatial_cache: dict[tuple[int, int], NDArray[np.float64]] = {}
self._precompute_spatial()
# Reverse index: id(term) → (eq_idx, term_idx) for O(1) lookup
# in _check_mass_sign() (avoids O(S) scan of _spatial_cache)
self._term_to_spatial_key: dict[int, tuple[int, int]] = {
id(self._spec.equations[ei].rhs_terms[ti]): (ei, ti)
for (ei, ti) in self._spatial_cache
}
# L3: Per-timestep cache
self._timestep_cache: dict[
tuple[int, int, float], float | NDArray[np.float64]
] = {}
# Pre-check: skip begin_timestep() cache clear when all
# coefficients are time-independent (common case)
self._has_time_dependent = not self.all_time_independent()
# Fail-fast: validate all unresolved symbolic terms at init
self._validate_unresolved()
# Diagnostic: warn about mass sign changes
self._check_mass_sign()
[docs]
def resolve(
self,
term: OperatorTerm,
t: float = 0.0,
*,
eq_idx: int = -1,
term_idx: int = -1,
) -> float | NDArray[np.float64]:
"""Resolve effective coefficient for an operator term.
Parameters
----------
term : OperatorTerm
The operator term whose coefficient to resolve.
t : float
Current simulation time.
eq_idx : int
Equation index (for cache lookup). -1 disables caching.
term_idx : int
Term index within equation (for cache lookup). -1 disables caching.
Returns
-------
float | NDArray[np.float64]
Scalar for constant coefficients, grid-shaped array for
position-dependent ones.
"""
key = (eq_idx, term_idx)
# L0: Pre-resolved constant (single dict.get avoids two-op in+[])
c = self._constants.get(key, _MISSING)
if c is not _MISSING:
return cast("float | NDArray[np.float64]", c)
# No symbolic → return numeric coefficient directly
if term.coefficient_symbolic is None:
return term.coefficient
# Spatial-only cache (position-dependent, not time-dependent)
c = self._spatial_cache.get(key, _MISSING)
if c is not _MISSING:
return cast("NDArray[np.float64]", c)
# L3: Per-timestep cache
ts_key = (eq_idx, term_idx, t)
c = self._timestep_cache.get(ts_key, _MISSING)
if c is not _MISSING:
return cast("float | NDArray[np.float64]", c)
# Full evaluation
result = self._evaluate_symbolic(term, t)
# Cache if time-dependent (L3)
if term.time_dependent and eq_idx >= 0:
self._timestep_cache[ts_key] = result
return result
[docs]
def begin_timestep(self, t: float) -> None: # noqa: ARG002
"""Clear per-timestep cache (L3).
Call at the start of each timestep to ensure time-dependent
coefficients are re-evaluated. Skipped when all coefficients
are time-independent (common case — avoids empty dict.clear()).
"""
if self._has_time_dependent:
self._timestep_cache.clear()
[docs]
def all_constant(self) -> bool:
"""Check if every RHS term has a constant (scalar) coefficient.
Returns True when all coefficients are pre-resolved at L0 (no
position or time dependence). This enables the analytical
Jacobian optimization for constant-coefficient linear systems.
"""
for eq_idx, eq in enumerate(self._spec.equations):
for term_idx, _term in enumerate(eq.rhs_terms):
if (eq_idx, term_idx) not in self._constants:
return False
return True
[docs]
def all_time_independent(self) -> bool:
"""Check if every RHS term has a time-independent coefficient.
Returns True when no term has ``time_dependent=True``. Position-
dependent (but time-independent) coefficients still produce a
constant Jacobian because the spatial grid is fixed, so the
analytical Jacobian optimization applies.
"""
for eq in self._spec.equations:
for term in eq.rhs_terms:
if term.time_dependent:
return False
return True
# ---- Internal helpers ----
def _is_constant(self, term: OperatorTerm) -> bool:
"""Check if a term has a constant (non-varying) coefficient."""
sym = term.coefficient_symbolic
if sym is None:
return True
if term.time_dependent or term.position_dependent:
return False
# Simple parameter name or negated parameter
if sym in self._parameters:
return True
if sym.startswith("-") and sym[1:] in self._parameters:
return True
# Try to evaluate as compound expression (no coords needed)
try:
self._evaluate_symbolic(term, 0.0)
except (ValueError, TypeError, NameError):
return False
return True
def _resolve_constant(self, term: OperatorTerm) -> float:
"""Resolve a constant coefficient to a float.
Raises
------
TypeError
If a compound expression produces an array instead of scalar.
"""
sym = term.coefficient_symbolic
if sym is None:
return term.coefficient
# Fast path: direct parameter lookup
if sym in self._parameters:
return self._parameters[sym]
if sym.startswith("-") and sym[1:] in self._parameters:
return -self._parameters[sym[1:]]
# Compound expression (e.g. "-2*m2")
result = self._evaluate_symbolic(term, 0.0)
if isinstance(result, np.ndarray):
msg = (
f"Expected scalar for constant coefficient '{sym}', "
f"got array. Check coordinate_dependent flags."
)
raise TypeError(msg)
return float(result)
def _precompute_spatial(self) -> None:
"""Pre-compute L2 spatial-only coefficients."""
for eq_idx, eq in enumerate(self._spec.equations):
for term_idx, term in enumerate(eq.rhs_terms):
key = (eq_idx, term_idx)
if key in self._constants:
continue
if term.coefficient_symbolic is None:
continue
if term.position_dependent and not term.time_dependent:
result = self._evaluate_symbolic(term, 0.0)
if isinstance(result, np.ndarray):
self._spatial_cache[key] = result
else:
# Scalar result despite position_dependent flag —
# treat as constant
self._constants[key] = float(result)
def _validate_unresolved(self) -> None:
"""Fail fast on symbolic terms that can't be evaluated at init.
Any term with coefficient_symbolic that is NOT in L0 (constants),
L2 (spatial), or expected to vary with time must be evaluable now.
"""
for eq_idx, eq in enumerate(self._spec.equations):
for term_idx, term in enumerate(eq.rhs_terms):
key = (eq_idx, term_idx)
if term.coefficient_symbolic is None:
continue
if key in self._constants or key in self._spatial_cache:
continue
if term.time_dependent:
# Time-dependent: try evaluating at t=0 to validate
self._evaluate_symbolic(term, 0.0)
continue
# Not in any cache and not time-dependent → should have
# been resolved. Try again — will raise on failure.
self._evaluate_symbolic(term, 0.0)
def _evaluate_symbolic(
self, term: OperatorTerm, t: float
) -> float | NDArray[np.float64]:
"""Evaluate a symbolic coefficient expression."""
sym = term.coefficient_symbolic
if sym is None:
return term.coefficient
# Use coord_arrays only if position-dependent
coord_arrays = self._coord_arrays if term.position_dependent else None
return evaluate_coefficient(
sym, self._parameters, self._coordinates, coord_arrays, t
)
def _check_mass_sign(self) -> None:
"""Warn if position-dependent mass-like terms change sign on grid."""
for eq in self._spec.equations:
for term in eq.rhs_terms:
if (
term.operator != "identity"
or term.field != eq.field_name
or term.coefficient_symbolic is None
or not term.position_dependent
or term.time_dependent
):
continue
# O(1) lookup via reverse index (avoids scanning _spatial_cache)
cache_key = self._term_to_spatial_key.get(id(term))
if cache_key is None:
continue
arr = self._spatial_cache[cache_key]
if float(arr.min()) * float(arr.max()) < 0:
warnings.warn(
f"Position-dependent mass term "
f"'{term.coefficient_symbolic}' for field "
f"'{eq.field_name}' changes sign across "
f"the grid (min={float(arr.min()):.4g}, "
f"max={float(arr.max()):.4g}). This may "
f"cause tachyonic instability at locations "
f"where the effective mass² is negative.",
UserWarning,
stacklevel=2,
)
[docs]
def check_periodic_coefficient_continuity(
self,
periodic: tuple[bool, ...],
*,
rtol: float | None = None,
) -> None:
"""Warn if position-dependent coefficients are discontinuous at periodic boundaries.
When periodic BCs are used, the conservation proof for the PDE system
requires all coefficient functions to be continuous across the periodic
boundary (so that integration-by-parts boundary terms vanish).
Non-periodic coefficients (e.g. ``B(x) = B0/x^3`` on ``[5, 80]``)
produce O(1) energy non-conservation that is independent of grid
resolution and solver tolerances.
The check uses a **leak metric** that estimates the IBP energy leak::
leak = (jump / scale) * (boundary_magnitude / scale)
where *jump* is ``|coeff(L) - coeff(0)|``, *scale* is the peak
``|coeff|``, and *boundary_magnitude* is the larger of ``|coeff(0)|``
and ``|coeff(L)|``. This product naturally suppresses false positives
for localized coefficients (both factors small at boundaries) while
preserving detection of genuine discontinuities.
Parameters
----------
periodic : tuple[bool, ...]
Per-axis periodicity flags.
rtol : float, optional
Solver relative tolerance. When provided, the warn and error
thresholds scale as ``sqrt(rtol)`` — stricter for machine-precision
solvers (modal, leapfrog) and more lenient for coarse exploratory
runs. When *None*, the default thresholds (warn=0.01, error=0.25)
are used.
Raises
------
ValueError
If the leak metric exceeds the error threshold.
"""
if not any(periodic):
return # no periodic axes → no check needed
# Scale thresholds with sqrt(rtol) when the solver tolerance is known.
# sqrt balances: leak accumulates over O(1/rtol) steps but each step's
# contribution is proportional to the leak metric. For rtol=1e-8
# (CVODE default) this gives warn~1e-4, error~2.5e-3. For machine-
# precision solvers (rtol~1e-15) it gives warn~3e-8, error~8e-7.
# Scale thresholds with sqrt(rtol) when the solver tolerance is known.
# sqrt balances: leak accumulates over O(1/rtol) steps but each step's
# contribution is proportional to the leak metric. For rtol=1e-8
# (CVODE default) this gives warn~1e-4, error~2.5e-3. For machine-
# precision solvers (rtol~1e-15) it gives warn~3e-8, error~8e-7.
if rtol is not None:
import math # noqa: PLC0415
sf = math.sqrt(rtol / 1e-8) # normalized to CVODE default
thresh_warn = _LEAK_WARN_THRESHOLD * sf
thresh_error = _LEAK_ERROR_THRESHOLD * sf
else:
thresh_warn = _LEAK_WARN_THRESHOLD
thresh_error = _LEAK_ERROR_THRESHOLD
for (eq_idx, term_idx), arr in self._spatial_cache.items():
# Check each periodic axis for boundary discontinuity
for axis, is_periodic in enumerate(periodic):
if not is_periodic:
continue
if axis >= arr.ndim:
continue
# Compare first and last slices along this axis
first = np.take(arr, 0, axis=axis)
last = np.take(arr, arr.shape[axis] - 1, axis=axis)
scale = max(float(np.abs(arr).max()), 1e-30)
jump = float(np.abs(first - last).max())
boundary_magnitude = max(
float(np.abs(first).max()),
float(np.abs(last).max()),
)
# Skip if boundary values are at machine noise level.
if boundary_magnitude < _BOUNDARY_NEGLIGIBLE:
continue
# Leak metric: product of relative jump and boundary significance.
# IBP leak ~ |β(L)-β(0)| * |f(L)|². For localized problems both
# the coefficient jump and field amplitude are small at boundaries.
# We use boundary_magnitude/scale as a proxy for field localisation
# (localized coefficients ↔ localized fields), giving:
# leak ~ (jump/scale) * (boundary_magnitude/scale)
# This naturally suppresses false positives for localized
# coefficients while preserving detection of genuine
# discontinuities where boundary values are significant.
rel_jump = jump / scale
boundary_fraction = boundary_magnitude / scale
leak_metric = rel_jump * boundary_fraction
if leak_metric > thresh_warn:
term = self._spec.equations[eq_idx].rhs_terms[term_idx]
field = self._spec.equations[eq_idx].field_name
axis_name = (
self._spatial_coords[axis]
if axis < len(self._spatial_coords)
else f"axis {axis}"
)
msg = (
f"Position-dependent coefficient "
f"'{term.coefficient_symbolic}' in equation for "
f"'{field}' has {rel_jump:.0%} jump at the periodic "
f"boundary along {axis_name} "
f"(left={float(first.flat[0]):.4g}, "
f"right={float(last.flat[0]):.4g}, "
f"leak_metric={leak_metric:.2g}). "
f"This breaks the integration-by-parts identity "
f"and causes O(1) energy non-conservation. "
f"Use a larger domain, non-periodic BCs, or a "
f"localized coefficient profile."
)
if leak_metric > thresh_error:
raise ValueError(msg)
warnings.warn(msg, UserWarning, stacklevel=2)