"""Fourier modal solver — exact spectral time evolution for linear PDEs.
Transforms the spatial grid to Fourier space, builds a per-mode evolution
matrix, and eigendecomposes to obtain the exact solution y(t) = exp(A·t)·y₀.
Eliminates spatial discretization error entirely and provides machine-precision
solutions for time-independent linear systems.
Applicable to any linear PDE system with:
- Flat (Minkowski) metric
- All-periodic boundary conditions
- Time-independent coefficients (position-dependent OK via convolution)
- Operators with known exact Fourier multipliers
Two algorithm paths are used depending on coefficient structure:
- Constant coefficients: per-mode eigendecomposition with block-aware independent
blocks (machine-precision, ~14x faster).
- Position-dependent coefficients: Krylov matrix exponential (expm_multiply) which
is backward-stable for non-normal convolution matrices where eigendecomposition
gives incorrect results due to pseudospectral overflow.
References
----------
Moler & Van Loan (2003), SIAM Review 45(1):3-49 — matrix exponential.
Hairer, Lubich & Wanner (2006), Geometric Numerical Integration, §4.
Burns et al. (2020), Phys. Rev. Research 2:023068 — pseudo-spectral.
Raffelt & Stodolsky (1988), PRD 37:1237 — mixing-matrix formalism.
"""
# ruff: noqa: N803, N806 — uppercase names for matrices (A, V, T, Z) follow
# standard linear-algebra notation.
# ruff: noqa: PLR0913, PLR0917, PLR0914, PLR0912, PLR0911, PLR0915, PLR2004
# — numerical code inherently requires many arguments, local variables,
# return statements, statements, and literal comparisons.
# ruff: noqa: C901 — complexity and Unicode math symbols.
# ruff: noqa: ERA001, ARG001 — commented-out code serves as documentation;
# unused args (bc, grid) kept for interface consistency with other solvers.
# ruff: noqa: B903, PLR1702 — _OperatorDecomp uses __slots__ for memory efficiency;
# nested block depth is inherent to multi-field modal algebra.
from __future__ import annotations
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any
import numpy as np
from numpy.typing import NDArray
from tidal.solver._defaults import DEFAULT_ATOL, DEFAULT_RTOL
from tidal.solver._setup import warn_frozen_constraints
from tidal.solver.operators import get_wavenumbers, is_periodic_bc
from tidal.solver.state import StateLayout
if TYPE_CHECKING:
from tidal.solver._types import SolverResult
from tidal.solver.coefficients import CoefficientEvaluator
from tidal.solver.grid import GridInfo
from tidal.solver.operators import BCSpec
from tidal.solver.progress import SimulationProgress
from tidal.symbolic.json_loader import (
ComponentEquation,
EquationSystem,
OperatorTerm,
)
# ---------------------------------------------------------------------------
# Exact Fourier multipliers (angular wavenumber convention: k = 2π·rfftfreq)
# ---------------------------------------------------------------------------
# These use the EXACT wavenumber k, consistent with operators.py spectral
# mode (gradient → ik, laplacian → -k²). NOT the modified-wavenumber
# convention from constraint_solve.py which matches FD stencils.
_ExactMultFn = Callable[[list[NDArray[np.float64]]], NDArray[Any] | int]
# ---------------------------------------------------------------------------
# Operator decomposition: (spatial Fourier multiplier, time derivative order)
# ---------------------------------------------------------------------------
# Every operator in flat spacetime decomposes as spatial_multiplier(k) x ∂ⁿ_t.
# Derivatives commute in Minkowski: ∂²_t ∂_x = ∂_x ∂²_t.
#
# Time order classification:
# 0 = position operator → stiffness matrix K
# 1 = velocity operator → damping matrix D
# 2 = acceleration operator → mass matrix M (off-diagonal, implicit coupling)
# 3 = jerk operator → eliminated via EOM substitution
#
# References:
# Golub & Van Loan (2013), Matrix Computations, 4th ed. §7.7
# Hairer & Lubich (2003), ZAMM 83(1) — mass matrices in structural dynamics
class _OperatorDecomp:
"""Operator decomposition into spatial multiplier and time derivative order."""
__slots__ = ("spatial_fn", "time_order")
def __init__(self, spatial_fn: _ExactMultFn, time_order: int) -> None:
self.spatial_fn = spatial_fn
self.time_order = time_order
_OPERATOR_DECOMP: dict[str, _OperatorDecomp] = {
# --- Pure spatial operators (time_order=0) ---
"identity": _OperatorDecomp(lambda k_axes: np.ones_like(k_axes[0]), 0),
"laplacian": _OperatorDecomp(
lambda k_axes: -sum(ki**2 for ki in k_axes),
0,
),
"laplacian_x": _OperatorDecomp(lambda k_axes: -(k_axes[0] ** 2), 0),
"laplacian_y": _OperatorDecomp(lambda k_axes: -(k_axes[1] ** 2), 0),
"laplacian_z": _OperatorDecomp(lambda k_axes: -(k_axes[2] ** 2), 0),
"gradient_x": _OperatorDecomp(lambda k_axes: 1j * k_axes[0], 0),
"gradient_y": _OperatorDecomp(lambda k_axes: 1j * k_axes[1], 0),
"gradient_z": _OperatorDecomp(lambda k_axes: 1j * k_axes[2], 0),
"cross_derivative_xy": _OperatorDecomp(lambda k_axes: -(k_axes[0] * k_axes[1]), 0),
"cross_derivative_xz": _OperatorDecomp(lambda k_axes: -(k_axes[0] * k_axes[2]), 0),
"cross_derivative_yz": _OperatorDecomp(lambda k_axes: -(k_axes[1] * k_axes[2]), 0),
"biharmonic": _OperatorDecomp(
lambda k_axes: sum(ki**2 for ki in k_axes) ** 2,
0,
),
"derivative_3_x": _OperatorDecomp(lambda k_axes: -1j * k_axes[0] ** 3, 0),
"derivative_3_y": _OperatorDecomp(lambda k_axes: -1j * k_axes[1] ** 3, 0),
"derivative_3_z": _OperatorDecomp(lambda k_axes: -1j * k_axes[2] ** 3, 0),
# --- Velocity operators (time_order=1) ---
"first_derivative_t": _OperatorDecomp(lambda k_axes: np.ones_like(k_axes[0]), 1),
"mixed_T1_S1x": _OperatorDecomp(lambda k_axes: 1j * k_axes[0], 1),
"mixed_T1_S1y": _OperatorDecomp(lambda k_axes: 1j * k_axes[1], 1),
"mixed_T1_S1z": _OperatorDecomp(lambda k_axes: 1j * k_axes[2], 1),
# --- Acceleration operators (time_order=2) ---
"d2_t": _OperatorDecomp(lambda k_axes: np.ones_like(k_axes[0]), 2),
"mixed_T2_S1x": _OperatorDecomp(lambda k_axes: 1j * k_axes[0], 2),
"mixed_T2_S1y": _OperatorDecomp(lambda k_axes: 1j * k_axes[1], 2),
"mixed_T2_S1z": _OperatorDecomp(lambda k_axes: 1j * k_axes[2], 2),
"mixed_T2_S2x": _OperatorDecomp(lambda k_axes: -(k_axes[0] ** 2), 2),
"mixed_T2_S2y": _OperatorDecomp(lambda k_axes: -(k_axes[1] ** 2), 2),
"mixed_T2_S2z": _OperatorDecomp(lambda k_axes: -(k_axes[2] ** 2), 2),
# --- Jerk operators (time_order=3, eliminated via EOM substitution) ---
"d3_t": _OperatorDecomp(lambda k_axes: np.ones_like(k_axes[0]), 3),
"mixed_T3_S1x": _OperatorDecomp(lambda k_axes: 1j * k_axes[0], 3),
"mixed_T3_S1y": _OperatorDecomp(lambda k_axes: 1j * k_axes[1], 3),
"mixed_T3_S1z": _OperatorDecomp(lambda k_axes: 1j * k_axes[2], 3),
}
# Backward-compatible mapping: operator name → spatial multiplier function.
# Used by existing code paths that only need the spatial part.
_EXACT_MULTIPLIERS: dict[str, _ExactMultFn] = {
name: dec.spatial_fn for name, dec in _OPERATOR_DECOMP.items()
}
# ---------------------------------------------------------------------------
# Eligibility check
# ---------------------------------------------------------------------------
[docs]
def can_use_modal(
spec: EquationSystem,
grid: GridInfo,
bc: BCSpec | None,
) -> bool:
"""Check whether the modal solver is applicable to this system.
Requirements (checked in order):
1. Flat metric (volume_element is None)
2. Constraints (time_order=0) must be Fourier-eliminable via Schur complement
3. All boundary conditions periodic
4. All RHS operators have exact Fourier multipliers
5. No time-dependent coefficients
"""
# 1. Flat metric — curved metrics have non-None volume_element
if spec.canonical is not None and spec.canonical.volume_element is not None:
return False
# Also reject if canonical is None but any term is position-dependent
# with non-Cartesian coordinate names (heuristic for curved metrics
# without canonical structure)
if spec.canonical is None:
for eq in spec.equations:
for term in eq.rhs_terms:
if term.position_dependent:
coords = set(term.coordinate_dependent)
# Non-Cartesian coordinate references suggest curved metric
cartesian = {"x", "y", "z"}
if coords - cartesian:
return False
# 2. Constraints — allow if Fourier-eliminable via Schur complement
constraint_eqs = [eq for eq in spec.equations if eq.time_derivative_order == 0]
if constraint_eqs and not _constraints_fourier_eliminable(spec, constraint_eqs):
return False
# 3. All-periodic BCs
if not all(grid.periodic):
return False
if bc is not None:
if isinstance(bc, str):
if not is_periodic_bc(bc):
return False
else:
for b in bc:
if not is_periodic_bc(b):
return False
# 4. All operators supported (spatial or time-derivative decomposable)
for eq in spec.equations:
for term in eq.rhs_terms:
if term.operator not in _OPERATOR_DECOMP:
return False
# 5. No time-dependent coefficients
for eq in spec.equations:
for term in eq.rhs_terms:
if term.time_dependent:
return False
return True
# ---------------------------------------------------------------------------
# FFT state transforms
# ---------------------------------------------------------------------------
def _fft_slots(
y: NDArray[np.float64],
layout: StateLayout,
grid: GridInfo,
) -> NDArray[np.complex128]:
"""Transform each slot from physical space to Fourier space (rfft).
Returns a complex array of shape (n_slots, n_modes) where n_modes is
the rfft output length for the 1D case (n//2+1), or the product of
rfft output lengths for multi-D.
"""
n_slots = layout.num_slots
n_pts = layout.num_points
shape = grid.shape
# For 1D: rfft output length is shape[0]//2 + 1
# For nD: rfftn produces shape[:-1] + (shape[-1]//2+1,)
# Compute analytically instead of probing with a zero FFT.
rfft_shape = list(shape)
rfft_shape[-1] = shape[-1] // 2 + 1
n_modes = int(np.prod(rfft_shape))
y_hat = np.zeros((n_slots, n_modes), dtype=np.complex128)
for slot_idx in range(n_slots):
start = slot_idx * n_pts
end = start + n_pts
field_data = y[start:end].reshape(shape)
y_hat[slot_idx] = np.fft.rfftn(field_data).ravel()
return y_hat
def _ifft_slots(
y_hat: NDArray[np.complex128],
layout: StateLayout,
grid: GridInfo,
) -> NDArray[np.float64]:
"""Transform each slot from Fourier space back to physical space (irfft).
Returns a real flat array of shape (n_slots * n_pts,).
"""
n_slots = layout.num_slots
n_pts = layout.num_points
shape = grid.shape
# Determine rfftn output shape for irfftn reconstruction
rfft_shape = list(shape)
rfft_shape[-1] = shape[-1] // 2 + 1
rfft_shape_tuple = tuple(rfft_shape)
y_out = np.zeros(n_slots * n_pts)
for slot_idx in range(n_slots):
hat_data = y_hat[slot_idx].reshape(rfft_shape_tuple)
physical = np.fft.irfftn(hat_data, s=shape, axes=list(range(len(shape))))
y_out[slot_idx * n_pts : (slot_idx + 1) * n_pts] = physical.ravel()
return y_out
# ---------------------------------------------------------------------------
# Wavenumber grid construction
# ---------------------------------------------------------------------------
def _build_k_axes(grid: GridInfo) -> list[NDArray[np.float64]]:
"""Build wavenumber arrays for each spatial axis.
Uses the same convention as operators.get_wavenumbers:
k = 2π · rfftfreq(N, d=dx) for the last axis (rfft),
k = 2π · fftfreq(N, d=dx) for all other axes (full fft).
"""
k_axes: list[NDArray[np.float64]] = []
ndim = grid.ndim
for ax in range(ndim):
n = grid.shape[ax]
dx = grid.dx[ax]
if ax == ndim - 1:
# Last axis uses rfft (half-complex)
k = get_wavenumbers(n, dx)
else:
# Other axes use full fft
k = np.asarray(2.0 * np.pi * np.fft.fftfreq(n, d=dx), dtype=np.float64)
k_axes.append(k)
return k_axes
def _build_k_grid(
k_axes: list[NDArray[np.float64]],
) -> list[NDArray[np.float64]]:
"""Build broadcasted k-grid arrays from per-axis wavenumbers.
Returns a list of arrays, one per axis, each broadcastable to the
full rfft output shape.
"""
ndim = len(k_axes)
k_grid: list[NDArray[np.float64]] = []
for ax in range(ndim):
shape = [1] * ndim
shape[ax] = len(k_axes[ax])
k_grid.append(k_axes[ax].reshape(shape))
return k_grid
# ---------------------------------------------------------------------------
# Constraint elimination (Fourier Schur complement)
# ---------------------------------------------------------------------------
# Ref: Hairer & Wanner (1996), Solving ODEs II, Ch. VII — DAE reduction.
# Ref: Ascher & Petzold (1998), Computer Methods for ODEs/DAEs, §10.2.
def _constraints_fourier_eliminable(
spec: EquationSystem,
constraint_eqs: Sequence[ComponentEquation],
) -> bool:
"""Check if all constraint equations can be eliminated in Fourier space.
Requirements:
- Each constraint operator must be decomposable (spatial x time)
- No time-dependent coefficients in constraints
Constraints may contain acceleration operators (mixed_T2_S1x, d2_t)
which are handled by substituting the dynamical equations of motion
before Schur elimination.
"""
for eq in constraint_eqs:
for term in eq.rhs_terms:
if term.operator not in _OPERATOR_DECOMP:
return False
if term.time_dependent:
return False
return True
def _build_constraint_eliminated_matrices(
spec: EquationSystem,
layout: StateLayout,
grid: GridInfo,
coeff_eval: object, # CoefficientEvaluator
k_grid: list[NDArray[np.float64]],
rfft_shape: tuple[int, ...],
) -> tuple[
NDArray[np.complex128], # A_reduced (n_modes, n_dyn, n_dyn)
NDArray[np.complex128], # recovery (n_modes, n_constraints, n_dyn)
NDArray[np.complex128], # v_recovery (n_modes, n_constraints, n_dyn)
list[str], # constraint_field_names
dict[int, int], # orig_to_reduced slot mapping
]:
"""Build reduced per-mode matrices with constraints algebraically eliminated.
For a mixed system with dynamical (d) and constraint (c) fields:
d/dt[d] = A_dd·d + A_dc·c (dynamical equations)
0 = S_cd·d + S_cc·c (constraint: solve for c)
The constraint gives c = -S_cc⁻¹·S_cd·d. Substituting:
d/dt[d] = (A_dd - A_dc·S_cc⁻¹·S_cd)·d
This handles v_A₀ references in dynamical equations by recognizing that
v_A₀ = dA₀/dt = d/dt[-S_cc⁻¹·S_cd·d] = -S_cc⁻¹·S_cd·d', creating an
implicit equation (I - A_dc_v·S_cc⁻¹·S_cd)·d' = (A_dd + A_dc_f·f)·d
which is resolved by matrix inversion of the LHS factor.
All operations are purely numeric (CoefficientEvaluator returns floats).
S_cc⁻¹ in Fourier space is diagonal per mode — just 1/(m²+k²).
Returns
-------
A_reduced : ndarray
Per-mode matrices (n_modes, n_dyn, n_dyn) for dynamical fields only.
recovery : ndarray
Per-mode recovery (n_modes, n_constraints, n_dyn) for reconstructing
constraint fields from dynamical state.
constraint_field_names : list[str]
Names of eliminated constraint fields.
orig_to_reduced : dict
Mapping from original layout slot index to reduced slot index.
"""
from tidal.solver.coefficients import CoefficientEvaluator # noqa: PLC0415
assert isinstance(coeff_eval, CoefficientEvaluator)
n_modes = int(np.prod(rfft_shape))
# Identify constraint and dynamical fields
constraint_field_names: list[str] = []
constraint_eq_map: dict[str, int] = {} # field_name → eq_idx
for eq_idx, eq in enumerate(spec.equations):
if eq.time_derivative_order == 0:
constraint_field_names.append(eq.field_name)
constraint_eq_map[eq.field_name] = eq_idx
n_c = len(constraint_field_names)
# Build dynamical-only slot mapping (excluding constraint field slots)
orig_to_reduced: dict[int, int] = {}
red_idx = 0
for si, slot in enumerate(layout.slots):
if slot.kind == "constraint":
continue
orig_to_reduced[si] = red_idx
red_idx += 1
n_dyn = red_idx
# Map field names to slot indices in the REDUCED layout
dyn_slot_map: dict[str, int] = {}
for si, slot in enumerate(layout.slots):
if si in orig_to_reduced:
dyn_slot_map[slot.name] = orig_to_reduced[si]
# Also map velocity names v_X for dynamical fields
for fname, si in layout.velocity_slot_map.items():
v_name = f"v_{fname}"
if si in orig_to_reduced:
dyn_slot_map[v_name] = orig_to_reduced[si]
# Constraint slot map (constraint field names → constraint index 0..n_c-1)
c_idx_map: dict[str, int] = {
name: i for i, name in enumerate(constraint_field_names)
}
# Evaluate Fourier multipliers
multiplier_cache: dict[str, NDArray[np.complex128]] = {}
for eq in spec.equations:
for term in eq.rhs_terms:
op = term.operator
if op not in multiplier_cache:
mult_fn = _EXACT_MULTIPLIERS[op]
mult_val = mult_fn(k_grid)
mult_full = np.broadcast_to(mult_val, rfft_shape)
multiplier_cache[op] = mult_full.ravel().astype(np.complex128)
# --- Build the four coupling matrices per mode ---
# A_dd: dynamical → dynamical (n_modes, n_dyn, n_dyn)
A_dd = np.zeros((n_modes, n_dyn, n_dyn), dtype=np.complex128)
# A_dc: constraint → dynamical via FIELD references (n_modes, n_dyn, n_c)
A_dc_field = np.zeros((n_modes, n_dyn, n_c), dtype=np.complex128)
# A_dc_vel: constraint VELOCITY → dynamical (n_modes, n_dyn, n_c)
A_dc_vel = np.zeros((n_modes, n_dyn, n_c), dtype=np.complex128)
# S_cd: dynamical → constraint source (n_modes, n_c, n_dyn)
S_cd = np.zeros((n_modes, n_c, n_dyn), dtype=np.complex128)
# S_cc: constraint self-coupling (n_modes, n_c, n_c)
S_cc = np.zeros((n_modes, n_c, n_c), dtype=np.complex128)
for eq_idx, eq in enumerate(spec.equations):
is_constraint = eq.time_derivative_order == 0
is_second_order = eq.time_derivative_order >= 2
if is_constraint:
# Constraint equation: 0 = Σ coeff·op(target)
ci = c_idx_map[eq.field_name]
for term_idx, term in enumerate(eq.rhs_terms):
coeff = _resolve_constant_coeff(
term,
coeff_eval,
eq_idx=eq_idx,
term_idx=term_idx,
)
mult = multiplier_cache[term.operator]
if term.field in c_idx_map:
# Self/cross constraint coupling
cj = c_idx_map[term.field]
S_cc[:, ci, cj] += coeff * mult
elif term.field in dyn_slot_map:
# Source coupling to dynamical state
dj = dyn_slot_map[term.field]
S_cd[:, ci, dj] += coeff * mult
elif is_second_order:
field_slot = orig_to_reduced[layout.field_slot_map[eq.field_name]]
vel_slot = orig_to_reduced[layout.velocity_slot_map[eq.field_name]]
# Kinematic: dq/dt = v
A_dd[:, field_slot, vel_slot] = 1.0
# RHS terms: dv/dt = Σ coeff·op(target)
for term_idx, term in enumerate(eq.rhs_terms):
coeff = _resolve_constant_coeff(
term,
coeff_eval,
eq_idx=eq_idx,
term_idx=term_idx,
)
mult = multiplier_cache[term.operator]
if term.field in c_idx_map:
# References constraint field directly
cj = c_idx_map[term.field]
A_dc_field[:, vel_slot, cj] += coeff * mult
elif term.field.startswith("v_") and term.field[2:] in c_idx_map:
# References constraint velocity v_A₀
cj = c_idx_map[term.field[2:]]
A_dc_vel[:, vel_slot, cj] += coeff * mult
elif term.field in dyn_slot_map:
# Normal dynamical reference
dj = dyn_slot_map[term.field]
A_dd[:, vel_slot, dj] += coeff * mult
else:
# First-order: du/dt = Σ coeff·op(target)
this_slot = orig_to_reduced[layout.field_slot_map[eq.field_name]]
for term_idx, term in enumerate(eq.rhs_terms):
coeff = _resolve_constant_coeff(
term,
coeff_eval,
eq_idx=eq_idx,
term_idx=term_idx,
)
mult = multiplier_cache[term.operator]
if term.field in c_idx_map:
cj = c_idx_map[term.field]
A_dc_field[:, this_slot, cj] += coeff * mult
elif term.field in dyn_slot_map:
dj = dyn_slot_map[term.field]
A_dd[:, this_slot, dj] += coeff * mult
# --- Compute Schur complement ---
# Batch-invert S_cc across all modes (small matrices, typically 1x1 or 2x2)
# Detect and regularize singular modes (e.g. k=0 gauge freedom)
dets = np.linalg.det(S_cc) if n_c > 0 else np.ones(n_modes)
singular_mask = np.abs(dets) < 1e-14
S_cc_reg = S_cc.copy()
if np.any(singular_mask):
S_cc_reg[singular_mask] += 1e-14 * np.eye(n_c, dtype=np.complex128)
S_cc_inv = np.linalg.inv(S_cc_reg) # (n_modes, n_c, n_c)
# Recovery: c = -S_cc⁻¹ · S_cd · d
# recovery[m, ci, dj] = -Σ_cj S_cc_inv[m,ci,cj] · S_cd[m,cj,dj]
recovery = -np.einsum("mij,mjk->mik", S_cc_inv, S_cd)
# Substitution: A_dc_field · c = A_dc_field · recovery · d
# field_correction[m] = A_dc_field[m] @ recovery[m]
field_correction = np.einsum("mij,mjk->mik", A_dc_field, recovery)
# For constraint velocity: v_c = d/dt[c] = recovery · d'
# where d' = A_reduced · d. So A_dc_vel · v_c = A_dc_vel · recovery · d'.
# This creates implicit coupling:
# d' = A_dd · d + field_correction · d + A_dc_vel · recovery · d'
# (I - A_dc_vel · recovery) · d' = (A_dd + field_correction) · d
# d' = (I - A_dc_vel · recovery)⁻¹ · (A_dd + field_correction) · d
vel_coupling = np.einsum("mij,mjk->mik", A_dc_vel, recovery)
# Check if vel_coupling is nonzero (constraint velocity referenced)
has_vel_coupling = np.max(np.abs(vel_coupling)) > 1e-15
A_rhs = A_dd + field_correction
if has_vel_coupling:
# Implicit solve: (I - vel_coupling) · d' = A_rhs · d
eye = np.broadcast_to(
np.eye(n_dyn, dtype=np.complex128),
(n_modes, n_dyn, n_dyn),
).copy()
lhs = eye - vel_coupling
# Batch solve: A_reduced = lhs⁻¹ · A_rhs (all modes at once)
A_reduced: NDArray[np.complex128] = np.asarray(
np.linalg.solve(lhs, A_rhs),
dtype=np.complex128,
)
else:
A_reduced = A_rhs
# Velocity recovery: v_c_hat = v_recovery @ d_hat gives exact ∂_t(c)
# Derived from: c = recovery · d, so ∂_t c = recovery · ∂_t d = recovery · A_reduced · d
# This is machine-precision — no numerical differentiation needed.
# Stored per-mode: shape (n_modes, n_c, n_dyn). Typically < 50 MB for 128³.
v_recovery = np.einsum("mci,mij->mcj", recovery, A_reduced)
return A_reduced, recovery, v_recovery, constraint_field_names, orig_to_reduced
# ---------------------------------------------------------------------------
# Generalized mass-matrix evolution (M·ẍ = K·x + D·ẋ + J·x⃛)
# ---------------------------------------------------------------------------
# For systems with implicit acceleration coupling (d2_t, mixed_T2_S*) and
# jerk coupling (d3_t, mixed_T3_S*). The mass matrix M may be singular,
# creating hidden constraints analogous to time_order=0 fields.
#
# Algorithm:
# 1. Build M, D, K, J matrices from operator decomposition
# 2. Eigendecompose M per mode — zero eigenvalues → constraints
# 3. Schur-eliminate mass-matrix constraints (same as constraint fields)
# 4. Substitute jerk terms using equations of motion
# 5. Build first-order evolution matrix A = [[0,I],[M⁻¹K, M⁻¹D]]
# 6. Combine with existing constraint field Schur elimination
#
# References:
# Golub & Van Loan (2013), Matrix Computations §7.7 (generalized eigenvalue)
# Hairer & Lubich (2003), ZAMM 83(1) (mass matrices in dynamics)
# Ostrogradsky (1850), Mem. Acad. St. Petersbourg VI 4, 385
def _has_time_derivative_operators(spec: EquationSystem) -> bool:
"""Check whether any equation has time-derivative operators on its RHS."""
for eq in spec.equations:
for term in eq.rhs_terms:
decomp = _OPERATOR_DECOMP.get(term.operator)
if decomp is not None and decomp.time_order > 0:
return True
return False
def _build_generalized_evolution_matrices(
spec: EquationSystem,
layout: StateLayout,
grid: GridInfo,
coeff_eval: object, # CoefficientEvaluator
k_grid: list[NDArray[np.float64]],
rfft_shape: tuple[int, ...],
) -> tuple[
NDArray[np.complex128], # A_rhs (n_modes, n_dyn_slots, n_dyn_slots)
NDArray[np.complex128] | None, # B_lhs (n_modes, n_dyn, n_dyn) or None
NDArray[np.complex128], # recovery (n_modes, n_total_constraints, n_dyn_slots)
NDArray[np.complex128] | None, # v_recovery (n_modes, n_c, n_dyn) or None
list[str], # all constraint field names
dict[int, int], # orig_to_reduced slot mapping
]:
"""Build per-mode matrices for systems with mass-matrix coupling.
Returns A_rhs and optionally B_lhs for the generalized eigenvalue
problem B·d' = A·d, where B = I - vel_coupling may be singular
(gauge freedom from circular constraint velocity dependencies).
When B_lhs is not None, the caller should use scipy.linalg.eig(A, B)
(QZ decomposition) instead of np.linalg.eig(A) for eigendecomposition.
Infinite eigenvalues correspond to gauge-constrained directions.
Handles the generalized second-order system:
M(k)·ẍ = K(k)·x + D(k)·ẋ + J(k)·x⃛
where M may be singular (creating hidden algebraic constraints) and J
encodes jerk coupling from d3_t/mixed_T3_S* operators.
The algorithm:
1. Separates constraint (time_order=0) and dynamical fields
2. Builds M, D, K matrices for dynamical fields from operator decomposition
3. Eigendecomposes M per mode to detect singular directions
4. Treats zero-eigenvalue directions as additional constraints (Schur)
5. Substitutes jerk terms using the (now-invertible) dynamical equations
6. Combines both constraint levels and builds the first-order evolution matrix
Returns the same tuple as ``_build_constraint_eliminated_matrices``.
"""
import logging # noqa: PLC0415
from tidal.solver.coefficients import CoefficientEvaluator # noqa: PLC0415
assert isinstance(coeff_eval, CoefficientEvaluator)
logger = logging.getLogger(__name__)
n_modes = int(np.prod(rfft_shape))
# ---- Identify constraint and dynamical fields ----
constraint_field_names: list[str] = [
eq.field_name for eq in spec.equations if eq.time_derivative_order == 0
]
c_idx_map: dict[str, int] = {
name: i for i, name in enumerate(constraint_field_names)
}
n_c = len(constraint_field_names)
# Build dynamical-only slot mapping (excluding constraint field slots)
orig_to_reduced: dict[int, int] = {}
red_idx = 0
for si, slot in enumerate(layout.slots):
if slot.kind == "constraint":
continue
orig_to_reduced[si] = red_idx
red_idx += 1
n_dyn_slots = red_idx
# Map field/velocity names → reduced slot indices
dyn_slot_map: dict[str, int] = {}
for si, slot in enumerate(layout.slots):
if si in orig_to_reduced:
dyn_slot_map[slot.name] = orig_to_reduced[si]
for fname, si in layout.velocity_slot_map.items():
v_name = f"v_{fname}"
if si in orig_to_reduced:
dyn_slot_map[v_name] = orig_to_reduced[si]
# ---- Evaluate spatial Fourier multipliers ----
multiplier_cache: dict[str, NDArray[np.complex128]] = {}
for eq in spec.equations:
for term in eq.rhs_terms:
op = term.operator
if op not in multiplier_cache:
decomp = _OPERATOR_DECOMP[op]
mult_val = decomp.spatial_fn(k_grid)
mult_full = np.broadcast_to(mult_val, rfft_shape)
multiplier_cache[op] = mult_full.ravel().astype(np.complex128)
# ---- Identify dynamical fields and their indices ----
# Map: dynamical field name → index in the n_f dynamical field array
dyn_field_names: list[str] = []
dyn_field_idx: dict[str, int] = {}
for eq in spec.equations:
if eq.time_derivative_order > 0:
dyn_field_idx[eq.field_name] = len(dyn_field_names)
dyn_field_names.append(eq.field_name)
n_f = len(dyn_field_names) # number of dynamical FIELDS (not slots)
# ---- Build M, D, K matrices for dynamical fields (n_f x n_f) ----
# These are the FIELD-level matrices, not slot-level.
# M·ẍ = K·x + D·ẋ where x is the vector of field values.
M_mat = np.zeros((n_modes, n_f, n_f), dtype=np.complex128)
D_mat = np.zeros((n_modes, n_f, n_f), dtype=np.complex128)
K_mat = np.zeros((n_modes, n_f, n_f), dtype=np.complex128)
J_mat = np.zeros((n_modes, n_f, n_f), dtype=np.complex128)
# Diagonal of M: each 2nd-order field has ẍ_i on the LHS
for fi in range(n_f):
M_mat[:, fi, fi] = 1.0
# Constraint matrices — built in two phases:
# Phase 1: collect terms from constraint equations
# Phase 2: substitute acceleration/velocity terms after M inversion
S_cd = np.zeros((n_modes, n_c, n_dyn_slots), dtype=np.complex128)
S_cc = np.zeros((n_modes, n_c, n_c), dtype=np.complex128)
# A_dc: constraint → dynamical (field + velocity references)
A_dc_field = np.zeros((n_modes, n_dyn_slots, n_c), dtype=np.complex128)
A_dc_vel = np.zeros((n_modes, n_dyn_slots, n_c), dtype=np.complex128)
# Deferred constraint terms with time_order > 0 on dynamical fields.
# These need acceleration/velocity substitution after M inversion.
# Each entry: (ci, coeff, spatial_mult, time_order, target_field_idx)
deferred_constraint_terms: list[
tuple[int, complex, NDArray[np.complex128], int, int]
] = []
# ---- Populate matrices from equations ----
for eq_idx, eq in enumerate(spec.equations):
is_constraint = eq.time_derivative_order == 0
if is_constraint:
ci = c_idx_map[eq.field_name]
for term_idx, term in enumerate(eq.rhs_terms):
coeff = _resolve_constant_coeff(
term, coeff_eval, eq_idx=eq_idx, term_idx=term_idx
)
mult = multiplier_cache[term.operator]
decomp = _OPERATOR_DECOMP[term.operator]
t_order = decomp.time_order
if term.field in c_idx_map:
cj = c_idx_map[term.field]
S_cc[:, ci, cj] += coeff * mult
elif t_order == 0:
# Pure spatial operator on dynamical field/velocity
if term.field in dyn_slot_map:
dj = dyn_slot_map[term.field]
S_cd[:, ci, dj] += coeff * mult
elif t_order == 1 and term.field in dyn_field_idx:
# Velocity operator on dynamical field (e.g. mixed_T1_S1x)
# This references ẋ_field → use velocity slot
fj = dyn_field_idx[term.field]
vel_j = orig_to_reduced[layout.velocity_slot_map[term.field]]
S_cd[:, ci, vel_j] += coeff * mult
elif t_order >= 2 and term.field in dyn_field_idx:
# Acceleration/jerk on dynamical field — defer until M inverted
fj = dyn_field_idx[term.field]
deferred_constraint_terms.append(
(ci, complex(coeff), mult, t_order, fj)
)
elif term.field in dyn_slot_map:
# Fallback: direct slot reference
dj = dyn_slot_map[term.field]
S_cd[:, ci, dj] += coeff * mult
continue
# Dynamical equation
fi = dyn_field_idx[eq.field_name]
field_slot = orig_to_reduced[layout.field_slot_map[eq.field_name]]
vel_slot = orig_to_reduced[layout.velocity_slot_map[eq.field_name]]
for term_idx, term in enumerate(eq.rhs_terms):
coeff = _resolve_constant_coeff(
term, coeff_eval, eq_idx=eq_idx, term_idx=term_idx
)
mult = multiplier_cache[term.operator]
decomp = _OPERATOR_DECOMP[term.operator]
t_order = decomp.time_order
# Determine which field this term targets
target_field = term.field
# Strip v_ prefix to get base field name for velocity references
is_vel_ref = target_field.startswith("v_")
base_field = target_field[2:] if is_vel_ref else target_field
if target_field in c_idx_map:
# Direct reference to constraint field
cj = c_idx_map[target_field]
A_dc_field[:, vel_slot, cj] += coeff * mult
elif is_vel_ref and base_field in c_idx_map:
# Velocity of constraint field
cj = c_idx_map[base_field]
A_dc_vel[:, vel_slot, cj] += coeff * mult
elif base_field in dyn_field_idx:
fj = dyn_field_idx[base_field]
if t_order == 0:
if is_vel_ref:
# Velocity reference with spatial operator
# → damping matrix D[fi, fj]
D_mat[:, fi, fj] += coeff * mult
else:
# Position reference with spatial operator
# → stiffness matrix K[fi, fj]
K_mat[:, fi, fj] += coeff * mult
elif t_order == 1:
if is_vel_ref:
# first_derivative_t(v_X) = ẍ_X → acceleration
# This should be rare; treat as M coupling
M_mat[:, fi, fj] -= coeff * mult
else:
# first_derivative_t(X) = ẋ_X → velocity
D_mat[:, fi, fj] += coeff * mult
elif t_order == 2:
# d2_t or mixed_T2: acceleration coupling → mass matrix
# RHS has coeff·ẍ_j, move to LHS: M[fi,fj] -= coeff·mult
M_mat[:, fi, fj] -= coeff * mult
elif t_order == 3:
# d3_t or mixed_T3: jerk coupling → substitute later
J_mat[:, fi, fj] += coeff * mult
elif target_field in dyn_slot_map:
# Direct slot reference (velocity name like v_h_3)
dj = dyn_slot_map[target_field]
# Just put it in the A matrix directly later
# For now, track separately if needed
# ---- Mass-matrix constraint elimination ----
# Eigendecompose M per mode to find singular directions.
# Zero eigenvalues → hidden constraints; nonzero → dynamical.
#
# For each mode k:
# M(k) = Q(k) · Λ(k) · Q(k)ᵀ
# Rotate: K̃ = QᵀKQ, D̃ = QᵀDQ
# Singular rows (λ=0) → constraint: 0 = K̃_c·z + D̃_c·ż
# Dynamical rows (λ≠0) → ODE: Λ_d·z̈ = K̃_d·z + D̃_d·ż
# Check if any mode has singular M
dets = np.linalg.det(M_mat)
has_singular_M = np.any(np.abs(dets) < 1e-12)
m_k_independent = False
con_mask = np.zeros(n_f, dtype=bool) # will be updated if singular
if has_singular_M:
logger.info(
"Generalized mass matrix: singular M detected — "
"applying mass-matrix Schur elimination"
)
# Build the first-order evolution matrix A in the FULL dynamical slot space.
# A has shape (n_modes, n_dyn_slots, n_dyn_slots).
# For 2nd-order fields: rows for field_slot get dq/dt = v (kinematic),
# rows for vel_slot get dv/dt = M⁻¹(K·x + D·v).
A_dd = np.zeros((n_modes, n_dyn_slots, n_dyn_slots), dtype=np.complex128)
# Kinematic equations: dq/dt = v
for fname in dyn_field_names:
field_slot = orig_to_reduced[layout.field_slot_map[fname]]
vel_slot = orig_to_reduced[layout.velocity_slot_map[fname]]
A_dd[:, field_slot, vel_slot] = 1.0
if has_singular_M:
# Use eigendecomposition to handle singular M
# We work with each mode separately for modes where M is singular,
# and batch-process modes where M is invertible.
# For simplicity and correctness, process per-mode where needed.
# M is typically k-independent for d2_t coupling (spatial_mult=1),
# so use the k=0 mode's eigenstructure as representative.
# For k-dependent M (from mixed_T2_S*), process per mode.
# Check if M is k-independent
M_spread = np.max(np.abs(M_mat - M_mat[0:1, :, :]))
m_k_independent = M_spread < 1e-14
if m_k_independent:
# M is the same for all modes — single eigendecomposition
M0 = M_mat[0]
eigvals, Q = np.linalg.eigh(M0.real) # M is real symmetric
# Threshold for zero eigenvalue
tol = 1e-10 * max(1.0, np.max(np.abs(eigvals)))
dyn_mask = np.abs(eigvals) > tol
con_mask = ~dyn_mask
n_mass_con = int(np.sum(con_mask))
n_mass_dyn = int(np.sum(dyn_mask))
logger.info(
"Mass matrix eigenvalues: %s (dynamical: %d, constrained: %d)",
eigvals,
n_mass_dyn,
n_mass_con,
)
if n_mass_con > 0:
# Rotate K, D, J into eigenspace
Q[:, con_mask] # (n_f, n_mass_con)
Q_d = Q[:, dyn_mask] # (n_f, n_mass_dyn)
np.diag(eigvals[dyn_mask]) # (n_mass_dyn, n_mass_dyn)
Lambda_d_inv = np.diag(
1.0 / eigvals[dyn_mask]
) # (n_mass_dyn, n_mass_dyn)
# Rotate per-mode matrices
# K̃ = QᵀKQ, D̃ = QᵀDQ, J̃ = QᵀJQ
K_rot = np.einsum("ij,mjk,kl->mil", Q.T, K_mat, Q)
D_rot = np.einsum("ij,mjk,kl->mil", Q.T, D_mat, Q)
J_rot = np.einsum("ij,mjk,kl->mil", Q.T, J_mat, Q)
# Partition into dynamical (d) and constrained (c) blocks
d_idx = np.where(dyn_mask)[0]
c_idx = np.where(con_mask)[0]
K_dd = K_rot[:, np.ix_(d_idx, d_idx)[0], np.ix_(d_idx, d_idx)[1]]
K_dc = K_rot[:, np.ix_(d_idx, c_idx)[0], np.ix_(d_idx, c_idx)[1]]
K_cd = K_rot[:, np.ix_(c_idx, d_idx)[0], np.ix_(c_idx, d_idx)[1]]
K_cc = K_rot[:, np.ix_(c_idx, c_idx)[0], np.ix_(c_idx, c_idx)[1]]
D_dd = D_rot[:, np.ix_(d_idx, d_idx)[0], np.ix_(d_idx, d_idx)[1]]
D_dc = D_rot[:, np.ix_(d_idx, c_idx)[0], np.ix_(d_idx, c_idx)[1]]
D_rot[:, np.ix_(c_idx, d_idx)[0], np.ix_(c_idx, d_idx)[1]]
D_rot[:, np.ix_(c_idx, c_idx)[0], np.ix_(c_idx, c_idx)[1]]
J_dd = J_rot[:, np.ix_(d_idx, d_idx)[0], np.ix_(d_idx, d_idx)[1]]
J_dc = J_rot[:, np.ix_(d_idx, c_idx)[0], np.ix_(d_idx, c_idx)[1]]
J_rot[:, np.ix_(c_idx, d_idx)[0], np.ix_(c_idx, d_idx)[1]]
J_rot[:, np.ix_(c_idx, c_idx)[0], np.ix_(c_idx, c_idx)[1]]
# Constraint rows: 0 = K_cd·z_d + K_cc·z_c + D_cd·ż_d + D_cc·ż_c
# Solve for z_c (position-only constraint, ignoring velocity for now):
# If K_cc is invertible: z_c = -K_cc⁻¹·K_cd·z_d
# If velocity terms are present, handle as implicit coupling.
# Check if constraint is purely positional (K_cc invertible, D_cc ~ 0)
K_cc_det = np.linalg.det(K_cc) if n_mass_con > 0 else np.ones(n_modes)
has_k_con = np.any(np.abs(K_cc_det) > 1e-14)
if has_k_con:
# Standard case: K_cc invertible → z_c = -K_cc⁻¹·K_cd·z_d
K_cc_reg = K_cc.copy()
singular = np.abs(K_cc_det) < 1e-14
if np.any(singular):
K_cc_reg[singular] += 1e-14 * np.eye(
n_mass_con, dtype=np.complex128
)
K_cc_inv = np.linalg.inv(K_cc_reg) # pyright: ignore[reportUnknownVariableType]
# Recovery: z_c = -K_cc⁻¹·K_cd·z_d
mass_recovery = -np.einsum("mij,mjk->mik", K_cc_inv, K_cd) # pyright: ignore[reportUnknownArgumentType]
# Substitute into dynamical equations:
# Λ_d·z̈_d = K_dd·z_d + K_dc·z_c + D_dd·ż_d + D_dc·ż_c
# z_c = mass_recovery·z_d → ż_c = mass_recovery·ż_d
K_eff = K_dd + np.einsum("mij,mjk->mik", K_dc, mass_recovery)
D_eff = D_dd + np.einsum("mij,mjk->mik", D_dc, mass_recovery)
J_eff = J_dd + np.einsum("mij,mjk->mik", J_dc, mass_recovery)
else:
# No positional constraint coupling — mass constraint
# modes decouple trivially (zero rows)
K_eff = K_dd
D_eff = D_dd
J_eff = J_dd
mass_recovery = np.zeros(
(n_modes, n_mass_con, n_mass_dyn), dtype=np.complex128
)
# Now invert Λ_d (diagonal, all nonzero)
# E = Λ_d⁻¹·K_eff, F = Λ_d⁻¹·D_eff
E = np.einsum("ij,mjk->mik", Lambda_d_inv, K_eff)
F = np.einsum("ij,mjk->mik", Lambda_d_inv, D_eff)
# Jerk substitution:
# d3_t(z_j) = E_j·ẋ + F_j·(E·x + F·ẋ) = F_j·E·x + (E_j + F_j·F)·ẋ
J_eff_inv = np.einsum("ij,mjk->mik", Lambda_d_inv, J_eff)
has_jerk = np.max(np.abs(J_eff_inv)) > 1e-15
if has_jerk:
logger.info("Jerk substitution: applying d3_t elimination")
# K_final += J_eff_inv · F · E (position correction from jerk)
FE = np.einsum("mij,mjk->mik", F, E)
K_jerk = np.einsum("mij,mjk->mik", J_eff_inv, FE)
# D_final += J_eff_inv · (E + F²) (velocity correction from jerk)
FF = np.einsum("mij,mjk->mik", F, F)
D_jerk = np.einsum("mij,mjk->mik", J_eff_inv, E + FF)
E_final = E + K_jerk
F_final = F + D_jerk
else:
E_final = E
F_final = F
# Build evolution matrix in the ROTATED field basis
# State vector in rotated basis: (z_d, ż_d)
# dz_d/dt = ż_d
# dż_d/dt = E_final·z_d + F_final·ż_d
# Now map back to the ORIGINAL slot-level evolution matrix A_dd.
# The rotation Q maps field-level indices to slot-level indices.
# For each dynamical field, there's a field_slot and vel_slot.
# Build the Q_d mapping: original field index → rotated dynamical index
# Q_d[original_i, rotated_j] = transformation coefficient
# For the velocity-slot rows (dynamics), fill in:
# dv_i/dt = Σ_j Q_d[i,a] · E_final[a,b] · Q_d[j,b] · field_j
# + Σ_j Q_d[i,a] · F_final[a,b] · Q_d[j,b] · vel_j
# Effective K and D in original field basis:
K_orig = np.einsum("ia,mab,jb->mij", Q_d, E_final, Q_d)
D_orig = np.einsum("ia,mab,jb->mij", Q_d, F_final, Q_d)
# Fill A_dd velocity rows
for i, fname_i in enumerate(dyn_field_names):
vel_i = orig_to_reduced[layout.velocity_slot_map[fname_i]]
for j, fname_j in enumerate(dyn_field_names):
field_j = orig_to_reduced[layout.field_slot_map[fname_j]]
vel_j = orig_to_reduced[layout.velocity_slot_map[fname_j]]
A_dd[:, vel_i, field_j] += K_orig[:, i, j]
A_dd[:, vel_i, vel_j] += D_orig[:, i, j]
else:
# No singular directions — M is invertible
m_inv = np.linalg.inv(M_mat)
eff_k = np.einsum("mij,mjk->mik", m_inv, K_mat)
eff_d = np.einsum("mij,mjk->mik", m_inv, D_mat)
# Jerk substitution
j_inv = np.einsum("mij,mjk->mik", m_inv, J_mat)
has_jerk = np.max(np.abs(j_inv)) > 1e-15
if has_jerk:
fd_k = np.einsum("mij,mjk->mik", eff_d, eff_k)
k_jerk = np.einsum("mij,mjk->mik", j_inv, fd_k)
fd_d = np.einsum("mij,mjk->mik", eff_d, eff_d)
d_jerk = np.einsum("mij,mjk->mik", j_inv, eff_k + fd_d)
eff_k += k_jerk
eff_d += d_jerk
for i, fname_i in enumerate(dyn_field_names):
vel_i = orig_to_reduced[layout.velocity_slot_map[fname_i]]
for j, fname_j in enumerate(dyn_field_names):
field_j = orig_to_reduced[layout.field_slot_map[fname_j]]
vel_j = orig_to_reduced[layout.velocity_slot_map[fname_j]]
A_dd[:, vel_i, field_j] += eff_k[:, i, j]
A_dd[:, vel_i, vel_j] += eff_d[:, i, j]
else:
# M is k-dependent — process per mode
# For now, treat each mode independently
for m in range(n_modes):
M_m = M_mat[m]
eigvals_m, _Q_m = np.linalg.eigh(M_m.real)
tol = 1e-10 * max(1.0, np.max(np.abs(eigvals_m)))
dyn_m = np.abs(eigvals_m) > tol
if np.all(dyn_m):
# Invertible for this mode
m_inv_m = np.linalg.inv(M_m) # pyright: ignore[reportUnknownVariableType]
ek_m = m_inv_m @ K_mat[m] # pyright: ignore[reportUnknownVariableType]
ed_m = m_inv_m @ D_mat[m] # pyright: ignore[reportUnknownVariableType]
j_inv_m = m_inv_m @ J_mat[m] # pyright: ignore[reportUnknownVariableType]
if np.max(np.abs(j_inv_m)) > 1e-15: # pyright: ignore[reportUnknownArgumentType]
fd_k_m = ed_m @ ek_m # pyright: ignore[reportUnknownVariableType]
ek_m += j_inv_m @ fd_k_m # pyright: ignore[reportUnknownVariableType]
ed_m += j_inv_m @ (ek_m + ed_m @ ed_m) # pyright: ignore[reportUnknownVariableType]
for i, fname_i in enumerate(dyn_field_names):
vi = orig_to_reduced[layout.velocity_slot_map[fname_i]]
for j, fname_j in enumerate(dyn_field_names):
fj = orig_to_reduced[layout.field_slot_map[fname_j]]
vj = orig_to_reduced[layout.velocity_slot_map[fname_j]]
A_dd[m, vi, fj] += ek_m[i, j]
A_dd[m, vi, vj] += ed_m[i, j]
else:
# Singular mode — would need per-mode Schur elimination
# This is rare for k-dependent M; log and use pseudoinverse
m_pinv = np.linalg.pinv(M_m) # pyright: ignore[reportUnknownVariableType]
ek_m2 = m_pinv @ K_mat[m] # pyright: ignore[reportUnknownVariableType]
ed_m2 = m_pinv @ D_mat[m] # pyright: ignore[reportUnknownVariableType]
for i, fname_i in enumerate(dyn_field_names):
vi = orig_to_reduced[layout.velocity_slot_map[fname_i]]
for j, fname_j in enumerate(dyn_field_names):
fj = orig_to_reduced[layout.field_slot_map[fname_j]]
vj = orig_to_reduced[layout.velocity_slot_map[fname_j]]
A_dd[m, vi, fj] += ek_m2[i, j]
A_dd[m, vi, vj] += ed_m2[i, j]
else:
# M is invertible for all modes — standard path
m_inv = np.linalg.inv(M_mat)
eff_k = np.einsum("mij,mjk->mik", m_inv, K_mat)
eff_d = np.einsum("mij,mjk->mik", m_inv, D_mat)
# Jerk substitution
j_inv = np.einsum("mij,mjk->mik", m_inv, J_mat)
has_jerk = np.max(np.abs(j_inv)) > 1e-15
if has_jerk:
logger.info("Jerk substitution: applying d3_t elimination")
fd_k = np.einsum("mij,mjk->mik", eff_d, eff_k)
k_jerk = np.einsum("mij,mjk->mik", j_inv, fd_k)
fd_d = np.einsum("mij,mjk->mik", eff_d, eff_d)
d_jerk = np.einsum("mij,mjk->mik", j_inv, eff_k + fd_d)
eff_k += k_jerk
eff_d += d_jerk
for i, fname_i in enumerate(dyn_field_names):
vel_i = orig_to_reduced[layout.velocity_slot_map[fname_i]]
for j, fname_j in enumerate(dyn_field_names):
field_j = orig_to_reduced[layout.field_slot_map[fname_j]]
vel_j = orig_to_reduced[layout.velocity_slot_map[fname_j]]
A_dd[:, vel_i, field_j] += eff_k[:, i, j]
A_dd[:, vel_i, vel_j] += eff_d[:, i, j]
# ---- Substitute deferred constraint acceleration/velocity terms ----
# Constraints may contain time_order>=2 operators on dynamical fields
# (e.g., mixed_T2_S1x(t_3) = ik_x x ẍ_{t_3}). After mass-matrix
# inversion, ẍ_j = Σ_k E[j,k]·field_k + F[j,k]·vel_k. Substitute
# this into the constraint's S_cd matrix.
if deferred_constraint_terms:
# Extract effective acceleration matrices from A_dd.
# A_dd[m, vel_i, field_j] = K_eff[i,j] (position → acceleration)
# A_dd[m, vel_i, vel_j] = D_eff[i,j] (velocity → acceleration)
K_eff = np.zeros((n_modes, n_f, n_f), dtype=np.complex128)
D_eff = np.zeros((n_modes, n_f, n_f), dtype=np.complex128)
for i, fname_i in enumerate(dyn_field_names):
vel_i = orig_to_reduced[layout.velocity_slot_map[fname_i]]
for j, fname_j in enumerate(dyn_field_names):
field_j = orig_to_reduced[layout.field_slot_map[fname_j]]
vel_j = orig_to_reduced[layout.velocity_slot_map[fname_j]]
K_eff[:, i, j] = A_dd[:, vel_i, field_j]
D_eff[:, i, j] = A_dd[:, vel_i, vel_j]
for ci, coeff_val, spatial_mult, t_order, fj in deferred_constraint_terms:
if t_order == 2:
# ẍ_fj = Σ_k K_eff[fj,k]·field_k + D_eff[fj,k]·vel_k
for k, fname_k in enumerate(dyn_field_names):
fk_slot = orig_to_reduced[layout.field_slot_map[fname_k]]
vk_slot = orig_to_reduced[layout.velocity_slot_map[fname_k]]
# Position contribution: coeff x spatial x K_eff[fj, k]
S_cd[:, ci, fk_slot] += coeff_val * spatial_mult * K_eff[:, fj, k]
# Velocity contribution: coeff x spatial x D_eff[fj, k]
S_cd[:, ci, vk_slot] += coeff_val * spatial_mult * D_eff[:, fj, k]
# time_order=3 in constraints is very rare; log and skip
elif t_order >= 3:
logger.warning(
"Constraint has time_order=%d operator — not yet handled",
t_order,
)
# ---- Constraint field Schur elimination ----
if n_c > 0:
# Batch-invert S_cc
cc_dets = np.linalg.det(S_cc) if n_c > 0 else np.ones(n_modes)
singular_mask = np.abs(cc_dets) < 1e-14
S_cc_reg = S_cc.copy()
if np.any(singular_mask):
S_cc_reg[singular_mask] += 1e-14 * np.eye(n_c, dtype=np.complex128)
S_cc_inv = np.linalg.inv(S_cc_reg)
# Recovery: c = -S_cc⁻¹ · S_cd · d
recovery = -np.einsum("mij,mjk->mik", S_cc_inv, S_cd)
# Field correction: A_dc_field · recovery
field_correction = np.einsum("mij,mjk->mik", A_dc_field, recovery)
# Velocity coupling: A_dc_vel · recovery
vel_coupling = np.einsum("mij,mjk->mik", A_dc_vel, recovery)
has_vel = np.max(np.abs(vel_coupling)) > 1e-15
A_rhs = A_dd + field_correction
if has_vel:
eye = np.broadcast_to(
np.eye(n_dyn_slots, dtype=np.complex128),
(n_modes, n_dyn_slots, n_dyn_slots),
).copy()
B_lhs: NDArray[np.complex128] | None = eye - vel_coupling
else:
B_lhs = None
else:
recovery = np.zeros((n_modes, 0, n_dyn_slots), dtype=np.complex128)
A_rhs = A_dd
B_lhs = None
n_mass_con_total = int(np.sum(con_mask))
logger.info(
"Generalized evolution: %d constraint fields, %d mass-matrix constraints, "
"%d dynamical slots, jerk=%s, vel_coupling=%s",
n_c,
n_mass_con_total,
n_dyn_slots,
"yes" if np.max(np.abs(J_mat)) > 1e-15 else "no",
"generalized_eig" if B_lhs is not None else "none",
)
# Velocity recovery for generalized eigenvalue (B·d' = A·d):
# d' = B⁻¹·A·d, so v_c = recovery · d' = recovery · B⁻¹·A · d
# For singular B modes, use least-squares to get the best A_eff.
if B_lhs is not None and recovery.size > 0:
n_dyn_slots = A_rhs.shape[1]
A_eff = np.zeros_like(A_rhs)
for m in range(n_modes):
try:
A_eff[m] = np.linalg.solve(B_lhs[m], A_rhs[m])
except np.linalg.LinAlgError:
# Singular B at this mode — use lstsq
A_eff[m] = np.asarray(
np.linalg.lstsq(B_lhs[m], A_rhs[m], rcond=None)[0], # type: ignore[reportUnknownMemberType]
dtype=np.complex128,
)
v_recovery = np.einsum("mci,mij->mcj", recovery, A_eff)
elif recovery.size > 0:
v_recovery = np.einsum("mci,mij->mcj", recovery, A_rhs)
else:
v_recovery = None
return A_rhs, B_lhs, recovery, v_recovery, constraint_field_names, orig_to_reduced
# ---------------------------------------------------------------------------
# Evolution matrix construction
# ---------------------------------------------------------------------------
def _build_per_mode_matrices(
spec: EquationSystem,
layout: StateLayout,
grid: GridInfo,
coeff_eval: CoefficientEvaluator,
k_grid: list[NDArray[np.float64]],
rfft_shape: tuple[int, ...],
) -> NDArray[np.complex128]:
"""Build evolution matrices for the all-constant-coefficient case.
Returns array of shape (n_modes, n_state_slots, n_state_slots) where
each [m, :, :] is the evolution matrix for mode m.
The matrix has block structure:
For second-order fields: [0, 1; L(k), 0] (velocity coupling)
For first-order fields: [L(k)] (direct evolution)
"""
n_slots = layout.num_slots
n_modes = int(np.prod(rfft_shape))
# Evaluate Fourier multipliers on the k-grid
multiplier_cache: dict[str, NDArray[np.complex128]] = {}
for eq in spec.equations:
for term in eq.rhs_terms:
op = term.operator
if op not in multiplier_cache:
mult_fn = _EXACT_MULTIPLIERS[op]
mult_val = mult_fn(k_grid)
# Broadcast to full rfft shape and flatten
mult_full = np.broadcast_to(mult_val, rfft_shape)
multiplier_cache[op] = mult_full.ravel().astype(np.complex128)
# Build matrices: A[m, i, j] for each mode m
A = np.zeros((n_modes, n_slots, n_slots), dtype=np.complex128)
for _eq_idx, eq in enumerate(spec.equations):
field_name = eq.field_name
is_second_order = eq.time_derivative_order >= 2
if is_second_order:
# Field slot and velocity slot
field_slot = layout.field_slot_map[field_name]
vel_slot = layout.velocity_slot_map[field_name]
# dq/dt = v → A[field_slot, vel_slot] = 1
A[:, field_slot, vel_slot] = 1.0
# dv/dt = Σ coeff * operator(target_field)
for _term_idx, term in enumerate(eq.rhs_terms):
target_slot = layout.field_slot_map[term.field]
coeff = _resolve_constant_coeff(
term,
coeff_eval,
eq_idx=_eq_idx,
term_idx=_term_idx,
)
mult = multiplier_cache[term.operator]
A[:, vel_slot, target_slot] += coeff * mult
else:
# First-order: du/dt = Σ coeff * operator(target_field)
this_slot = layout.field_slot_map[field_name]
for _term_idx, term in enumerate(eq.rhs_terms):
target_slot = layout.field_slot_map[term.field]
coeff = _resolve_constant_coeff(
term,
coeff_eval,
eq_idx=_eq_idx,
term_idx=_term_idx,
)
mult = multiplier_cache[term.operator]
A[:, this_slot, target_slot] += coeff * mult
return A
def _resolve_constant_coeff(
term: OperatorTerm,
coeff_eval: CoefficientEvaluator,
*,
eq_idx: int = -1,
term_idx: int = -1,
) -> complex:
"""Resolve a constant (non-position-dependent) coefficient to a scalar.
Uses CoefficientEvaluator.resolve() which returns a float for constant
coefficients or an ndarray for position-dependent ones (the latter should
not reach this function).
"""
resolved = coeff_eval.resolve(term, 0.0, eq_idx=eq_idx, term_idx=term_idx)
if isinstance(resolved, np.ndarray):
# Position-dependent — should not happen for constant-coeff path
return complex(resolved.ravel()[0])
return complex(resolved)
def _build_convolution_matrix(
spec: EquationSystem,
layout: StateLayout,
grid: GridInfo,
coeff_eval: CoefficientEvaluator,
k_grid: list[NDArray[np.float64]],
rfft_shape: tuple[int, ...],
) -> NDArray[np.complex128]:
"""Build full evolution matrix for position-dependent coefficient case.
Position-dependent coefficients c(x) create convolution coupling in
k-space: FFT[c(x)·u(x)] = ĉ * û (convolution). This couples
different k-modes, producing a full (n_total x n_total) matrix where
n_total = n_slots x n_modes.
For localized c(x) (e.g. Gaussian B₀), the convolution kernel ĉ(q)
decays exponentially, making the matrix effectively banded. The
downstream ``_evolve_full_matrix`` exploits this by thresholding small
entries and converting to sparse CSC format for faster expm_multiply.
Reference: Burns et al. (2020), Phys. Rev. Research 2:023068.
"""
n_slots = layout.num_slots
n_modes = int(np.prod(rfft_shape))
n_total = n_slots * n_modes
# Evaluate Fourier multipliers on the k-grid (for constant terms)
multiplier_cache: dict[str, NDArray[np.complex128]] = {}
for eq in spec.equations:
for term in eq.rhs_terms:
op = term.operator
if op not in multiplier_cache:
mult_fn = _EXACT_MULTIPLIERS[op]
mult_val = mult_fn(k_grid)
mult_full = np.broadcast_to(mult_val, rfft_shape)
multiplier_cache[op] = mult_full.ravel().astype(np.complex128)
A = np.zeros((n_total, n_total), dtype=np.complex128)
for _eq_idx, eq in enumerate(spec.equations):
field_name = eq.field_name
is_second_order = eq.time_derivative_order >= 2
if is_second_order:
field_slot = layout.field_slot_map[field_name]
vel_slot = layout.velocity_slot_map[field_name]
# dq/dt = v → diagonal identity coupling between field and velocity
for m in range(n_modes):
row = field_slot * n_modes + m
col = vel_slot * n_modes + m
A[row, col] = 1.0
# dv/dt = Σ coeff(x) * operator(target_field)
for _term_idx, term in enumerate(eq.rhs_terms):
target_slot = layout.field_slot_map[term.field]
mult = multiplier_cache[term.operator]
if not term.position_dependent:
# Constant coefficient: diagonal in mode space
coeff = _resolve_constant_coeff(
term,
coeff_eval,
eq_idx=_eq_idx,
term_idx=_term_idx,
)
for m in range(n_modes):
row = vel_slot * n_modes + m
col = target_slot * n_modes + m
A[row, col] += coeff * mult[m]
else:
# Position-dependent: convolution coupling
_add_convolution_coupling(
A,
vel_slot,
target_slot,
term,
coeff_eval,
mult,
grid,
rfft_shape,
n_modes,
eq_idx=_eq_idx,
term_idx=_term_idx,
)
else:
# First-order
this_slot = layout.field_slot_map[field_name]
for _term_idx, term in enumerate(eq.rhs_terms):
target_slot = layout.field_slot_map[term.field]
mult = multiplier_cache[term.operator]
if not term.position_dependent:
coeff = _resolve_constant_coeff(
term,
coeff_eval,
eq_idx=_eq_idx,
term_idx=_term_idx,
)
for m in range(n_modes):
row = this_slot * n_modes + m
col = target_slot * n_modes + m
A[row, col] += coeff * mult[m]
else:
_add_convolution_coupling(
A,
this_slot,
target_slot,
term,
coeff_eval,
mult,
grid,
rfft_shape,
n_modes,
eq_idx=_eq_idx,
term_idx=_term_idx,
)
return A
def _add_convolution_coupling(
A: NDArray[np.complex128],
row_slot: int,
col_slot: int,
term: OperatorTerm,
coeff_eval: CoefficientEvaluator,
operator_mult: NDArray[np.complex128],
grid: GridInfo,
rfft_shape: tuple[int, ...],
n_modes: int,
*,
eq_idx: int = -1,
term_idx: int = -1,
) -> None:
"""Add convolution coupling from a position-dependent coefficient.
The product c(x)·op(u(x)) in k-space becomes a convolution:
FFT[c·op(u)]_k = Σ_k' ĉ(k-k') · mult(k') · û(k')
This creates off-diagonal entries in the evolution matrix coupling
different k-modes.
"""
# Get the coefficient array on the spatial grid
coeff_array = coeff_eval.resolve(term, 0.0, eq_idx=eq_idx, term_idx=term_idx)
if isinstance(coeff_array, (int, float)):
coeff_array = np.full(grid.shape, float(coeff_array))
# For each pair of output mode m and input mode m',
# the coupling is (1/N) * ĉ(m-m') * mult(m')
# This is a Toeplitz-like structure in 1D.
#
# Build via outer product approach for efficiency:
# We compute the full convolution matrix using FFT properties.
#
# For rfftn: the convolution of real functions in rfft space requires
# care with the Hermitian symmetry. We use the identity:
# FFT[c·u]_k = (1/N) Σ_{k'} ĉ_{k-k'} · û_{k'}
#
# Build the convolution matrix C where C[k, k'] = (1/N) * ĉ_{k-k'}
# using probe vectors (unit impulse per mode).
for m_prime in range(n_modes):
# Probe: unit impulse at mode m_prime
probe_hat = np.zeros(n_modes, dtype=np.complex128)
probe_hat[m_prime] = 1.0
# Reconstruct to physical space, multiply by coefficient, FFT back
probe_physical = np.fft.irfftn(
probe_hat.reshape(rfft_shape),
s=grid.shape,
axes=list(range(len(grid.shape))),
)
product = coeff_array * probe_physical
result_hat = np.fft.rfftn(product).ravel()
# result_hat[m] = Σ_{k'} (1/N) ĉ_{m-k'} δ_{k',m'} = (1/N) ĉ_{m-m'}
# multiplied by operator multiplier at m'
row_start = row_slot * n_modes
col = col_slot * n_modes + m_prime
A[row_start : row_start + n_modes, col] += result_hat * operator_mult[m_prime]
# ---------------------------------------------------------------------------
# Block decomposition
# ---------------------------------------------------------------------------
def _find_independent_blocks(
A: NDArray[np.complex128],
threshold: float = 1e-14,
) -> list[list[int]]:
"""Find independent (decoupled) blocks in an evolution matrix.
Analyzes the sparsity pattern of A: slots i and j are coupled if
|A[i,j]| > threshold or |A[j,i]| > threshold. Returns a list of
slot-index groups (connected components).
This prevents degenerate-eigenvalue mixing when ``np.linalg.eig``
processes a block-diagonal matrix with repeated eigenvalues across
independent blocks — a common situation in symmetric multi-field
theories (e.g. Gertsenshtein h₅↔a₁ + h₇↔a₂).
"""
n = A.shape[0]
# Union-find (path compression + union by rank)
parent = list(range(n))
rank = [0] * n
def _find(x: int) -> int:
while parent[x] != x:
parent[x] = parent[parent[x]] # path compression
x = parent[x]
return x
def _union(x: int, y: int) -> None:
rx, ry = _find(x), _find(y)
if rx == ry:
return
if rank[rx] < rank[ry]:
rx, ry = ry, rx
parent[ry] = rx
if rank[rx] == rank[ry]:
rank[rx] += 1
# Build coupling graph from matrix entries
for i in range(n):
for j in range(i + 1, n):
if abs(A[i, j]) > threshold or abs(A[j, i]) > threshold:
_union(i, j)
# Group by root
groups: dict[int, list[int]] = {}
for i in range(n):
root = _find(i)
groups.setdefault(root, []).append(i)
return list(groups.values())
# ---------------------------------------------------------------------------
# Eigendecomposition and time evolution
# ---------------------------------------------------------------------------
def _has_position_dependent_terms(spec: EquationSystem) -> bool:
"""Check if any RHS term has a position-dependent coefficient."""
for eq in spec.equations:
for term in eq.rhs_terms:
if term.position_dependent:
return True
return False
def _warn_eigenvalue_growth(
eigenvalues: NDArray[np.complex128],
dt_total: float,
context: str = "",
) -> None:
"""Warn if eigenvalues have positive real parts that could cause overflow."""
import warnings # noqa: PLC0415
max_real = float(np.max(np.real(eigenvalues)))
if max_real > 1e-10:
max_growth = max_real * dt_total
if max_growth > 30: # exp(30) ≈ 1e13
ctx = f" ({context})" if context else ""
warnings.warn(
f"Modal solver{ctx}: eigenvalues with positive real parts "
f"(max Re(λ)={max_real:.3e}). Growth factor exp({max_growth:.1f}) "
f"over Δt={dt_total:.1f} may cause overflow. "
f"Consider --scheme cvode for numerical stability.",
stacklevel=3,
)
def _evolve_per_mode(
A_modes: NDArray[np.complex128],
y0_hat: NDArray[np.complex128],
t_eval: NDArray[np.float64],
layout: StateLayout,
grid: GridInfo,
snapshot_callback: Callable[[float, NDArray[np.float64]], None] | None,
progress: SimulationProgress | None,
*,
return_fourier: bool = False,
return_derivative_fourier: bool = False,
B_modes: NDArray[np.complex128] | None = None,
) -> tuple[
NDArray[np.float64],
NDArray[np.float64],
NDArray[np.complex128] | None,
NDArray[np.complex128] | None,
]:
"""Evolve system with per-mode independent matrices (constant coefficients).
A_modes has shape (n_modes, n_slots, n_slots).
y0_hat has shape (n_slots, n_modes).
Uses block-aware eigendecomposition: independent field blocks are detected
and eigendecomposed separately to prevent degenerate-eigenvalue mixing.
Blocks with all-zero initial conditions are skipped entirely.
Ref: Golub & Van Loan (1996), Matrix Computations, §4.8.
"""
n_slots = layout.num_slots
n_pts = layout.num_points
n_snapshots = len(t_eval)
t0 = t_eval[0]
dt_total = float(t_eval[-1] - t0)
# Detect independent blocks from the first mode's matrix.
# Block structure is k-independent for constant coefficients, so we only
# need to analyze one representative mode (use max across a few modes for
# robustness against accidental zeros at specific k).
n_check = min(3, A_modes.shape[0])
combined = np.max(np.abs(A_modes[:n_check]), axis=0)
blocks = _find_independent_blocks(combined)
# Pre-compute eigendecomposition for each active block
block_data: list[
tuple[
list[int], # slot indices
NDArray[np.complex128], # eigenvalues (n_modes, block_size)
NDArray[np.complex128], # V (n_modes, block_size, block_size)
NDArray[np.complex128], # y0_eigen (n_modes, block_size)
]
] = []
for block_slots in blocks:
# Extract IC for this block
y0_block = y0_hat[block_slots, :] # (block_size, n_modes)
# Skip blocks with all-zero IC — output stays at zero
if np.max(np.abs(y0_block)) < 1e-15:
continue
# Extract block sub-matrices: (n_modes, block_size, block_size)
idx = np.array(block_slots)
A_block = A_modes[:, idx[:, None], idx[None, :]]
if B_modes is not None:
# Generalized eigenvalue problem: B · d' = A · d
# Uses QZ decomposition via scipy.linalg.eig(A, B).
# Infinite eigenvalues (gauge DOF) are zeroed — they don't evolve.
# Ref: Golub & Van Loan (2013), Matrix Computations §7.7.6
import scipy.linalg as sla # type: ignore[import-untyped] # noqa: PLC0415
B_block = B_modes[:, idx[:, None], idx[None, :]]
bs = len(block_slots)
n_block_modes = A_block.shape[0]
eig_vals = np.zeros((n_block_modes, bs), dtype=np.complex128)
v_mat = np.zeros((n_block_modes, bs, bs), dtype=np.complex128)
n_gauge_total = 0
for m in range(A_block.shape[0]):
eig_result = sla.eig(A_block[m], B_block[m], right=True) # pyright: ignore[reportUnknownVariableType]
ev_m = eig_result[0] # pyright: ignore[reportUnknownVariableType]
vr_m = eig_result[1] # pyright: ignore[reportUnknownVariableType]
# Filter infinite/very-large eigenvalues (gauge modes)
gauge = ~np.isfinite(ev_m) | (np.abs(ev_m) > 1e12) # pyright: ignore[reportUnknownArgumentType]
ev_m[gauge] = 0.0 # gauge modes frozen at IC
n_gauge_total += int(np.sum(gauge))
eig_vals[m] = ev_m # pyright: ignore[reportUnknownArgumentType]
v_mat[m] = vr_m
if n_gauge_total > 0:
import logging as _log # noqa: PLC0415
_log.getLogger(__name__).info(
"Generalized eigenvalue: %d gauge modes zeroed across %d modes",
n_gauge_total,
A_block.shape[0],
)
v_inv = np.linalg.inv(v_mat)
else:
# Standard eigendecomposition (existing path)
eig_vals, v_mat = np.linalg.eig(A_block)
v_inv = np.linalg.inv(v_mat)
# Warn about potential overflow
_warn_eigenvalue_growth(eig_vals, dt_total, context="per-mode")
# Transform IC to eigenbasis
y0_eigen = np.einsum("mij,mj->mi", v_inv, y0_block.T)
block_data.append((block_slots, eig_vals, v_mat, y0_eigen))
# Evolve at each time point.
# Pre-multiply V @ diag(y0_eigen) for each block so the inner loop only
# needs element-wise exp + matrix-vector product, not a full einsum.
block_evolved: list[
tuple[
list[int], # slot indices
NDArray[np.complex128], # V_y0: V * y0_eigen, (n_modes, bs, bs)
NDArray[np.complex128] | None, # V_y0_deriv (n_modes, bs, bs) or None
NDArray[np.complex128], # eigenvalues (n_modes, bs)
]
] = []
for block_slots, eig_vals, v_mat, y0_eigen in block_data:
# V_y0[m, i, j] = v_mat[m, i, j] * y0_eigen[m, j]
# so y(t) = V_y0 @ exp(λ*dt) is just a matvec
V_y0 = v_mat * y0_eigen[:, np.newaxis, :] # (n_modes, bs, bs)
# V_y0_deriv[m, i, j] = V_y0[m, i, j] * λ[m, j]
# so y'(t) = V_y0_deriv @ exp(λ*dt) gives exact time derivative
V_y0_deriv = (
V_y0 * eig_vals[:, np.newaxis, :] if return_derivative_fourier else None
)
block_evolved.append((block_slots, V_y0, V_y0_deriv, eig_vals))
snapshots = np.zeros((n_snapshots, n_slots * n_pts))
times = np.zeros(n_snapshots)
n_modes = y0_hat.shape[1]
# Optionally collect Fourier-space snapshots (avoids re-FFT in constraint
# recovery — the Fourier data is already computed here).
fourier_snaps: NDArray[np.complex128] | None = None
if return_fourier:
fourier_snaps = np.zeros(
(n_snapshots, n_slots, n_modes),
dtype=np.complex128,
)
# Optionally collect Fourier-space TIME DERIVATIVE snapshots.
# d'(t) = V · diag(λ · exp(λt)) · y0_eigen — exact, no numerical diff.
# Used for machine-precision constraint velocity: v_c = recovery · d'.
deriv_fourier_snaps: NDArray[np.complex128] | None = None
dy_hat_t: NDArray[np.complex128] | None = None
if return_derivative_fourier:
deriv_fourier_snaps = np.zeros(
(n_snapshots, n_slots, n_modes),
dtype=np.complex128,
)
dy_hat_t = np.zeros((n_slots, n_modes), dtype=np.complex128)
# Pre-allocate buffer reused each timestep (avoids n_snapshots allocations)
y_hat_t = np.zeros((n_slots, n_modes), dtype=np.complex128)
for ti, t in enumerate(t_eval):
dt = t - t0
y_hat_t[:] = 0.0
if dy_hat_t is not None:
dy_hat_t[:] = 0.0
for block_slots, V_y0, V_y0_deriv, eig_vals in block_evolved:
# exp_lambda shape: (n_modes, block_size)
exp_lambda = np.exp(eig_vals * dt)
# y_evolved[m, i] = Σ_j V_y0[m, i, j] * exp(λ_j * dt)
y_evolved = np.einsum("mij,mj->mi", V_y0, exp_lambda)
y_hat_t[block_slots, :] = y_evolved.T
# Exact time derivative: dy[m,i] = Σ_j V_y0_deriv[m,i,j] * exp(λ_j*dt)
if V_y0_deriv is not None and dy_hat_t is not None:
dy_evolved = np.einsum("mij,mj->mi", V_y0_deriv, exp_lambda)
dy_hat_t[block_slots, :] = dy_evolved.T
if fourier_snaps is not None:
fourier_snaps[ti] = y_hat_t
if deriv_fourier_snaps is not None and dy_hat_t is not None:
deriv_fourier_snaps[ti] = dy_hat_t
y_physical = _ifft_slots(y_hat_t, layout, grid)
snapshots[ti] = y_physical
times[ti] = t
if snapshot_callback is not None:
snapshot_callback(t, y_physical)
if progress is not None:
progress.update(t)
return times, snapshots, fourier_snaps, deriv_fourier_snaps
def _evolve_full_matrix(
A_full: NDArray[np.complex128],
y0_hat: NDArray[np.complex128],
t_eval: NDArray[np.float64],
layout: StateLayout,
grid: GridInfo,
snapshot_callback: Callable[[float, NDArray[np.float64]], None] | None,
progress: SimulationProgress | None,
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
"""Evolve system with full coupled matrix (position-dependent coefficients).
A_full has shape (n_total, n_total) where n_total = n_slots x n_modes.
y0_hat has shape (n_slots, n_modes).
Uses ``scipy.sparse.linalg.expm_multiply`` to compute exp(A·t)·y₀ at each
output time without eigendecomposition. This is backward-stable for
non-normal matrices (position-dependent gradient coupling creates non-normal
convolution matrices whose eigenvalues have large real parts, but the true
dynamics are bounded). The algorithm uses scaling + truncated Taylor series
in matrix-vector products, avoiding individual exp(λ·t) overflow.
**Why not eigendecomposition?** The original full-matrix eigendecomposition
gave incorrect physics for localized Gertsenshtein (P=0.477 vs correct
P=0.3437) because non-normal convolution matrices have eigenvalues with
significant positive real parts despite conservative physics — individual
exp(λ·t) overflow while exp(A·t)·y₀ is bounded (pseudospectral phenomenon;
Trefethen & Embree 2005, Ch. 14).
**Sparse optimization:** For localized coefficients (e.g. Gaussian B₀), the
convolution kernel ĉ(q) decays exponentially, making the matrix effectively
banded. Entries below a relative threshold (1e-14 x max|A|) are zeroed, and
if density < 30% the matrix is converted to sparse CSC format. This
accelerates expm_multiply's internal matrix-vector products.
Ref: Al-Mohy & Higham (2011), "Computing the Action of the Matrix
Exponential", SIAM J. Sci. Comput. 33(2):488-511.
"""
import scipy.sparse # noqa: PLC0415 # pyright: ignore[reportMissingTypeStubs]
from scipy.sparse.linalg import ( # noqa: PLC0415 # pyright: ignore[reportMissingTypeStubs]
expm_multiply, # pyright: ignore[reportUnknownVariableType]
)
n_slots = layout.num_slots
n_pts = layout.num_points
n_modes = y0_hat.shape[1]
n_snapshots = len(t_eval)
# Flatten y0_hat to (n_total,) — slot-major order
y0_flat = y0_hat.ravel()
# --- Sparse matrix optimization ---
# Position-dependent convolution matrices are effectively banded: for
# Gaussian B₀(x) the kernel ĉ(q) decays exponentially, so most off-
# diagonal entries are negligibly small. Thresholding and converting to
# sparse CSC format accelerates expm_multiply's internal matrix-vector
# products (the dominant cost) without affecting accuracy.
abs_max = float(np.max(np.abs(A_full)))
if abs_max > 0:
threshold = abs_max * 1e-14
A_work = A_full.copy()
A_work[np.abs(A_work) < threshold] = 0.0
density = np.count_nonzero(A_work) / A_work.size
if density < 0.3:
A_op: NDArray[np.complex128] | scipy.sparse.csc_array = (
scipy.sparse.csc_array(A_work)
)
else:
A_op = A_work
else:
A_op = A_full
density = 1.0
snapshots = np.zeros((n_snapshots, n_slots * n_pts))
times = np.zeros(n_snapshots)
# Compute exp(A·t)·y₀ at all requested times using Krylov/Taylor method.
# expm_multiply handles the scaling internally — no manual dt stepping.
t0 = float(t_eval[0])
t_end = float(t_eval[-1])
if n_snapshots > 1 and t_end > t0:
# Use expm_multiply's built-in multi-point evaluation
y_all: NDArray[np.complex128] = np.asarray(
expm_multiply(
A_op,
y0_flat,
start=t0,
stop=t_end,
num=n_snapshots,
),
dtype=np.complex128,
)
# y_all has shape (n_snapshots, n_total)
for ti in range(n_snapshots):
t = float(t_eval[ti])
y_hat_t = y_all[ti].reshape(n_slots, n_modes)
y_physical = _ifft_slots(y_hat_t, layout, grid)
snapshots[ti] = y_physical
times[ti] = t
if snapshot_callback is not None:
snapshot_callback(t, y_physical)
if progress is not None:
progress.update(t)
else:
# Single time point or t0 == t_end
for ti, t in enumerate(t_eval):
if t == t0:
y_evolved = y0_flat.copy()
else:
y_evolved = np.asarray(
expm_multiply(A_op, y0_flat, start=t0, stop=float(t), num=2)[-1],
dtype=np.complex128,
)
y_hat_t = y_evolved.reshape(n_slots, n_modes)
y_physical = _ifft_slots(y_hat_t, layout, grid)
snapshots[ti] = y_physical
times[ti] = t
if snapshot_callback is not None:
snapshot_callback(t, y_physical)
if progress is not None:
progress.update(t)
return times, snapshots
# ---------------------------------------------------------------------------
# Public solver entry point
# ---------------------------------------------------------------------------
[docs]
def solve_modal(
spec: EquationSystem,
grid: GridInfo,
y0: np.ndarray,
t_span: tuple[float, float],
*,
bc: BCSpec | None = None,
parameters: dict[str, float] | None = None,
rtol: float = DEFAULT_RTOL,
atol: float = DEFAULT_ATOL,
num_snapshots: int = 101,
snapshot_callback: Callable[[float, np.ndarray], None] | None = None,
progress: SimulationProgress | None = None,
) -> SolverResult:
"""Solve a TIDAL equation system using Fourier modal decomposition.
Transforms the spatial grid to Fourier space, builds the per-mode or
full evolution matrix, eigendecomposes, and evaluates the exact solution
at each output time.
Parameters
----------
spec : EquationSystem
Parsed equation specification (from JSON).
grid : GridInfo
Spatial grid (must be all-periodic).
y0 : np.ndarray
Initial state vector (flat).
t_span : tuple[float, float]
(t_start, t_end).
bc : str or tuple, optional
Boundary conditions (must be all-periodic).
parameters : dict[str, float], optional
Runtime parameter overrides for symbolic coefficients.
rtol, atol : float
Tolerances (unused for eigendecomposition; reserved for solve_ivp
fallback with time-dependent coefficients).
num_snapshots : int
Number of output time points.
snapshot_callback : callable, optional
Called as ``callback(t, y)`` at each output time.
progress : SimulationProgress, optional
Progress tracker for tqdm display.
Returns
-------
SolverResult
Dict with keys: ``t``, ``y``, ``success``, ``message``.
"""
from tidal.solver.coefficients import CoefficientEvaluator # noqa: PLC0415
layout = StateLayout.from_spec(spec, grid.num_points)
coeff_eval = CoefficientEvaluator(spec, grid, parameters or {})
# Detect constraint fields
has_constraints = any(eq.time_derivative_order == 0 for eq in spec.equations)
if not has_constraints:
warn_frozen_constraints(layout, "modal")
# Build time evaluation points
t_eval = np.linspace(t_span[0], t_span[1], num_snapshots)
# Build k-grid
k_axes = _build_k_axes(grid)
k_grid = _build_k_grid(k_axes)
# Compute rfft output shape
rfft_shape_list = list(grid.shape)
rfft_shape_list[-1] = grid.shape[-1] // 2 + 1
rfft_shape = tuple(rfft_shape_list)
# FFT initial conditions
y0_hat = _fft_slots(y0, layout, grid)
# Zero the Nyquist mode(s) in the IC. The rfft Nyquist bin (last mode
# in each dimension) must be real for real-valued fields. The modal
# evolution matrix has complex entries (from gradient coupling ik),
# which creates imaginary components at the Nyquist bin. irfft
# silently drops these, causing energy non-conservation proportional
# to the Nyquist power. Zeroing the Nyquist IC prevents this entirely.
# This is standard practice in spectral methods — the Nyquist mode
# aliases with its conjugate and cannot represent physical content.
# Ref: Boyd (2001), Chebyshev & Fourier Spectral Methods, §11.5.
for _dim_idx, n in enumerate(grid.shape):
if n % 2 == 0: # Nyquist mode exists only for even N
nyq_mode = n // 2 # last rfft bin
if len(grid.shape) == 1:
y0_hat[:, nyq_mode] = 0.0
else:
# Multi-D: zero along the last-axis Nyquist slice
rfft_last = grid.shape[-1] // 2
y0_hat[:, ..., rfft_last] = 0.0
has_pos_dep = _has_position_dependent_terms(spec)
has_time_ops = _has_time_derivative_operators(spec)
# Determine which matrix builder to use
use_generalized = has_time_ops and not has_pos_dep
use_constraint = has_constraints and not has_pos_dep and not use_generalized
B_lhs_modes: NDArray[np.complex128] | None = None # set by generalized path
constraint_vel_arrays: dict[
str, NDArray[np.float64]
] = {} # populated by Schur path
if use_generalized or use_constraint:
# Both paths produce: A_reduced, recovery, constraint names, slot mapping
if use_generalized:
# Generalized mass-matrix system: M·ẍ = K·x + D·ẋ + J·x⃛
# Returns A_rhs and optional B_lhs for generalized eigenvalue.
# Ref: Golub & Van Loan (2013), Matrix Computations §7.7
(
A_reduced,
B_lhs_modes, # CRITICAL: was _B_lhs_modes (discarded) — #177
recovery_matrix,
_v_recovery_gen, # unused — constraint vel from eigendata
c_names,
orig_to_reduced,
) = _build_generalized_evolution_matrices(
spec,
layout,
grid,
coeff_eval,
k_grid,
rfft_shape,
)
else:
# Constraint elimination via Fourier Schur complement
# Ref: Hairer & Wanner (1996), Solving ODEs II, Ch. VII
(
A_reduced,
recovery_matrix,
_v_recovery_matrix, # recovery @ A_reduced (exact constraint velocities)
c_names,
orig_to_reduced,
) = _build_constraint_eliminated_matrices(
spec,
layout,
grid,
coeff_eval,
k_grid,
rfft_shape,
)
n_dyn = A_reduced.shape[1]
n_modes = y0_hat.shape[1]
n_pts = layout.num_points
# Extract dynamical IC in reduced ordering
y0_hat_dyn = np.zeros((n_dyn, n_modes), dtype=np.complex128)
for orig_si, red_pos in orig_to_reduced.items():
y0_hat_dyn[red_pos] = y0_hat[orig_si]
# Build a reduced StateLayout for eigendecomposition
sorted_orig = sorted(orig_to_reduced.keys())
red_slots = tuple(layout.slots[si] for si in sorted_orig)
red_field_map: dict[str, int] = {}
red_vel_map: dict[str, int] = {}
for new_i, si in enumerate(sorted_orig):
s = layout.slots[si]
if s.kind == "field":
red_field_map[s.field_name] = new_i
elif s.kind == "velocity":
red_vel_map[s.field_name] = new_i
dyn_layout = StateLayout(
slots=red_slots,
num_points=n_pts,
field_slot_map=red_field_map,
velocity_slot_map=red_vel_map,
dynamical_fields=layout.dynamical_fields,
)
# Evolve dynamical fields (return Fourier + derivative data)
times, dyn_snapshots, dyn_fourier, dyn_deriv_fourier = _evolve_per_mode(
A_reduced,
y0_hat_dyn,
t_eval,
dyn_layout,
grid,
None,
progress, # callback handled below with full state
return_fourier=True,
return_derivative_fourier=True, # for exact constraint velocities
B_modes=B_lhs_modes, # generalized eigenvalue if vel coupling
)
# Reconstruct full state (including constraints) at each snapshot
n_full = layout.num_slots * n_pts
snapshots = np.zeros((len(t_eval), n_full))
assert dyn_fourier is not None # guaranteed by return_fourier=True
# Populate constraint velocity arrays: exact ∂_t(c) from eigendata.
# "Constraint" is a solver concept (algebraic evolution), not a physics
# statement — these fields have physically meaningful velocities.
# v_c(t) = recovery · d'(t), where d'(t) is computed from eigendata
# inside _evolve_per_mode (V·diag(λ·exp(λt))·y0_eigen — exact).
for c_name in c_names:
constraint_vel_arrays[c_name] = np.zeros((len(t_eval), *grid.shape))
for ti in range(len(t_eval)):
dyn_phys = dyn_snapshots[ti]
# Use Fourier data directly (already computed in _evolve_per_mode)
y_hat_dyn_t = dyn_fourier[ti] # (n_dyn, n_modes)
# Recover constraint fields: c_hat = recovery @ d_hat
c_hat = np.einsum("mcj,jm->cm", recovery_matrix, y_hat_dyn_t)
# Recover constraint velocities: v_c_hat = recovery @ d'_hat
# d'_hat comes from eigendata — exact, no numerical differentiation.
assert dyn_deriv_fourier is not None
dy_hat_dyn_t = dyn_deriv_fourier[ti] # (n_dyn, n_modes)
v_c_hat = np.einsum("mcj,jm->cm", recovery_matrix, dy_hat_dyn_t)
# Assemble full physical state
full_state = np.zeros(n_full)
for orig_si, red_pos in orig_to_reduced.items():
full_state[orig_si * n_pts : (orig_si + 1) * n_pts] = dyn_phys[
red_pos * n_pts : (red_pos + 1) * n_pts
]
for ci, c_name in enumerate(c_names):
c_slot = layout.field_slot_map[c_name]
c_phys = np.fft.irfftn(
c_hat[ci].reshape(rfft_shape),
s=grid.shape,
axes=list(range(len(grid.shape))),
).ravel()
full_state[c_slot * n_pts : (c_slot + 1) * n_pts] = np.real(
c_phys,
)
# Store exact constraint velocity (from eigendata d')
v_c_phys = np.fft.irfftn(
v_c_hat[ci].reshape(rfft_shape),
s=grid.shape,
axes=list(range(len(grid.shape))),
)
constraint_vel_arrays[c_name][ti] = np.real(v_c_phys)
snapshots[ti] = full_state
if snapshot_callback is not None:
snapshot_callback(t_eval[ti], full_state)
n_c = len(c_names)
if use_generalized:
method_desc = (
f"per-mode eigendecomposition with generalized Schur elimination "
f"({n_c} constraints, {n_dyn} dynamical slots, mass-matrix)"
)
else:
method_desc = (
f"per-mode eigendecomposition with Schur constraint elimination "
f"({n_c} constraints, {n_dyn} dynamical slots)"
)
elif not has_pos_dep:
# All-constant coefficients: per-mode independent evolution
A_modes = _build_per_mode_matrices(
spec,
layout,
grid,
coeff_eval,
k_grid,
rfft_shape,
)
times, snapshots, _, _ = _evolve_per_mode(
A_modes,
y0_hat,
t_eval,
layout,
grid,
snapshot_callback,
progress,
)
method_desc = "per-mode eigendecomposition (constant coefficients)"
else:
# Position-dependent coefficients: full convolution matrix
A_full = _build_convolution_matrix(
spec,
layout,
grid,
coeff_eval,
k_grid,
rfft_shape,
)
times, snapshots = _evolve_full_matrix(
A_full,
y0_hat,
t_eval,
layout,
grid,
snapshot_callback,
progress,
)
n_total = A_full.shape[0]
method_desc = f"expm_multiply ({n_total}x{n_total}, position-dependent)"
if progress is not None:
progress.finish()
result: SolverResult = {
"t": times,
"y": snapshots,
"success": True,
"message": f"Modal solver completed ({method_desc})",
}
# Attach constraint velocity arrays (exact ∂_t for constraint fields).
# These are computed from v_recovery = recovery @ A_reduced inside the
# Schur elimination path. For generalized eigenvalue or non-constraint
# systems, constraint_vel_arrays is empty.
if constraint_vel_arrays:
result["constraint_velocities"] = constraint_vel_arrays # type: ignore[typeddict-unknown-key]
return result