Source code for tidal.solver.rhs

"""Unified RHS evaluation for TIDAL solvers.

``RHSEvaluator`` applies spatial operators with resolved coefficients,
replacing the duplicated inner loops in ``ida.py`` and ``leapfrog.py``.

Depends on ``FieldSet`` (Phase 1) and ``CoefficientEvaluator`` (Phase 2).
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np

from tidal.solver.operators import (
    OPERATOR_REGISTRY,
    AxisBCSpec,
    BCSpec,
    _bc_from_grid,  # pyright: ignore[reportPrivateUsage]
    _normalize_bc,  # pyright: ignore[reportPrivateUsage]
    _resolve_axis_bc,  # pyright: ignore[reportPrivateUsage]
)

if TYPE_CHECKING:
    from tidal.solver.coefficients import CoefficientEvaluator
    from tidal.solver.fields import FieldSet
    from tidal.solver.grid import GridInfo
    from tidal.symbolic.json_loader import EquationSystem


# ---------------------------------------------------------------------------
# RHSEvaluator
# ---------------------------------------------------------------------------


[docs] class RHSEvaluator: """Evaluate RHS of field equations with resolved coefficients. Applies spatial operators to field data and multiplies by resolved coefficients (constant, parameter-overridden, position-dependent, or time-dependent). Parameters ---------- spec : EquationSystem Parsed JSON equation specification. grid : GridInfo Spatial grid. coeff_eval : CoefficientEvaluator Coefficient resolver with caching. bc : str or tuple of str, optional Boundary conditions for spatial operators. """ def __init__( self, spec: EquationSystem, grid: GridInfo, coeff_eval: CoefficientEvaluator, bc: BCSpec | None = None, ) -> None: self._spec = spec self._grid = grid self._coeff_eval = coeff_eval self._bc = bc self._eq_map: dict[str, int] = spec.equation_map self._result_buffer = np.zeros(grid.shape) self._term_buffer = np.zeros(grid.shape) # Pre-normalize BCs once (avoids per-call validation in operators) if bc is not None: self._normalized_bc: tuple[str | AxisBCSpec, ...] = _normalize_bc(bc, grid) else: self._normalized_bc = _bc_from_grid(grid) # Pre-resolve per-axis BCs to AxisBCSpec (avoids isinstance + # _str_to_axis_bc object creation per operator call in hot path) self._resolved_bcs: tuple[AxisBCSpec, ...] = tuple( _resolve_axis_bc(bc_entry) for bc_entry in self._normalized_bc ) # Pre-resolve operator functions AND velocity names # (avoids dict lookup and f-string allocation per call) self._resolved_ops: list[list[tuple[Any, str, str]]] = [] for eq in spec.equations: ops: list[tuple[Any, str, str]] = [] for term in eq.rhs_terms: if term.operator == "first_derivative_t": ops.append((None, term.field, f"v_{term.field}")) else: fn = OPERATOR_REGISTRY.get(term.operator) if fn is None: msg = ( f"Unknown operator {term.operator!r}; " f"known: {sorted(OPERATOR_REGISTRY)}" ) raise ValueError(msg) ops.append((fn, term.field, "")) self._resolved_ops.append(ops)
[docs] def begin_timestep(self, t: float) -> None: """Notify the coefficient evaluator of a new timestep.""" self._coeff_eval.begin_timestep(t)
[docs] def evaluate( self, eq_idx: int, fields: FieldSet, t: float = 0.0, ) -> np.ndarray: """Compute RHS for a single equation. Parameters ---------- eq_idx : int Index of the equation in ``spec.equations``. fields : FieldSet Current field state. t : float Current simulation time. Returns ------- np.ndarray Grid-shaped result array. **Warning:** the returned array is an internal buffer and may be overwritten by the next call to ``evaluate()``. Callers must copy if they need to persist it. """ eq = self._spec.equations[eq_idx] result = self._result_buffer temp = self._term_buffer terms = eq.rhs_terms if not terms: result.fill(0.0) return result resolved = self._resolved_ops[eq_idx] # First term: write directly to result (eliminates fill(0)) operated = self._apply_resolved(resolved[0], fields) coeff = self._coeff_eval.resolve(terms[0], t, eq_idx=eq_idx, term_idx=0) np.multiply(coeff, operated, out=result) # Remaining terms: accumulate for term_idx in range(1, len(terms)): operated = self._apply_resolved(resolved[term_idx], fields) coeff = self._coeff_eval.resolve( terms[term_idx], t, eq_idx=eq_idx, term_idx=term_idx ) np.multiply(coeff, operated, out=temp) result += temp return result
[docs] def evaluate_by_field( self, field_name: str, fields: FieldSet, t: float = 0.0, ) -> np.ndarray: """Compute RHS for the equation governing *field_name*. Raises ------ KeyError If *field_name* has no associated equation. """ eq_idx = self._eq_map.get(field_name) if eq_idx is None: msg = f"No equation for field '{field_name}'" raise KeyError(msg) return self.evaluate(eq_idx, fields, t)
# ---- Internal helpers ---- def _apply_resolved( self, resolved: tuple[Any, str, str], fields: FieldSet, ) -> np.ndarray: """Apply a pre-resolved operator, returning the operated data. Uses pre-resolved function pointer, pre-computed velocity name, and pre-normalized BCs to avoid per-call overhead. Raises ------ ValueError If a ``first_derivative_t`` term references a field whose velocity slot is not present in the state. """ op_fn, field_name, vel_name = resolved if op_fn is None: # first_derivative_t if vel_name not in fields: msg = ( f"Cannot resolve first_derivative_t({field_name}): " f"velocity slot '{vel_name}' not found. " f"Available: {sorted(fields.slot_names)}" ) raise ValueError(msg) return fields[vel_name] target = self._get_field_data(field_name, fields) return op_fn(target, self._grid, self._resolved_bcs) @staticmethod def _get_field_data(field_name: str, fields: FieldSet) -> np.ndarray: """Get field data. Raises on unknown field references. Raises ------ ValueError If *field_name* cannot be resolved to any known field. """ if field_name in fields: return fields[field_name] msg = ( f"Unknown field reference '{field_name}'. " f"Available: {sorted(fields.slot_names)}" ) raise ValueError(msg)