"""SUNDIALS/CVODE integration for TIDAL — adaptive ODE solver for wave systems.
Pure ODE solver using the CVODE module from SUNDIALS via scikit-sundae.
Supports both BDF (stiff, order 1-5) and Adams (non-stiff, order 1-12)
methods with tolerance-controlled adaptive time stepping.
For wave (second-order) equations, the system is reduced to first-order
ODE form using E-L velocity formulation:
dq/dt = v (trivial kinematic)
dv/dt = E-L RHS (from spatial operators)
Constraint fields (time_order=0) are frozen at initial values — correct
for gauge-fixed systems (e.g. Coulomb gauge A_0 = 0).
Reference: Hindmarsh et al., "SUNDIALS: Suite of Nonlinear and
Differential/Algebraic Equation Solvers", ACM TOMS 31(3), 2005.
scikit-sundae: NREL, BSD-3 license.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import numpy as np
from tidal.solver._defaults import DEFAULT_ATOL, DEFAULT_RTOL
from tidal.solver._setup import (
build_rhs_evaluator,
configure_linear_solver,
warn_frozen_constraints,
)
from tidal.solver._sksundae import SundialsResult, call_cvode, call_cvode_stepwise
from tidal.solver.fields import FieldSet
from tidal.solver.leapfrog import compute_force, compute_velocity
from tidal.solver.state import StateLayout
if TYPE_CHECKING:
from collections.abc import Callable
from tidal.solver._types import SolverResult
from tidal.solver.grid import GridInfo
from tidal.solver.operators import BCSpec
from tidal.solver.progress import SimulationProgress
from tidal.solver.rhs import RHSEvaluator
from tidal.symbolic.json_loader import EquationSystem
# ---------------------------------------------------------------------------
# RHS builder
# ---------------------------------------------------------------------------
def _build_rhsfn(
spec: EquationSystem,
layout: StateLayout,
grid: GridInfo,
bc: BCSpec | None,
rhs_eval: RHSEvaluator,
) -> Callable[[float, np.ndarray, np.ndarray], None]:
"""Build the CVODE RHS closure: ``rhsfn(t, y, yp)``."""
eq_map = spec.equation_map
fs = FieldSet.zeros(layout, grid.shape)
force_buf = np.zeros(layout.total_size)
vel_buf = np.zeros(layout.total_size)
def rhsfn(t: float, y: np.ndarray, yp: np.ndarray) -> None:
force = compute_force(
spec,
layout,
grid,
bc,
y,
t,
rhs_eval,
out=force_buf,
fieldset=fs,
)
velocity = compute_velocity(layout, y, out=vel_buf)
for _si, s, _fn in layout.velocity_slot_groups:
yp[s] = force[s]
for _si, s, _vs in layout.dynamical_field_slot_groups:
yp[s] = velocity[s]
for _si, s, field_name in layout.first_order_slot_groups:
eq_idx = eq_map.get(field_name)
if eq_idx is not None:
result = rhs_eval.evaluate(eq_idx, fs, t)
yp[s] = result.ravel()
else:
yp[s] = 0.0
for _si, s, _fn in layout.constraint_slot_groups:
yp[s] = 0.0
return rhsfn
# ---------------------------------------------------------------------------
# Solver entry point
# ---------------------------------------------------------------------------
[docs]
def solve_cvode( # noqa: PLR0913
spec: EquationSystem,
grid: GridInfo,
y0: np.ndarray,
t_span: tuple[float, float],
*,
bc: BCSpec | None = None,
parameters: dict[str, float] | None = None,
method: str = "BDF",
rtol: float = DEFAULT_RTOL,
atol: float = DEFAULT_ATOL,
max_step: float = 0.0,
max_num_steps: int = 50000,
num_snapshots: int = 101,
snapshot_callback: Callable[[float, np.ndarray], None] | None = None,
progress: SimulationProgress | None = None,
) -> SolverResult:
"""Solve a TIDAL equation system using SUNDIALS/CVODE.
Parameters
----------
spec : EquationSystem
Parsed equation specification.
grid : GridInfo
Spatial grid.
y0 : np.ndarray
Initial state vector (flat).
t_span : tuple[float, float]
(t_start, t_end).
bc : str or tuple, optional
Boundary conditions.
parameters : dict[str, float], optional
Runtime parameter overrides for symbolic coefficients.
method : str
Integration method: ``'BDF'`` (stiff, default) or ``'Adams'`` (non-stiff).
rtol, atol : float
Relative and absolute tolerances.
max_step : float
Maximum step size. ``0`` (default) = unbounded (SUNDIALS default).
max_num_steps : int
Maximum solver steps per output interval.
num_snapshots : int
Number of output time points.
snapshot_callback : callable, optional
Called as ``callback(t, y)`` at each output time.
Returns
-------
dict
Result dictionary with keys: ``t``, ``y``, ``success``, ``message``.
Warns
-----
UserWarning
If constraint fields (time_order=0) are present — they remain frozen
at initial values.
"""
layout = StateLayout.from_spec(spec, grid.num_points)
warn_frozen_constraints(layout, "CVODE")
rhs_eval = build_rhs_evaluator(spec, grid, parameters, bc, rtol=rtol)
# Build RHS closure
rhsfn = _build_rhsfn(spec, layout, grid, bc, rhs_eval)
# Configure CVODE solver
options: dict[str, Any] = {
"method": method,
"rtol": rtol,
"atol": atol,
"max_num_steps": max_num_steps,
}
if max_step > 0:
options["max_step"] = max_step
configure_linear_solver(
options,
layout,
spec,
grid,
bc,
parameters=parameters,
solver="cvode",
)
# Build time evaluation points
t_eval = np.linspace(t_span[0], t_span[1], num_snapshots)
if progress is not None:
# Step-by-step mode: progress updates between solver steps (zero overhead)
result: SundialsResult = call_cvode_stepwise(
rhsfn,
t_eval,
y0,
progress,
snapshot_callback=snapshot_callback,
**options,
)
else:
result = call_cvode(rhsfn, t_eval, y0, **options)
# Call snapshot callback at each output time
if snapshot_callback is not None and result.success:
for i in range(len(result.t)):
snapshot_callback(result.t[i], result.y[i])
return {
"t": result.t,
"y": result.y,
"success": result.success,
"message": result.message,
}