"""scipy.integrate.solve_ivp wrapper for TIDAL — explicit adaptive ODE solver.
Provides DOP853 (8th-order Dormand-Prince) as default method, which excels
for smooth non-stiff wave problems with fewer function evaluations than
implicit BDF for the same accuracy.
Also supports implicit methods (Radau, BDF) with Jacobian sparsity from
the existing TIDAL sparsity builder.
Reference: Dormand & Prince, "A family of embedded Runge-Kutta formulae",
J. Comp. Appl. Math. 6, 1980.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from tidal.solver._defaults import DEFAULT_ATOL, DEFAULT_RTOL
from tidal.solver._scipy_types import IVPResult, call_solve_ivp
from tidal.solver._setup import build_rhs_evaluator, warn_frozen_constraints
from tidal.solver.fields import FieldSet
from tidal.solver.leapfrog import compute_force
from tidal.solver.state import StateLayout
if TYPE_CHECKING:
from collections.abc import Callable
from tidal.solver._types import SolverResult
from tidal.solver.grid import GridInfo
from tidal.solver.operators import BCSpec
from tidal.solver.progress import SimulationProgress
from tidal.solver.rhs import RHSEvaluator
from tidal.symbolic.json_loader import EquationSystem
# ---------------------------------------------------------------------------
# RHS builder
# ---------------------------------------------------------------------------
# Implicit methods that benefit from Jacobian sparsity
_IMPLICIT_METHODS = {"Radau", "BDF"}
def _build_rhs_fn( # noqa: PLR0913, PLR0917
spec: EquationSystem,
layout: StateLayout,
grid: GridInfo,
bc: BCSpec | None,
rhs_eval: RHSEvaluator,
progress: SimulationProgress | None = None,
) -> Callable[[float, np.ndarray], np.ndarray]:
"""Build the scipy RHS closure: ``rhs_fn(t, y) -> dydt``."""
eq_map = spec.equation_map
fs = FieldSet.zeros(layout, grid.shape)
force_buf = np.zeros(layout.total_size)
dydt_buf = np.zeros(layout.total_size)
drift_pairs = layout.drift_slot_pairs
def rhs_fn(t: float, y: np.ndarray) -> np.ndarray:
if progress is not None:
progress.update(t)
dydt_buf.fill(0.0)
compute_force(
spec,
layout,
grid,
bc,
y,
t,
rhs_eval,
out=force_buf,
fieldset=fs,
)
for _si, s, _fn in layout.velocity_slot_groups:
dydt_buf[s] = force_buf[s]
# Zero-copy velocity: read directly from y's velocity slots
for field_slice, vel_slice in drift_pairs:
dydt_buf[field_slice] = y[vel_slice]
for _si, s, field_name in layout.first_order_slot_groups:
eq_idx = eq_map.get(field_name)
if eq_idx is not None:
result = rhs_eval.evaluate(eq_idx, fs, t)
dydt_buf[s] = result.ravel()
return dydt_buf
return rhs_fn
# ---------------------------------------------------------------------------
# Solver entry point
# ---------------------------------------------------------------------------
[docs]
def solve_scipy( # noqa: PLR0913
spec: EquationSystem,
grid: GridInfo,
y0: np.ndarray,
t_span: tuple[float, float],
*,
bc: BCSpec | None = None,
parameters: dict[str, float] | None = None,
method: str = "DOP853",
rtol: float = DEFAULT_RTOL,
atol: float = DEFAULT_ATOL,
max_step: float = np.inf,
num_snapshots: int = 101,
snapshot_callback: Callable[[float, np.ndarray], None] | None = None,
progress: SimulationProgress | None = None,
) -> SolverResult:
"""Solve a TIDAL equation system using scipy.integrate.solve_ivp.
Parameters
----------
spec : EquationSystem
Parsed equation specification.
grid : GridInfo
Spatial grid.
y0 : np.ndarray
Initial state vector (flat).
t_span : tuple[float, float]
(t_start, t_end).
bc : str or tuple, optional
Boundary conditions.
parameters : dict[str, float], optional
Runtime parameter overrides for symbolic coefficients.
method : str
Integration method: ``'DOP853'`` (default), ``'RK45'``, ``'Radau'``,
``'BDF'``, ``'RK23'``, ``'LSODA'``.
rtol, atol : float
Relative and absolute tolerances.
max_step : float
Maximum step size. ``np.inf`` (default) = unbounded.
num_snapshots : int
Number of output time points.
snapshot_callback : callable, optional
Called as ``callback(t, y)`` at each output time.
Returns
-------
dict
Result dictionary with keys: ``t``, ``y``, ``success``, ``message``.
Warns
-----
UserWarning
If constraint fields (time_order=0) are present — they remain frozen
at initial values.
"""
layout = StateLayout.from_spec(spec, grid.num_points)
warn_frozen_constraints(layout, "scipy")
rhs_eval = build_rhs_evaluator(spec, grid, parameters, bc, rtol=rtol)
# Build RHS closure (with optional progress tracking)
rhs_fn = _build_rhs_fn(spec, layout, grid, bc, rhs_eval, progress=progress)
# Build time evaluation points
t_eval = np.linspace(t_span[0], t_span[1], num_snapshots)
# For implicit methods, provide Jacobian sparsity
jac_sparsity = None
if method in _IMPLICIT_METHODS:
from tidal.solver.sparsity import build_jacobian_sparsity # noqa: PLC0415
jac_sparsity = build_jacobian_sparsity(spec, layout, grid, bc)
result: IVPResult = call_solve_ivp(
rhs_fn,
t_span,
y0,
method=method,
t_eval=t_eval,
rtol=rtol,
atol=atol,
max_step=max_step,
jac_sparsity=jac_sparsity,
)
if progress is not None:
progress.finish()
# Call snapshot callback at each output time
if snapshot_callback is not None and result.success:
for i in range(len(result.t)):
snapshot_callback(result.t[i], result.y[i])
return {
"t": result.t,
"y": result.y,
"success": result.success,
"message": result.message,
}