Source code for tidal.solver.modal

"""Fourier modal solver — exact spectral time evolution for linear PDEs.

Transforms the spatial grid to Fourier space, builds a per-mode evolution
matrix, and eigendecomposes to obtain the exact solution y(t) = exp(A·t)·y₀.
Eliminates spatial discretization error entirely and provides machine-precision
solutions for time-independent linear systems.

Applicable to any linear PDE system with:
- Flat (Minkowski) metric
- All-periodic boundary conditions
- Time-independent coefficients (position-dependent OK via convolution)
- Operators with known exact Fourier multipliers

Two algorithm paths are used depending on coefficient structure:
- Constant coefficients: per-mode eigendecomposition with block-aware independent
  blocks (machine-precision, ~14x faster).
- Position-dependent coefficients: Krylov matrix exponential (expm_multiply) which
  is backward-stable for non-normal convolution matrices where eigendecomposition
  gives incorrect results due to pseudospectral overflow.

References
----------
    Moler & Van Loan (2003), SIAM Review 45(1):3-49 — matrix exponential.
    Hairer, Lubich & Wanner (2006), Geometric Numerical Integration, §4.
    Burns et al. (2020), Phys. Rev. Research 2:023068 — pseudo-spectral.
    Raffelt & Stodolsky (1988), PRD 37:1237 — mixing-matrix formalism.
"""

# ruff: noqa: N803, N806 — uppercase names for matrices (A, V, T, Z) follow
#   standard linear-algebra notation.
# ruff: noqa: PLR0913, PLR0917, PLR0914, PLR0912, PLR0911, PLR0915, PLR2004
#   — numerical code inherently requires many arguments, local variables,
#   return statements, statements, and literal comparisons.
# ruff: noqa: C901 — complexity and Unicode math symbols.
# ruff: noqa: ERA001, ARG001 — commented-out code serves as documentation;
#   unused args (bc, grid) kept for interface consistency with other solvers.
# ruff: noqa: B903, PLR1702 — _OperatorDecomp uses __slots__ for memory efficiency;
#   nested block depth is inherent to multi-field modal algebra.

from __future__ import annotations

from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any

import numpy as np
from numpy.typing import NDArray

from tidal.solver._defaults import DEFAULT_ATOL, DEFAULT_RTOL
from tidal.solver._setup import warn_frozen_constraints
from tidal.solver.operators import get_wavenumbers, is_periodic_bc
from tidal.solver.state import StateLayout

if TYPE_CHECKING:
    from tidal.solver._types import SolverResult
    from tidal.solver.coefficients import CoefficientEvaluator
    from tidal.solver.grid import GridInfo
    from tidal.solver.operators import BCSpec
    from tidal.solver.progress import SimulationProgress
    from tidal.symbolic.json_loader import (
        ComponentEquation,
        EquationSystem,
        OperatorTerm,
    )

# ---------------------------------------------------------------------------
# Exact Fourier multipliers (angular wavenumber convention: k = 2π·rfftfreq)
# ---------------------------------------------------------------------------
# These use the EXACT wavenumber k, consistent with operators.py spectral
# mode (gradient → ik, laplacian → -k²).  NOT the modified-wavenumber
# convention from constraint_solve.py which matches FD stencils.

_ExactMultFn = Callable[[list[NDArray[np.float64]]], NDArray[Any] | int]

# ---------------------------------------------------------------------------
# Operator decomposition: (spatial Fourier multiplier, time derivative order)
# ---------------------------------------------------------------------------
# Every operator in flat spacetime decomposes as spatial_multiplier(k) x ∂ⁿ_t.
# Derivatives commute in Minkowski: ∂²_t ∂_x = ∂_x ∂²_t.
#
# Time order classification:
#   0 = position operator   → stiffness matrix K
#   1 = velocity operator   → damping matrix D
#   2 = acceleration operator → mass matrix M (off-diagonal, implicit coupling)
#   3 = jerk operator       → eliminated via EOM substitution
#
# References:
#   Golub & Van Loan (2013), Matrix Computations, 4th ed. §7.7
#   Hairer & Lubich (2003), ZAMM 83(1) — mass matrices in structural dynamics


class _OperatorDecomp:
    """Operator decomposition into spatial multiplier and time derivative order."""

    __slots__ = ("spatial_fn", "time_order")

    def __init__(self, spatial_fn: _ExactMultFn, time_order: int) -> None:
        self.spatial_fn = spatial_fn
        self.time_order = time_order


_OPERATOR_DECOMP: dict[str, _OperatorDecomp] = {
    # --- Pure spatial operators (time_order=0) ---
    "identity": _OperatorDecomp(lambda k_axes: np.ones_like(k_axes[0]), 0),
    "laplacian": _OperatorDecomp(
        lambda k_axes: -sum(ki**2 for ki in k_axes),
        0,
    ),
    "laplacian_x": _OperatorDecomp(lambda k_axes: -(k_axes[0] ** 2), 0),
    "laplacian_y": _OperatorDecomp(lambda k_axes: -(k_axes[1] ** 2), 0),
    "laplacian_z": _OperatorDecomp(lambda k_axes: -(k_axes[2] ** 2), 0),
    "gradient_x": _OperatorDecomp(lambda k_axes: 1j * k_axes[0], 0),
    "gradient_y": _OperatorDecomp(lambda k_axes: 1j * k_axes[1], 0),
    "gradient_z": _OperatorDecomp(lambda k_axes: 1j * k_axes[2], 0),
    "cross_derivative_xy": _OperatorDecomp(lambda k_axes: -(k_axes[0] * k_axes[1]), 0),
    "cross_derivative_xz": _OperatorDecomp(lambda k_axes: -(k_axes[0] * k_axes[2]), 0),
    "cross_derivative_yz": _OperatorDecomp(lambda k_axes: -(k_axes[1] * k_axes[2]), 0),
    "biharmonic": _OperatorDecomp(
        lambda k_axes: sum(ki**2 for ki in k_axes) ** 2,
        0,
    ),
    "derivative_3_x": _OperatorDecomp(lambda k_axes: -1j * k_axes[0] ** 3, 0),
    "derivative_3_y": _OperatorDecomp(lambda k_axes: -1j * k_axes[1] ** 3, 0),
    "derivative_3_z": _OperatorDecomp(lambda k_axes: -1j * k_axes[2] ** 3, 0),
    # --- Velocity operators (time_order=1) ---
    "first_derivative_t": _OperatorDecomp(lambda k_axes: np.ones_like(k_axes[0]), 1),
    "mixed_T1_S1x": _OperatorDecomp(lambda k_axes: 1j * k_axes[0], 1),
    "mixed_T1_S1y": _OperatorDecomp(lambda k_axes: 1j * k_axes[1], 1),
    "mixed_T1_S1z": _OperatorDecomp(lambda k_axes: 1j * k_axes[2], 1),
    # --- Acceleration operators (time_order=2) ---
    "d2_t": _OperatorDecomp(lambda k_axes: np.ones_like(k_axes[0]), 2),
    "mixed_T2_S1x": _OperatorDecomp(lambda k_axes: 1j * k_axes[0], 2),
    "mixed_T2_S1y": _OperatorDecomp(lambda k_axes: 1j * k_axes[1], 2),
    "mixed_T2_S1z": _OperatorDecomp(lambda k_axes: 1j * k_axes[2], 2),
    "mixed_T2_S2x": _OperatorDecomp(lambda k_axes: -(k_axes[0] ** 2), 2),
    "mixed_T2_S2y": _OperatorDecomp(lambda k_axes: -(k_axes[1] ** 2), 2),
    "mixed_T2_S2z": _OperatorDecomp(lambda k_axes: -(k_axes[2] ** 2), 2),
    # --- Jerk operators (time_order=3, eliminated via EOM substitution) ---
    "d3_t": _OperatorDecomp(lambda k_axes: np.ones_like(k_axes[0]), 3),
    "mixed_T3_S1x": _OperatorDecomp(lambda k_axes: 1j * k_axes[0], 3),
    "mixed_T3_S1y": _OperatorDecomp(lambda k_axes: 1j * k_axes[1], 3),
    "mixed_T3_S1z": _OperatorDecomp(lambda k_axes: 1j * k_axes[2], 3),
}

# Backward-compatible mapping: operator name → spatial multiplier function.
# Used by existing code paths that only need the spatial part.
_EXACT_MULTIPLIERS: dict[str, _ExactMultFn] = {
    name: dec.spatial_fn for name, dec in _OPERATOR_DECOMP.items()
}


# ---------------------------------------------------------------------------
# Eligibility check
# ---------------------------------------------------------------------------


[docs] def can_use_modal( spec: EquationSystem, grid: GridInfo, bc: BCSpec | None, ) -> bool: """Check whether the modal solver is applicable to this system. Requirements (checked in order): 1. Flat metric (volume_element is None) 2. Constraints (time_order=0) must be Fourier-eliminable via Schur complement 3. All boundary conditions periodic 4. All RHS operators have exact Fourier multipliers 5. No time-dependent coefficients """ # 1. Flat metric — curved metrics have non-None volume_element if spec.canonical is not None and spec.canonical.volume_element is not None: return False # Also reject if canonical is None but any term is position-dependent # with non-Cartesian coordinate names (heuristic for curved metrics # without canonical structure) if spec.canonical is None: for eq in spec.equations: for term in eq.rhs_terms: if term.position_dependent: coords = set(term.coordinate_dependent) # Non-Cartesian coordinate references suggest curved metric cartesian = {"x", "y", "z"} if coords - cartesian: return False # 2. Constraints — allow if Fourier-eliminable via Schur complement constraint_eqs = [eq for eq in spec.equations if eq.time_derivative_order == 0] if constraint_eqs and not _constraints_fourier_eliminable(spec, constraint_eqs): return False # 3. All-periodic BCs if not all(grid.periodic): return False if bc is not None: if isinstance(bc, str): if not is_periodic_bc(bc): return False else: for b in bc: if not is_periodic_bc(b): return False # 4. All operators supported (spatial or time-derivative decomposable) for eq in spec.equations: for term in eq.rhs_terms: if term.operator not in _OPERATOR_DECOMP: return False # 5. No time-dependent coefficients for eq in spec.equations: for term in eq.rhs_terms: if term.time_dependent: return False return True
# --------------------------------------------------------------------------- # FFT state transforms # --------------------------------------------------------------------------- def _fft_slots( y: NDArray[np.float64], layout: StateLayout, grid: GridInfo, ) -> NDArray[np.complex128]: """Transform each slot from physical space to Fourier space (rfft). Returns a complex array of shape (n_slots, n_modes) where n_modes is the rfft output length for the 1D case (n//2+1), or the product of rfft output lengths for multi-D. """ n_slots = layout.num_slots n_pts = layout.num_points shape = grid.shape # For 1D: rfft output length is shape[0]//2 + 1 # For nD: rfftn produces shape[:-1] + (shape[-1]//2+1,) # Compute analytically instead of probing with a zero FFT. rfft_shape = list(shape) rfft_shape[-1] = shape[-1] // 2 + 1 n_modes = int(np.prod(rfft_shape)) y_hat = np.zeros((n_slots, n_modes), dtype=np.complex128) for slot_idx in range(n_slots): start = slot_idx * n_pts end = start + n_pts field_data = y[start:end].reshape(shape) y_hat[slot_idx] = np.fft.rfftn(field_data).ravel() return y_hat def _ifft_slots( y_hat: NDArray[np.complex128], layout: StateLayout, grid: GridInfo, ) -> NDArray[np.float64]: """Transform each slot from Fourier space back to physical space (irfft). Returns a real flat array of shape (n_slots * n_pts,). """ n_slots = layout.num_slots n_pts = layout.num_points shape = grid.shape # Determine rfftn output shape for irfftn reconstruction rfft_shape = list(shape) rfft_shape[-1] = shape[-1] // 2 + 1 rfft_shape_tuple = tuple(rfft_shape) y_out = np.zeros(n_slots * n_pts) for slot_idx in range(n_slots): hat_data = y_hat[slot_idx].reshape(rfft_shape_tuple) physical = np.fft.irfftn(hat_data, s=shape, axes=list(range(len(shape)))) y_out[slot_idx * n_pts : (slot_idx + 1) * n_pts] = physical.ravel() return y_out # --------------------------------------------------------------------------- # Wavenumber grid construction # --------------------------------------------------------------------------- def _build_k_axes(grid: GridInfo) -> list[NDArray[np.float64]]: """Build wavenumber arrays for each spatial axis. Uses the same convention as operators.get_wavenumbers: k = 2π · rfftfreq(N, d=dx) for the last axis (rfft), k = 2π · fftfreq(N, d=dx) for all other axes (full fft). """ k_axes: list[NDArray[np.float64]] = [] ndim = grid.ndim for ax in range(ndim): n = grid.shape[ax] dx = grid.dx[ax] if ax == ndim - 1: # Last axis uses rfft (half-complex) k = get_wavenumbers(n, dx) else: # Other axes use full fft k = np.asarray(2.0 * np.pi * np.fft.fftfreq(n, d=dx), dtype=np.float64) k_axes.append(k) return k_axes def _build_k_grid( k_axes: list[NDArray[np.float64]], ) -> list[NDArray[np.float64]]: """Build broadcasted k-grid arrays from per-axis wavenumbers. Returns a list of arrays, one per axis, each broadcastable to the full rfft output shape. """ ndim = len(k_axes) k_grid: list[NDArray[np.float64]] = [] for ax in range(ndim): shape = [1] * ndim shape[ax] = len(k_axes[ax]) k_grid.append(k_axes[ax].reshape(shape)) return k_grid # --------------------------------------------------------------------------- # Constraint elimination (Fourier Schur complement) # --------------------------------------------------------------------------- # Ref: Hairer & Wanner (1996), Solving ODEs II, Ch. VII — DAE reduction. # Ref: Ascher & Petzold (1998), Computer Methods for ODEs/DAEs, §10.2. def _constraints_fourier_eliminable( spec: EquationSystem, constraint_eqs: Sequence[ComponentEquation], ) -> bool: """Check if all constraint equations can be eliminated in Fourier space. Requirements: - Each constraint operator must be decomposable (spatial x time) - No time-dependent coefficients in constraints Constraints may contain acceleration operators (mixed_T2_S1x, d2_t) which are handled by substituting the dynamical equations of motion before Schur elimination. """ for eq in constraint_eqs: for term in eq.rhs_terms: if term.operator not in _OPERATOR_DECOMP: return False if term.time_dependent: return False return True def _build_constraint_eliminated_matrices( spec: EquationSystem, layout: StateLayout, grid: GridInfo, coeff_eval: object, # CoefficientEvaluator k_grid: list[NDArray[np.float64]], rfft_shape: tuple[int, ...], ) -> tuple[ NDArray[np.complex128], # A_reduced (n_modes, n_dyn, n_dyn) NDArray[np.complex128], # recovery (n_modes, n_constraints, n_dyn) NDArray[np.complex128], # v_recovery (n_modes, n_constraints, n_dyn) list[str], # constraint_field_names dict[int, int], # orig_to_reduced slot mapping ]: """Build reduced per-mode matrices with constraints algebraically eliminated. For a mixed system with dynamical (d) and constraint (c) fields: d/dt[d] = A_dd·d + A_dc·c (dynamical equations) 0 = S_cd·d + S_cc·c (constraint: solve for c) The constraint gives c = -S_cc⁻¹·S_cd·d. Substituting: d/dt[d] = (A_dd - A_dc·S_cc⁻¹·S_cd)·d This handles v_A₀ references in dynamical equations by recognizing that v_A₀ = dA₀/dt = d/dt[-S_cc⁻¹·S_cd·d] = -S_cc⁻¹·S_cd·d', creating an implicit equation (I - A_dc_v·S_cc⁻¹·S_cd)·d' = (A_dd + A_dc_f·f)·d which is resolved by matrix inversion of the LHS factor. All operations are purely numeric (CoefficientEvaluator returns floats). S_cc⁻¹ in Fourier space is diagonal per mode — just 1/(m²+k²). Returns ------- A_reduced : ndarray Per-mode matrices (n_modes, n_dyn, n_dyn) for dynamical fields only. recovery : ndarray Per-mode recovery (n_modes, n_constraints, n_dyn) for reconstructing constraint fields from dynamical state. constraint_field_names : list[str] Names of eliminated constraint fields. orig_to_reduced : dict Mapping from original layout slot index to reduced slot index. """ from tidal.solver.coefficients import CoefficientEvaluator # noqa: PLC0415 assert isinstance(coeff_eval, CoefficientEvaluator) n_modes = int(np.prod(rfft_shape)) # Identify constraint and dynamical fields constraint_field_names: list[str] = [] constraint_eq_map: dict[str, int] = {} # field_name → eq_idx for eq_idx, eq in enumerate(spec.equations): if eq.time_derivative_order == 0: constraint_field_names.append(eq.field_name) constraint_eq_map[eq.field_name] = eq_idx n_c = len(constraint_field_names) # Build dynamical-only slot mapping (excluding constraint field slots) orig_to_reduced: dict[int, int] = {} red_idx = 0 for si, slot in enumerate(layout.slots): if slot.kind == "constraint": continue orig_to_reduced[si] = red_idx red_idx += 1 n_dyn = red_idx # Map field names to slot indices in the REDUCED layout dyn_slot_map: dict[str, int] = {} for si, slot in enumerate(layout.slots): if si in orig_to_reduced: dyn_slot_map[slot.name] = orig_to_reduced[si] # Also map velocity names v_X for dynamical fields for fname, si in layout.velocity_slot_map.items(): v_name = f"v_{fname}" if si in orig_to_reduced: dyn_slot_map[v_name] = orig_to_reduced[si] # Constraint slot map (constraint field names → constraint index 0..n_c-1) c_idx_map: dict[str, int] = { name: i for i, name in enumerate(constraint_field_names) } # Evaluate Fourier multipliers multiplier_cache: dict[str, NDArray[np.complex128]] = {} for eq in spec.equations: for term in eq.rhs_terms: op = term.operator if op not in multiplier_cache: mult_fn = _EXACT_MULTIPLIERS[op] mult_val = mult_fn(k_grid) mult_full = np.broadcast_to(mult_val, rfft_shape) multiplier_cache[op] = mult_full.ravel().astype(np.complex128) # --- Build the four coupling matrices per mode --- # A_dd: dynamical → dynamical (n_modes, n_dyn, n_dyn) A_dd = np.zeros((n_modes, n_dyn, n_dyn), dtype=np.complex128) # A_dc: constraint → dynamical via FIELD references (n_modes, n_dyn, n_c) A_dc_field = np.zeros((n_modes, n_dyn, n_c), dtype=np.complex128) # A_dc_vel: constraint VELOCITY → dynamical (n_modes, n_dyn, n_c) A_dc_vel = np.zeros((n_modes, n_dyn, n_c), dtype=np.complex128) # S_cd: dynamical → constraint source (n_modes, n_c, n_dyn) S_cd = np.zeros((n_modes, n_c, n_dyn), dtype=np.complex128) # S_cc: constraint self-coupling (n_modes, n_c, n_c) S_cc = np.zeros((n_modes, n_c, n_c), dtype=np.complex128) for eq_idx, eq in enumerate(spec.equations): is_constraint = eq.time_derivative_order == 0 is_second_order = eq.time_derivative_order >= 2 if is_constraint: # Constraint equation: 0 = Σ coeff·op(target) ci = c_idx_map[eq.field_name] for term_idx, term in enumerate(eq.rhs_terms): coeff = _resolve_constant_coeff( term, coeff_eval, eq_idx=eq_idx, term_idx=term_idx, ) mult = multiplier_cache[term.operator] if term.field in c_idx_map: # Self/cross constraint coupling cj = c_idx_map[term.field] S_cc[:, ci, cj] += coeff * mult elif term.field in dyn_slot_map: # Source coupling to dynamical state dj = dyn_slot_map[term.field] S_cd[:, ci, dj] += coeff * mult elif is_second_order: field_slot = orig_to_reduced[layout.field_slot_map[eq.field_name]] vel_slot = orig_to_reduced[layout.velocity_slot_map[eq.field_name]] # Kinematic: dq/dt = v A_dd[:, field_slot, vel_slot] = 1.0 # RHS terms: dv/dt = Σ coeff·op(target) for term_idx, term in enumerate(eq.rhs_terms): coeff = _resolve_constant_coeff( term, coeff_eval, eq_idx=eq_idx, term_idx=term_idx, ) mult = multiplier_cache[term.operator] if term.field in c_idx_map: # References constraint field directly cj = c_idx_map[term.field] A_dc_field[:, vel_slot, cj] += coeff * mult elif term.field.startswith("v_") and term.field[2:] in c_idx_map: # References constraint velocity v_A₀ cj = c_idx_map[term.field[2:]] A_dc_vel[:, vel_slot, cj] += coeff * mult elif term.field in dyn_slot_map: # Normal dynamical reference dj = dyn_slot_map[term.field] A_dd[:, vel_slot, dj] += coeff * mult else: # First-order: du/dt = Σ coeff·op(target) this_slot = orig_to_reduced[layout.field_slot_map[eq.field_name]] for term_idx, term in enumerate(eq.rhs_terms): coeff = _resolve_constant_coeff( term, coeff_eval, eq_idx=eq_idx, term_idx=term_idx, ) mult = multiplier_cache[term.operator] if term.field in c_idx_map: cj = c_idx_map[term.field] A_dc_field[:, this_slot, cj] += coeff * mult elif term.field in dyn_slot_map: dj = dyn_slot_map[term.field] A_dd[:, this_slot, dj] += coeff * mult # --- Compute Schur complement --- # Batch-invert S_cc across all modes (small matrices, typically 1x1 or 2x2) # Detect and regularize singular modes (e.g. k=0 gauge freedom) dets = np.linalg.det(S_cc) if n_c > 0 else np.ones(n_modes) singular_mask = np.abs(dets) < 1e-14 S_cc_reg = S_cc.copy() if np.any(singular_mask): S_cc_reg[singular_mask] += 1e-14 * np.eye(n_c, dtype=np.complex128) S_cc_inv = np.linalg.inv(S_cc_reg) # (n_modes, n_c, n_c) # Recovery: c = -S_cc⁻¹ · S_cd · d # recovery[m, ci, dj] = -Σ_cj S_cc_inv[m,ci,cj] · S_cd[m,cj,dj] recovery = -np.einsum("mij,mjk->mik", S_cc_inv, S_cd) # Substitution: A_dc_field · c = A_dc_field · recovery · d # field_correction[m] = A_dc_field[m] @ recovery[m] field_correction = np.einsum("mij,mjk->mik", A_dc_field, recovery) # For constraint velocity: v_c = d/dt[c] = recovery · d' # where d' = A_reduced · d. So A_dc_vel · v_c = A_dc_vel · recovery · d'. # This creates implicit coupling: # d' = A_dd · d + field_correction · d + A_dc_vel · recovery · d' # (I - A_dc_vel · recovery) · d' = (A_dd + field_correction) · d # d' = (I - A_dc_vel · recovery)⁻¹ · (A_dd + field_correction) · d vel_coupling = np.einsum("mij,mjk->mik", A_dc_vel, recovery) # Check if vel_coupling is nonzero (constraint velocity referenced) has_vel_coupling = np.max(np.abs(vel_coupling)) > 1e-15 A_rhs = A_dd + field_correction if has_vel_coupling: # Implicit solve: (I - vel_coupling) · d' = A_rhs · d eye = np.broadcast_to( np.eye(n_dyn, dtype=np.complex128), (n_modes, n_dyn, n_dyn), ).copy() lhs = eye - vel_coupling # Batch solve: A_reduced = lhs⁻¹ · A_rhs (all modes at once) A_reduced: NDArray[np.complex128] = np.asarray( np.linalg.solve(lhs, A_rhs), dtype=np.complex128, ) else: A_reduced = A_rhs # Velocity recovery: v_c_hat = v_recovery @ d_hat gives exact ∂_t(c) # Derived from: c = recovery · d, so ∂_t c = recovery · ∂_t d = recovery · A_reduced · d # This is machine-precision — no numerical differentiation needed. # Stored per-mode: shape (n_modes, n_c, n_dyn). Typically < 50 MB for 128³. v_recovery = np.einsum("mci,mij->mcj", recovery, A_reduced) return A_reduced, recovery, v_recovery, constraint_field_names, orig_to_reduced # --------------------------------------------------------------------------- # Generalized mass-matrix evolution (M·ẍ = K·x + D·ẋ + J·x⃛) # --------------------------------------------------------------------------- # For systems with implicit acceleration coupling (d2_t, mixed_T2_S*) and # jerk coupling (d3_t, mixed_T3_S*). The mass matrix M may be singular, # creating hidden constraints analogous to time_order=0 fields. # # Algorithm: # 1. Build M, D, K, J matrices from operator decomposition # 2. Eigendecompose M per mode — zero eigenvalues → constraints # 3. Schur-eliminate mass-matrix constraints (same as constraint fields) # 4. Substitute jerk terms using equations of motion # 5. Build first-order evolution matrix A = [[0,I],[M⁻¹K, M⁻¹D]] # 6. Combine with existing constraint field Schur elimination # # References: # Golub & Van Loan (2013), Matrix Computations §7.7 (generalized eigenvalue) # Hairer & Lubich (2003), ZAMM 83(1) (mass matrices in dynamics) # Ostrogradsky (1850), Mem. Acad. St. Petersbourg VI 4, 385 def _has_time_derivative_operators(spec: EquationSystem) -> bool: """Check whether any equation has time-derivative operators on its RHS.""" for eq in spec.equations: for term in eq.rhs_terms: decomp = _OPERATOR_DECOMP.get(term.operator) if decomp is not None and decomp.time_order > 0: return True return False def _build_generalized_evolution_matrices( spec: EquationSystem, layout: StateLayout, grid: GridInfo, coeff_eval: object, # CoefficientEvaluator k_grid: list[NDArray[np.float64]], rfft_shape: tuple[int, ...], ) -> tuple[ NDArray[np.complex128], # A_rhs (n_modes, n_dyn_slots, n_dyn_slots) NDArray[np.complex128] | None, # B_lhs (n_modes, n_dyn, n_dyn) or None NDArray[np.complex128], # recovery (n_modes, n_total_constraints, n_dyn_slots) NDArray[np.complex128] | None, # v_recovery (n_modes, n_c, n_dyn) or None list[str], # all constraint field names dict[int, int], # orig_to_reduced slot mapping ]: """Build per-mode matrices for systems with mass-matrix coupling. Returns A_rhs and optionally B_lhs for the generalized eigenvalue problem B·d' = A·d, where B = I - vel_coupling may be singular (gauge freedom from circular constraint velocity dependencies). When B_lhs is not None, the caller should use scipy.linalg.eig(A, B) (QZ decomposition) instead of np.linalg.eig(A) for eigendecomposition. Infinite eigenvalues correspond to gauge-constrained directions. Handles the generalized second-order system: M(k)·ẍ = K(k)·x + D(k)·ẋ + J(k)·x⃛ where M may be singular (creating hidden algebraic constraints) and J encodes jerk coupling from d3_t/mixed_T3_S* operators. The algorithm: 1. Separates constraint (time_order=0) and dynamical fields 2. Builds M, D, K matrices for dynamical fields from operator decomposition 3. Eigendecomposes M per mode to detect singular directions 4. Treats zero-eigenvalue directions as additional constraints (Schur) 5. Substitutes jerk terms using the (now-invertible) dynamical equations 6. Combines both constraint levels and builds the first-order evolution matrix Returns the same tuple as ``_build_constraint_eliminated_matrices``. """ import logging # noqa: PLC0415 from tidal.solver.coefficients import CoefficientEvaluator # noqa: PLC0415 assert isinstance(coeff_eval, CoefficientEvaluator) logger = logging.getLogger(__name__) n_modes = int(np.prod(rfft_shape)) # ---- Identify constraint and dynamical fields ---- constraint_field_names: list[str] = [ eq.field_name for eq in spec.equations if eq.time_derivative_order == 0 ] c_idx_map: dict[str, int] = { name: i for i, name in enumerate(constraint_field_names) } n_c = len(constraint_field_names) # Build dynamical-only slot mapping (excluding constraint field slots) orig_to_reduced: dict[int, int] = {} red_idx = 0 for si, slot in enumerate(layout.slots): if slot.kind == "constraint": continue orig_to_reduced[si] = red_idx red_idx += 1 n_dyn_slots = red_idx # Map field/velocity names → reduced slot indices dyn_slot_map: dict[str, int] = {} for si, slot in enumerate(layout.slots): if si in orig_to_reduced: dyn_slot_map[slot.name] = orig_to_reduced[si] for fname, si in layout.velocity_slot_map.items(): v_name = f"v_{fname}" if si in orig_to_reduced: dyn_slot_map[v_name] = orig_to_reduced[si] # ---- Evaluate spatial Fourier multipliers ---- multiplier_cache: dict[str, NDArray[np.complex128]] = {} for eq in spec.equations: for term in eq.rhs_terms: op = term.operator if op not in multiplier_cache: decomp = _OPERATOR_DECOMP[op] mult_val = decomp.spatial_fn(k_grid) mult_full = np.broadcast_to(mult_val, rfft_shape) multiplier_cache[op] = mult_full.ravel().astype(np.complex128) # ---- Identify dynamical fields and their indices ---- # Map: dynamical field name → index in the n_f dynamical field array dyn_field_names: list[str] = [] dyn_field_idx: dict[str, int] = {} for eq in spec.equations: if eq.time_derivative_order > 0: dyn_field_idx[eq.field_name] = len(dyn_field_names) dyn_field_names.append(eq.field_name) n_f = len(dyn_field_names) # number of dynamical FIELDS (not slots) # ---- Build M, D, K matrices for dynamical fields (n_f x n_f) ---- # These are the FIELD-level matrices, not slot-level. # M·ẍ = K·x + D·ẋ where x is the vector of field values. M_mat = np.zeros((n_modes, n_f, n_f), dtype=np.complex128) D_mat = np.zeros((n_modes, n_f, n_f), dtype=np.complex128) K_mat = np.zeros((n_modes, n_f, n_f), dtype=np.complex128) J_mat = np.zeros((n_modes, n_f, n_f), dtype=np.complex128) # Diagonal of M: each 2nd-order field has ẍ_i on the LHS for fi in range(n_f): M_mat[:, fi, fi] = 1.0 # Constraint matrices — built in two phases: # Phase 1: collect terms from constraint equations # Phase 2: substitute acceleration/velocity terms after M inversion S_cd = np.zeros((n_modes, n_c, n_dyn_slots), dtype=np.complex128) S_cc = np.zeros((n_modes, n_c, n_c), dtype=np.complex128) # A_dc: constraint → dynamical (field + velocity references) A_dc_field = np.zeros((n_modes, n_dyn_slots, n_c), dtype=np.complex128) A_dc_vel = np.zeros((n_modes, n_dyn_slots, n_c), dtype=np.complex128) # Deferred constraint terms with time_order > 0 on dynamical fields. # These need acceleration/velocity substitution after M inversion. # Each entry: (ci, coeff, spatial_mult, time_order, target_field_idx) deferred_constraint_terms: list[ tuple[int, complex, NDArray[np.complex128], int, int] ] = [] # ---- Populate matrices from equations ---- for eq_idx, eq in enumerate(spec.equations): is_constraint = eq.time_derivative_order == 0 if is_constraint: ci = c_idx_map[eq.field_name] for term_idx, term in enumerate(eq.rhs_terms): coeff = _resolve_constant_coeff( term, coeff_eval, eq_idx=eq_idx, term_idx=term_idx ) mult = multiplier_cache[term.operator] decomp = _OPERATOR_DECOMP[term.operator] t_order = decomp.time_order if term.field in c_idx_map: cj = c_idx_map[term.field] S_cc[:, ci, cj] += coeff * mult elif t_order == 0: # Pure spatial operator on dynamical field/velocity if term.field in dyn_slot_map: dj = dyn_slot_map[term.field] S_cd[:, ci, dj] += coeff * mult elif t_order == 1 and term.field in dyn_field_idx: # Velocity operator on dynamical field (e.g. mixed_T1_S1x) # This references ẋ_field → use velocity slot fj = dyn_field_idx[term.field] vel_j = orig_to_reduced[layout.velocity_slot_map[term.field]] S_cd[:, ci, vel_j] += coeff * mult elif t_order >= 2 and term.field in dyn_field_idx: # Acceleration/jerk on dynamical field — defer until M inverted fj = dyn_field_idx[term.field] deferred_constraint_terms.append( (ci, complex(coeff), mult, t_order, fj) ) elif term.field in dyn_slot_map: # Fallback: direct slot reference dj = dyn_slot_map[term.field] S_cd[:, ci, dj] += coeff * mult continue # Dynamical equation fi = dyn_field_idx[eq.field_name] field_slot = orig_to_reduced[layout.field_slot_map[eq.field_name]] vel_slot = orig_to_reduced[layout.velocity_slot_map[eq.field_name]] for term_idx, term in enumerate(eq.rhs_terms): coeff = _resolve_constant_coeff( term, coeff_eval, eq_idx=eq_idx, term_idx=term_idx ) mult = multiplier_cache[term.operator] decomp = _OPERATOR_DECOMP[term.operator] t_order = decomp.time_order # Determine which field this term targets target_field = term.field # Strip v_ prefix to get base field name for velocity references is_vel_ref = target_field.startswith("v_") base_field = target_field[2:] if is_vel_ref else target_field if target_field in c_idx_map: # Direct reference to constraint field cj = c_idx_map[target_field] A_dc_field[:, vel_slot, cj] += coeff * mult elif is_vel_ref and base_field in c_idx_map: # Velocity of constraint field cj = c_idx_map[base_field] A_dc_vel[:, vel_slot, cj] += coeff * mult elif base_field in dyn_field_idx: fj = dyn_field_idx[base_field] if t_order == 0: if is_vel_ref: # Velocity reference with spatial operator # → damping matrix D[fi, fj] D_mat[:, fi, fj] += coeff * mult else: # Position reference with spatial operator # → stiffness matrix K[fi, fj] K_mat[:, fi, fj] += coeff * mult elif t_order == 1: if is_vel_ref: # first_derivative_t(v_X) = ẍ_X → acceleration # This should be rare; treat as M coupling M_mat[:, fi, fj] -= coeff * mult else: # first_derivative_t(X) = ẋ_X → velocity D_mat[:, fi, fj] += coeff * mult elif t_order == 2: # d2_t or mixed_T2: acceleration coupling → mass matrix # RHS has coeff·ẍ_j, move to LHS: M[fi,fj] -= coeff·mult M_mat[:, fi, fj] -= coeff * mult elif t_order == 3: # d3_t or mixed_T3: jerk coupling → substitute later J_mat[:, fi, fj] += coeff * mult elif target_field in dyn_slot_map: # Direct slot reference (velocity name like v_h_3) dj = dyn_slot_map[target_field] # Just put it in the A matrix directly later # For now, track separately if needed # ---- Mass-matrix constraint elimination ---- # Eigendecompose M per mode to find singular directions. # Zero eigenvalues → hidden constraints; nonzero → dynamical. # # For each mode k: # M(k) = Q(k) · Λ(k) · Q(k)ᵀ # Rotate: K̃ = QᵀKQ, D̃ = QᵀDQ # Singular rows (λ=0) → constraint: 0 = K̃_c·z + D̃_c·ż # Dynamical rows (λ≠0) → ODE: Λ_d·z̈ = K̃_d·z + D̃_d·ż # Check if any mode has singular M dets = np.linalg.det(M_mat) has_singular_M = np.any(np.abs(dets) < 1e-12) m_k_independent = False con_mask = np.zeros(n_f, dtype=bool) # will be updated if singular if has_singular_M: logger.info( "Generalized mass matrix: singular M detected — " "applying mass-matrix Schur elimination" ) # Build the first-order evolution matrix A in the FULL dynamical slot space. # A has shape (n_modes, n_dyn_slots, n_dyn_slots). # For 2nd-order fields: rows for field_slot get dq/dt = v (kinematic), # rows for vel_slot get dv/dt = M⁻¹(K·x + D·v). A_dd = np.zeros((n_modes, n_dyn_slots, n_dyn_slots), dtype=np.complex128) # Kinematic equations: dq/dt = v for fname in dyn_field_names: field_slot = orig_to_reduced[layout.field_slot_map[fname]] vel_slot = orig_to_reduced[layout.velocity_slot_map[fname]] A_dd[:, field_slot, vel_slot] = 1.0 if has_singular_M: # Use eigendecomposition to handle singular M # We work with each mode separately for modes where M is singular, # and batch-process modes where M is invertible. # For simplicity and correctness, process per-mode where needed. # M is typically k-independent for d2_t coupling (spatial_mult=1), # so use the k=0 mode's eigenstructure as representative. # For k-dependent M (from mixed_T2_S*), process per mode. # Check if M is k-independent M_spread = np.max(np.abs(M_mat - M_mat[0:1, :, :])) m_k_independent = M_spread < 1e-14 if m_k_independent: # M is the same for all modes — single eigendecomposition M0 = M_mat[0] eigvals, Q = np.linalg.eigh(M0.real) # M is real symmetric # Threshold for zero eigenvalue tol = 1e-10 * max(1.0, np.max(np.abs(eigvals))) dyn_mask = np.abs(eigvals) > tol con_mask = ~dyn_mask n_mass_con = int(np.sum(con_mask)) n_mass_dyn = int(np.sum(dyn_mask)) logger.info( "Mass matrix eigenvalues: %s (dynamical: %d, constrained: %d)", eigvals, n_mass_dyn, n_mass_con, ) if n_mass_con > 0: # Rotate K, D, J into eigenspace Q[:, con_mask] # (n_f, n_mass_con) Q_d = Q[:, dyn_mask] # (n_f, n_mass_dyn) np.diag(eigvals[dyn_mask]) # (n_mass_dyn, n_mass_dyn) Lambda_d_inv = np.diag( 1.0 / eigvals[dyn_mask] ) # (n_mass_dyn, n_mass_dyn) # Rotate per-mode matrices # K̃ = QᵀKQ, D̃ = QᵀDQ, J̃ = QᵀJQ K_rot = np.einsum("ij,mjk,kl->mil", Q.T, K_mat, Q) D_rot = np.einsum("ij,mjk,kl->mil", Q.T, D_mat, Q) J_rot = np.einsum("ij,mjk,kl->mil", Q.T, J_mat, Q) # Partition into dynamical (d) and constrained (c) blocks d_idx = np.where(dyn_mask)[0] c_idx = np.where(con_mask)[0] K_dd = K_rot[:, np.ix_(d_idx, d_idx)[0], np.ix_(d_idx, d_idx)[1]] K_dc = K_rot[:, np.ix_(d_idx, c_idx)[0], np.ix_(d_idx, c_idx)[1]] K_cd = K_rot[:, np.ix_(c_idx, d_idx)[0], np.ix_(c_idx, d_idx)[1]] K_cc = K_rot[:, np.ix_(c_idx, c_idx)[0], np.ix_(c_idx, c_idx)[1]] D_dd = D_rot[:, np.ix_(d_idx, d_idx)[0], np.ix_(d_idx, d_idx)[1]] D_dc = D_rot[:, np.ix_(d_idx, c_idx)[0], np.ix_(d_idx, c_idx)[1]] D_rot[:, np.ix_(c_idx, d_idx)[0], np.ix_(c_idx, d_idx)[1]] D_rot[:, np.ix_(c_idx, c_idx)[0], np.ix_(c_idx, c_idx)[1]] J_dd = J_rot[:, np.ix_(d_idx, d_idx)[0], np.ix_(d_idx, d_idx)[1]] J_dc = J_rot[:, np.ix_(d_idx, c_idx)[0], np.ix_(d_idx, c_idx)[1]] J_rot[:, np.ix_(c_idx, d_idx)[0], np.ix_(c_idx, d_idx)[1]] J_rot[:, np.ix_(c_idx, c_idx)[0], np.ix_(c_idx, c_idx)[1]] # Constraint rows: 0 = K_cd·z_d + K_cc·z_c + D_cd·ż_d + D_cc·ż_c # Solve for z_c (position-only constraint, ignoring velocity for now): # If K_cc is invertible: z_c = -K_cc⁻¹·K_cd·z_d # If velocity terms are present, handle as implicit coupling. # Check if constraint is purely positional (K_cc invertible, D_cc ~ 0) K_cc_det = np.linalg.det(K_cc) if n_mass_con > 0 else np.ones(n_modes) has_k_con = np.any(np.abs(K_cc_det) > 1e-14) if has_k_con: # Standard case: K_cc invertible → z_c = -K_cc⁻¹·K_cd·z_d K_cc_reg = K_cc.copy() singular = np.abs(K_cc_det) < 1e-14 if np.any(singular): K_cc_reg[singular] += 1e-14 * np.eye( n_mass_con, dtype=np.complex128 ) K_cc_inv = np.linalg.inv(K_cc_reg) # pyright: ignore[reportUnknownVariableType] # Recovery: z_c = -K_cc⁻¹·K_cd·z_d mass_recovery = -np.einsum("mij,mjk->mik", K_cc_inv, K_cd) # pyright: ignore[reportUnknownArgumentType] # Substitute into dynamical equations: # Λ_d·z̈_d = K_dd·z_d + K_dc·z_c + D_dd·ż_d + D_dc·ż_c # z_c = mass_recovery·z_d → ż_c = mass_recovery·ż_d K_eff = K_dd + np.einsum("mij,mjk->mik", K_dc, mass_recovery) D_eff = D_dd + np.einsum("mij,mjk->mik", D_dc, mass_recovery) J_eff = J_dd + np.einsum("mij,mjk->mik", J_dc, mass_recovery) else: # No positional constraint coupling — mass constraint # modes decouple trivially (zero rows) K_eff = K_dd D_eff = D_dd J_eff = J_dd mass_recovery = np.zeros( (n_modes, n_mass_con, n_mass_dyn), dtype=np.complex128 ) # Now invert Λ_d (diagonal, all nonzero) # E = Λ_d⁻¹·K_eff, F = Λ_d⁻¹·D_eff E = np.einsum("ij,mjk->mik", Lambda_d_inv, K_eff) F = np.einsum("ij,mjk->mik", Lambda_d_inv, D_eff) # Jerk substitution: # d3_t(z_j) = E_j·ẋ + F_j·(E·x + F·ẋ) = F_j·E·x + (E_j + F_j·F)·ẋ J_eff_inv = np.einsum("ij,mjk->mik", Lambda_d_inv, J_eff) has_jerk = np.max(np.abs(J_eff_inv)) > 1e-15 if has_jerk: logger.info("Jerk substitution: applying d3_t elimination") # K_final += J_eff_inv · F · E (position correction from jerk) FE = np.einsum("mij,mjk->mik", F, E) K_jerk = np.einsum("mij,mjk->mik", J_eff_inv, FE) # D_final += J_eff_inv · (E + F²) (velocity correction from jerk) FF = np.einsum("mij,mjk->mik", F, F) D_jerk = np.einsum("mij,mjk->mik", J_eff_inv, E + FF) E_final = E + K_jerk F_final = F + D_jerk else: E_final = E F_final = F # Build evolution matrix in the ROTATED field basis # State vector in rotated basis: (z_d, ż_d) # dz_d/dt = ż_d # dż_d/dt = E_final·z_d + F_final·ż_d # Now map back to the ORIGINAL slot-level evolution matrix A_dd. # The rotation Q maps field-level indices to slot-level indices. # For each dynamical field, there's a field_slot and vel_slot. # Build the Q_d mapping: original field index → rotated dynamical index # Q_d[original_i, rotated_j] = transformation coefficient # For the velocity-slot rows (dynamics), fill in: # dv_i/dt = Σ_j Q_d[i,a] · E_final[a,b] · Q_d[j,b] · field_j # + Σ_j Q_d[i,a] · F_final[a,b] · Q_d[j,b] · vel_j # Effective K and D in original field basis: K_orig = np.einsum("ia,mab,jb->mij", Q_d, E_final, Q_d) D_orig = np.einsum("ia,mab,jb->mij", Q_d, F_final, Q_d) # Fill A_dd velocity rows for i, fname_i in enumerate(dyn_field_names): vel_i = orig_to_reduced[layout.velocity_slot_map[fname_i]] for j, fname_j in enumerate(dyn_field_names): field_j = orig_to_reduced[layout.field_slot_map[fname_j]] vel_j = orig_to_reduced[layout.velocity_slot_map[fname_j]] A_dd[:, vel_i, field_j] += K_orig[:, i, j] A_dd[:, vel_i, vel_j] += D_orig[:, i, j] else: # No singular directions — M is invertible m_inv = np.linalg.inv(M_mat) eff_k = np.einsum("mij,mjk->mik", m_inv, K_mat) eff_d = np.einsum("mij,mjk->mik", m_inv, D_mat) # Jerk substitution j_inv = np.einsum("mij,mjk->mik", m_inv, J_mat) has_jerk = np.max(np.abs(j_inv)) > 1e-15 if has_jerk: fd_k = np.einsum("mij,mjk->mik", eff_d, eff_k) k_jerk = np.einsum("mij,mjk->mik", j_inv, fd_k) fd_d = np.einsum("mij,mjk->mik", eff_d, eff_d) d_jerk = np.einsum("mij,mjk->mik", j_inv, eff_k + fd_d) eff_k += k_jerk eff_d += d_jerk for i, fname_i in enumerate(dyn_field_names): vel_i = orig_to_reduced[layout.velocity_slot_map[fname_i]] for j, fname_j in enumerate(dyn_field_names): field_j = orig_to_reduced[layout.field_slot_map[fname_j]] vel_j = orig_to_reduced[layout.velocity_slot_map[fname_j]] A_dd[:, vel_i, field_j] += eff_k[:, i, j] A_dd[:, vel_i, vel_j] += eff_d[:, i, j] else: # M is k-dependent — process per mode # For now, treat each mode independently for m in range(n_modes): M_m = M_mat[m] eigvals_m, _Q_m = np.linalg.eigh(M_m.real) tol = 1e-10 * max(1.0, np.max(np.abs(eigvals_m))) dyn_m = np.abs(eigvals_m) > tol if np.all(dyn_m): # Invertible for this mode m_inv_m = np.linalg.inv(M_m) # pyright: ignore[reportUnknownVariableType] ek_m = m_inv_m @ K_mat[m] # pyright: ignore[reportUnknownVariableType] ed_m = m_inv_m @ D_mat[m] # pyright: ignore[reportUnknownVariableType] j_inv_m = m_inv_m @ J_mat[m] # pyright: ignore[reportUnknownVariableType] if np.max(np.abs(j_inv_m)) > 1e-15: # pyright: ignore[reportUnknownArgumentType] fd_k_m = ed_m @ ek_m # pyright: ignore[reportUnknownVariableType] ek_m += j_inv_m @ fd_k_m # pyright: ignore[reportUnknownVariableType] ed_m += j_inv_m @ (ek_m + ed_m @ ed_m) # pyright: ignore[reportUnknownVariableType] for i, fname_i in enumerate(dyn_field_names): vi = orig_to_reduced[layout.velocity_slot_map[fname_i]] for j, fname_j in enumerate(dyn_field_names): fj = orig_to_reduced[layout.field_slot_map[fname_j]] vj = orig_to_reduced[layout.velocity_slot_map[fname_j]] A_dd[m, vi, fj] += ek_m[i, j] A_dd[m, vi, vj] += ed_m[i, j] else: # Singular mode — would need per-mode Schur elimination # This is rare for k-dependent M; log and use pseudoinverse m_pinv = np.linalg.pinv(M_m) # pyright: ignore[reportUnknownVariableType] ek_m2 = m_pinv @ K_mat[m] # pyright: ignore[reportUnknownVariableType] ed_m2 = m_pinv @ D_mat[m] # pyright: ignore[reportUnknownVariableType] for i, fname_i in enumerate(dyn_field_names): vi = orig_to_reduced[layout.velocity_slot_map[fname_i]] for j, fname_j in enumerate(dyn_field_names): fj = orig_to_reduced[layout.field_slot_map[fname_j]] vj = orig_to_reduced[layout.velocity_slot_map[fname_j]] A_dd[m, vi, fj] += ek_m2[i, j] A_dd[m, vi, vj] += ed_m2[i, j] else: # M is invertible for all modes — standard path m_inv = np.linalg.inv(M_mat) eff_k = np.einsum("mij,mjk->mik", m_inv, K_mat) eff_d = np.einsum("mij,mjk->mik", m_inv, D_mat) # Jerk substitution j_inv = np.einsum("mij,mjk->mik", m_inv, J_mat) has_jerk = np.max(np.abs(j_inv)) > 1e-15 if has_jerk: logger.info("Jerk substitution: applying d3_t elimination") fd_k = np.einsum("mij,mjk->mik", eff_d, eff_k) k_jerk = np.einsum("mij,mjk->mik", j_inv, fd_k) fd_d = np.einsum("mij,mjk->mik", eff_d, eff_d) d_jerk = np.einsum("mij,mjk->mik", j_inv, eff_k + fd_d) eff_k += k_jerk eff_d += d_jerk for i, fname_i in enumerate(dyn_field_names): vel_i = orig_to_reduced[layout.velocity_slot_map[fname_i]] for j, fname_j in enumerate(dyn_field_names): field_j = orig_to_reduced[layout.field_slot_map[fname_j]] vel_j = orig_to_reduced[layout.velocity_slot_map[fname_j]] A_dd[:, vel_i, field_j] += eff_k[:, i, j] A_dd[:, vel_i, vel_j] += eff_d[:, i, j] # ---- Substitute deferred constraint acceleration/velocity terms ---- # Constraints may contain time_order>=2 operators on dynamical fields # (e.g., mixed_T2_S1x(t_3) = ik_x x ẍ_{t_3}). After mass-matrix # inversion, ẍ_j = Σ_k E[j,k]·field_k + F[j,k]·vel_k. Substitute # this into the constraint's S_cd matrix. if deferred_constraint_terms: # Extract effective acceleration matrices from A_dd. # A_dd[m, vel_i, field_j] = K_eff[i,j] (position → acceleration) # A_dd[m, vel_i, vel_j] = D_eff[i,j] (velocity → acceleration) K_eff = np.zeros((n_modes, n_f, n_f), dtype=np.complex128) D_eff = np.zeros((n_modes, n_f, n_f), dtype=np.complex128) for i, fname_i in enumerate(dyn_field_names): vel_i = orig_to_reduced[layout.velocity_slot_map[fname_i]] for j, fname_j in enumerate(dyn_field_names): field_j = orig_to_reduced[layout.field_slot_map[fname_j]] vel_j = orig_to_reduced[layout.velocity_slot_map[fname_j]] K_eff[:, i, j] = A_dd[:, vel_i, field_j] D_eff[:, i, j] = A_dd[:, vel_i, vel_j] for ci, coeff_val, spatial_mult, t_order, fj in deferred_constraint_terms: if t_order == 2: # ẍ_fj = Σ_k K_eff[fj,k]·field_k + D_eff[fj,k]·vel_k for k, fname_k in enumerate(dyn_field_names): fk_slot = orig_to_reduced[layout.field_slot_map[fname_k]] vk_slot = orig_to_reduced[layout.velocity_slot_map[fname_k]] # Position contribution: coeff x spatial x K_eff[fj, k] S_cd[:, ci, fk_slot] += coeff_val * spatial_mult * K_eff[:, fj, k] # Velocity contribution: coeff x spatial x D_eff[fj, k] S_cd[:, ci, vk_slot] += coeff_val * spatial_mult * D_eff[:, fj, k] # time_order=3 in constraints is very rare; log and skip elif t_order >= 3: logger.warning( "Constraint has time_order=%d operator — not yet handled", t_order, ) # ---- Constraint field Schur elimination ---- if n_c > 0: # Batch-invert S_cc cc_dets = np.linalg.det(S_cc) if n_c > 0 else np.ones(n_modes) singular_mask = np.abs(cc_dets) < 1e-14 S_cc_reg = S_cc.copy() if np.any(singular_mask): S_cc_reg[singular_mask] += 1e-14 * np.eye(n_c, dtype=np.complex128) S_cc_inv = np.linalg.inv(S_cc_reg) # Recovery: c = -S_cc⁻¹ · S_cd · d recovery = -np.einsum("mij,mjk->mik", S_cc_inv, S_cd) # Field correction: A_dc_field · recovery field_correction = np.einsum("mij,mjk->mik", A_dc_field, recovery) # Velocity coupling: A_dc_vel · recovery vel_coupling = np.einsum("mij,mjk->mik", A_dc_vel, recovery) has_vel = np.max(np.abs(vel_coupling)) > 1e-15 A_rhs = A_dd + field_correction if has_vel: eye = np.broadcast_to( np.eye(n_dyn_slots, dtype=np.complex128), (n_modes, n_dyn_slots, n_dyn_slots), ).copy() B_lhs: NDArray[np.complex128] | None = eye - vel_coupling else: B_lhs = None else: recovery = np.zeros((n_modes, 0, n_dyn_slots), dtype=np.complex128) A_rhs = A_dd B_lhs = None n_mass_con_total = int(np.sum(con_mask)) logger.info( "Generalized evolution: %d constraint fields, %d mass-matrix constraints, " "%d dynamical slots, jerk=%s, vel_coupling=%s", n_c, n_mass_con_total, n_dyn_slots, "yes" if np.max(np.abs(J_mat)) > 1e-15 else "no", "generalized_eig" if B_lhs is not None else "none", ) # Velocity recovery for generalized eigenvalue (B·d' = A·d): # d' = B⁻¹·A·d, so v_c = recovery · d' = recovery · B⁻¹·A · d # For singular B modes, use least-squares to get the best A_eff. if B_lhs is not None and recovery.size > 0: n_dyn_slots = A_rhs.shape[1] A_eff = np.zeros_like(A_rhs) for m in range(n_modes): try: A_eff[m] = np.linalg.solve(B_lhs[m], A_rhs[m]) except np.linalg.LinAlgError: # Singular B at this mode — use lstsq A_eff[m] = np.asarray( np.linalg.lstsq(B_lhs[m], A_rhs[m], rcond=None)[0], # type: ignore[reportUnknownMemberType] dtype=np.complex128, ) v_recovery = np.einsum("mci,mij->mcj", recovery, A_eff) elif recovery.size > 0: v_recovery = np.einsum("mci,mij->mcj", recovery, A_rhs) else: v_recovery = None return A_rhs, B_lhs, recovery, v_recovery, constraint_field_names, orig_to_reduced # --------------------------------------------------------------------------- # Evolution matrix construction # --------------------------------------------------------------------------- def _build_per_mode_matrices( spec: EquationSystem, layout: StateLayout, grid: GridInfo, coeff_eval: CoefficientEvaluator, k_grid: list[NDArray[np.float64]], rfft_shape: tuple[int, ...], ) -> NDArray[np.complex128]: """Build evolution matrices for the all-constant-coefficient case. Returns array of shape (n_modes, n_state_slots, n_state_slots) where each [m, :, :] is the evolution matrix for mode m. The matrix has block structure: For second-order fields: [0, 1; L(k), 0] (velocity coupling) For first-order fields: [L(k)] (direct evolution) """ n_slots = layout.num_slots n_modes = int(np.prod(rfft_shape)) # Evaluate Fourier multipliers on the k-grid multiplier_cache: dict[str, NDArray[np.complex128]] = {} for eq in spec.equations: for term in eq.rhs_terms: op = term.operator if op not in multiplier_cache: mult_fn = _EXACT_MULTIPLIERS[op] mult_val = mult_fn(k_grid) # Broadcast to full rfft shape and flatten mult_full = np.broadcast_to(mult_val, rfft_shape) multiplier_cache[op] = mult_full.ravel().astype(np.complex128) # Build matrices: A[m, i, j] for each mode m A = np.zeros((n_modes, n_slots, n_slots), dtype=np.complex128) for _eq_idx, eq in enumerate(spec.equations): field_name = eq.field_name is_second_order = eq.time_derivative_order >= 2 if is_second_order: # Field slot and velocity slot field_slot = layout.field_slot_map[field_name] vel_slot = layout.velocity_slot_map[field_name] # dq/dt = v → A[field_slot, vel_slot] = 1 A[:, field_slot, vel_slot] = 1.0 # dv/dt = Σ coeff * operator(target_field) for _term_idx, term in enumerate(eq.rhs_terms): target_slot = layout.field_slot_map[term.field] coeff = _resolve_constant_coeff( term, coeff_eval, eq_idx=_eq_idx, term_idx=_term_idx, ) mult = multiplier_cache[term.operator] A[:, vel_slot, target_slot] += coeff * mult else: # First-order: du/dt = Σ coeff * operator(target_field) this_slot = layout.field_slot_map[field_name] for _term_idx, term in enumerate(eq.rhs_terms): target_slot = layout.field_slot_map[term.field] coeff = _resolve_constant_coeff( term, coeff_eval, eq_idx=_eq_idx, term_idx=_term_idx, ) mult = multiplier_cache[term.operator] A[:, this_slot, target_slot] += coeff * mult return A def _resolve_constant_coeff( term: OperatorTerm, coeff_eval: CoefficientEvaluator, *, eq_idx: int = -1, term_idx: int = -1, ) -> complex: """Resolve a constant (non-position-dependent) coefficient to a scalar. Uses CoefficientEvaluator.resolve() which returns a float for constant coefficients or an ndarray for position-dependent ones (the latter should not reach this function). """ resolved = coeff_eval.resolve(term, 0.0, eq_idx=eq_idx, term_idx=term_idx) if isinstance(resolved, np.ndarray): # Position-dependent — should not happen for constant-coeff path return complex(resolved.ravel()[0]) return complex(resolved) def _build_convolution_matrix( spec: EquationSystem, layout: StateLayout, grid: GridInfo, coeff_eval: CoefficientEvaluator, k_grid: list[NDArray[np.float64]], rfft_shape: tuple[int, ...], ) -> NDArray[np.complex128]: """Build full evolution matrix for position-dependent coefficient case. Position-dependent coefficients c(x) create convolution coupling in k-space: FFT[c(x)·u(x)] = ĉ * û (convolution). This couples different k-modes, producing a full (n_total x n_total) matrix where n_total = n_slots x n_modes. For localized c(x) (e.g. Gaussian B₀), the convolution kernel ĉ(q) decays exponentially, making the matrix effectively banded. The downstream ``_evolve_full_matrix`` exploits this by thresholding small entries and converting to sparse CSC format for faster expm_multiply. Reference: Burns et al. (2020), Phys. Rev. Research 2:023068. """ n_slots = layout.num_slots n_modes = int(np.prod(rfft_shape)) n_total = n_slots * n_modes # Evaluate Fourier multipliers on the k-grid (for constant terms) multiplier_cache: dict[str, NDArray[np.complex128]] = {} for eq in spec.equations: for term in eq.rhs_terms: op = term.operator if op not in multiplier_cache: mult_fn = _EXACT_MULTIPLIERS[op] mult_val = mult_fn(k_grid) mult_full = np.broadcast_to(mult_val, rfft_shape) multiplier_cache[op] = mult_full.ravel().astype(np.complex128) A = np.zeros((n_total, n_total), dtype=np.complex128) for _eq_idx, eq in enumerate(spec.equations): field_name = eq.field_name is_second_order = eq.time_derivative_order >= 2 if is_second_order: field_slot = layout.field_slot_map[field_name] vel_slot = layout.velocity_slot_map[field_name] # dq/dt = v → diagonal identity coupling between field and velocity for m in range(n_modes): row = field_slot * n_modes + m col = vel_slot * n_modes + m A[row, col] = 1.0 # dv/dt = Σ coeff(x) * operator(target_field) for _term_idx, term in enumerate(eq.rhs_terms): target_slot = layout.field_slot_map[term.field] mult = multiplier_cache[term.operator] if not term.position_dependent: # Constant coefficient: diagonal in mode space coeff = _resolve_constant_coeff( term, coeff_eval, eq_idx=_eq_idx, term_idx=_term_idx, ) for m in range(n_modes): row = vel_slot * n_modes + m col = target_slot * n_modes + m A[row, col] += coeff * mult[m] else: # Position-dependent: convolution coupling _add_convolution_coupling( A, vel_slot, target_slot, term, coeff_eval, mult, grid, rfft_shape, n_modes, eq_idx=_eq_idx, term_idx=_term_idx, ) else: # First-order this_slot = layout.field_slot_map[field_name] for _term_idx, term in enumerate(eq.rhs_terms): target_slot = layout.field_slot_map[term.field] mult = multiplier_cache[term.operator] if not term.position_dependent: coeff = _resolve_constant_coeff( term, coeff_eval, eq_idx=_eq_idx, term_idx=_term_idx, ) for m in range(n_modes): row = this_slot * n_modes + m col = target_slot * n_modes + m A[row, col] += coeff * mult[m] else: _add_convolution_coupling( A, this_slot, target_slot, term, coeff_eval, mult, grid, rfft_shape, n_modes, eq_idx=_eq_idx, term_idx=_term_idx, ) return A def _add_convolution_coupling( A: NDArray[np.complex128], row_slot: int, col_slot: int, term: OperatorTerm, coeff_eval: CoefficientEvaluator, operator_mult: NDArray[np.complex128], grid: GridInfo, rfft_shape: tuple[int, ...], n_modes: int, *, eq_idx: int = -1, term_idx: int = -1, ) -> None: """Add convolution coupling from a position-dependent coefficient. The product c(x)·op(u(x)) in k-space becomes a convolution: FFT[c·op(u)]_k = Σ_k' ĉ(k-k') · mult(k') · û(k') This creates off-diagonal entries in the evolution matrix coupling different k-modes. """ # Get the coefficient array on the spatial grid coeff_array = coeff_eval.resolve(term, 0.0, eq_idx=eq_idx, term_idx=term_idx) if isinstance(coeff_array, (int, float)): coeff_array = np.full(grid.shape, float(coeff_array)) # For each pair of output mode m and input mode m', # the coupling is (1/N) * ĉ(m-m') * mult(m') # This is a Toeplitz-like structure in 1D. # # Build via outer product approach for efficiency: # We compute the full convolution matrix using FFT properties. # # For rfftn: the convolution of real functions in rfft space requires # care with the Hermitian symmetry. We use the identity: # FFT[c·u]_k = (1/N) Σ_{k'} ĉ_{k-k'} · û_{k'} # # Build the convolution matrix C where C[k, k'] = (1/N) * ĉ_{k-k'} # using probe vectors (unit impulse per mode). for m_prime in range(n_modes): # Probe: unit impulse at mode m_prime probe_hat = np.zeros(n_modes, dtype=np.complex128) probe_hat[m_prime] = 1.0 # Reconstruct to physical space, multiply by coefficient, FFT back probe_physical = np.fft.irfftn( probe_hat.reshape(rfft_shape), s=grid.shape, axes=list(range(len(grid.shape))), ) product = coeff_array * probe_physical result_hat = np.fft.rfftn(product).ravel() # result_hat[m] = Σ_{k'} (1/N) ĉ_{m-k'} δ_{k',m'} = (1/N) ĉ_{m-m'} # multiplied by operator multiplier at m' row_start = row_slot * n_modes col = col_slot * n_modes + m_prime A[row_start : row_start + n_modes, col] += result_hat * operator_mult[m_prime] # --------------------------------------------------------------------------- # Block decomposition # --------------------------------------------------------------------------- def _find_independent_blocks( A: NDArray[np.complex128], threshold: float = 1e-14, ) -> list[list[int]]: """Find independent (decoupled) blocks in an evolution matrix. Analyzes the sparsity pattern of A: slots i and j are coupled if |A[i,j]| > threshold or |A[j,i]| > threshold. Returns a list of slot-index groups (connected components). This prevents degenerate-eigenvalue mixing when ``np.linalg.eig`` processes a block-diagonal matrix with repeated eigenvalues across independent blocks — a common situation in symmetric multi-field theories (e.g. Gertsenshtein h₅↔a₁ + h₇↔a₂). """ n = A.shape[0] # Union-find (path compression + union by rank) parent = list(range(n)) rank = [0] * n def _find(x: int) -> int: while parent[x] != x: parent[x] = parent[parent[x]] # path compression x = parent[x] return x def _union(x: int, y: int) -> None: rx, ry = _find(x), _find(y) if rx == ry: return if rank[rx] < rank[ry]: rx, ry = ry, rx parent[ry] = rx if rank[rx] == rank[ry]: rank[rx] += 1 # Build coupling graph from matrix entries for i in range(n): for j in range(i + 1, n): if abs(A[i, j]) > threshold or abs(A[j, i]) > threshold: _union(i, j) # Group by root groups: dict[int, list[int]] = {} for i in range(n): root = _find(i) groups.setdefault(root, []).append(i) return list(groups.values()) # --------------------------------------------------------------------------- # Eigendecomposition and time evolution # --------------------------------------------------------------------------- def _has_position_dependent_terms(spec: EquationSystem) -> bool: """Check if any RHS term has a position-dependent coefficient.""" for eq in spec.equations: for term in eq.rhs_terms: if term.position_dependent: return True return False def _warn_eigenvalue_growth( eigenvalues: NDArray[np.complex128], dt_total: float, context: str = "", ) -> None: """Warn if eigenvalues have positive real parts that could cause overflow.""" import warnings # noqa: PLC0415 max_real = float(np.max(np.real(eigenvalues))) if max_real > 1e-10: max_growth = max_real * dt_total if max_growth > 30: # exp(30) ≈ 1e13 ctx = f" ({context})" if context else "" warnings.warn( f"Modal solver{ctx}: eigenvalues with positive real parts " f"(max Re(λ)={max_real:.3e}). Growth factor exp({max_growth:.1f}) " f"over Δt={dt_total:.1f} may cause overflow. " f"Consider --scheme cvode for numerical stability.", stacklevel=3, ) def _evolve_per_mode( A_modes: NDArray[np.complex128], y0_hat: NDArray[np.complex128], t_eval: NDArray[np.float64], layout: StateLayout, grid: GridInfo, snapshot_callback: Callable[[float, NDArray[np.float64]], None] | None, progress: SimulationProgress | None, *, return_fourier: bool = False, return_derivative_fourier: bool = False, B_modes: NDArray[np.complex128] | None = None, ) -> tuple[ NDArray[np.float64], NDArray[np.float64], NDArray[np.complex128] | None, NDArray[np.complex128] | None, ]: """Evolve system with per-mode independent matrices (constant coefficients). A_modes has shape (n_modes, n_slots, n_slots). y0_hat has shape (n_slots, n_modes). Uses block-aware eigendecomposition: independent field blocks are detected and eigendecomposed separately to prevent degenerate-eigenvalue mixing. Blocks with all-zero initial conditions are skipped entirely. Ref: Golub & Van Loan (1996), Matrix Computations, §4.8. """ n_slots = layout.num_slots n_pts = layout.num_points n_snapshots = len(t_eval) t0 = t_eval[0] dt_total = float(t_eval[-1] - t0) # Detect independent blocks from the first mode's matrix. # Block structure is k-independent for constant coefficients, so we only # need to analyze one representative mode (use max across a few modes for # robustness against accidental zeros at specific k). n_check = min(3, A_modes.shape[0]) combined = np.max(np.abs(A_modes[:n_check]), axis=0) blocks = _find_independent_blocks(combined) # Pre-compute eigendecomposition for each active block block_data: list[ tuple[ list[int], # slot indices NDArray[np.complex128], # eigenvalues (n_modes, block_size) NDArray[np.complex128], # V (n_modes, block_size, block_size) NDArray[np.complex128], # y0_eigen (n_modes, block_size) ] ] = [] for block_slots in blocks: # Extract IC for this block y0_block = y0_hat[block_slots, :] # (block_size, n_modes) # Skip blocks with all-zero IC — output stays at zero if np.max(np.abs(y0_block)) < 1e-15: continue # Extract block sub-matrices: (n_modes, block_size, block_size) idx = np.array(block_slots) A_block = A_modes[:, idx[:, None], idx[None, :]] if B_modes is not None: # Generalized eigenvalue problem: B · d' = A · d # Uses QZ decomposition via scipy.linalg.eig(A, B). # Infinite eigenvalues (gauge DOF) are zeroed — they don't evolve. # Ref: Golub & Van Loan (2013), Matrix Computations §7.7.6 import scipy.linalg as sla # type: ignore[import-untyped] # noqa: PLC0415 B_block = B_modes[:, idx[:, None], idx[None, :]] bs = len(block_slots) n_block_modes = A_block.shape[0] eig_vals = np.zeros((n_block_modes, bs), dtype=np.complex128) v_mat = np.zeros((n_block_modes, bs, bs), dtype=np.complex128) n_gauge_total = 0 for m in range(A_block.shape[0]): eig_result = sla.eig(A_block[m], B_block[m], right=True) # pyright: ignore[reportUnknownVariableType] ev_m = eig_result[0] # pyright: ignore[reportUnknownVariableType] vr_m = eig_result[1] # pyright: ignore[reportUnknownVariableType] # Filter infinite/very-large eigenvalues (gauge modes) gauge = ~np.isfinite(ev_m) | (np.abs(ev_m) > 1e12) # pyright: ignore[reportUnknownArgumentType] ev_m[gauge] = 0.0 # gauge modes frozen at IC n_gauge_total += int(np.sum(gauge)) eig_vals[m] = ev_m # pyright: ignore[reportUnknownArgumentType] v_mat[m] = vr_m if n_gauge_total > 0: import logging as _log # noqa: PLC0415 _log.getLogger(__name__).info( "Generalized eigenvalue: %d gauge modes zeroed across %d modes", n_gauge_total, A_block.shape[0], ) v_inv = np.linalg.inv(v_mat) else: # Standard eigendecomposition (existing path) eig_vals, v_mat = np.linalg.eig(A_block) v_inv = np.linalg.inv(v_mat) # Warn about potential overflow _warn_eigenvalue_growth(eig_vals, dt_total, context="per-mode") # Transform IC to eigenbasis y0_eigen = np.einsum("mij,mj->mi", v_inv, y0_block.T) block_data.append((block_slots, eig_vals, v_mat, y0_eigen)) # Evolve at each time point. # Pre-multiply V @ diag(y0_eigen) for each block so the inner loop only # needs element-wise exp + matrix-vector product, not a full einsum. block_evolved: list[ tuple[ list[int], # slot indices NDArray[np.complex128], # V_y0: V * y0_eigen, (n_modes, bs, bs) NDArray[np.complex128] | None, # V_y0_deriv (n_modes, bs, bs) or None NDArray[np.complex128], # eigenvalues (n_modes, bs) ] ] = [] for block_slots, eig_vals, v_mat, y0_eigen in block_data: # V_y0[m, i, j] = v_mat[m, i, j] * y0_eigen[m, j] # so y(t) = V_y0 @ exp(λ*dt) is just a matvec V_y0 = v_mat * y0_eigen[:, np.newaxis, :] # (n_modes, bs, bs) # V_y0_deriv[m, i, j] = V_y0[m, i, j] * λ[m, j] # so y'(t) = V_y0_deriv @ exp(λ*dt) gives exact time derivative V_y0_deriv = ( V_y0 * eig_vals[:, np.newaxis, :] if return_derivative_fourier else None ) block_evolved.append((block_slots, V_y0, V_y0_deriv, eig_vals)) snapshots = np.zeros((n_snapshots, n_slots * n_pts)) times = np.zeros(n_snapshots) n_modes = y0_hat.shape[1] # Optionally collect Fourier-space snapshots (avoids re-FFT in constraint # recovery — the Fourier data is already computed here). fourier_snaps: NDArray[np.complex128] | None = None if return_fourier: fourier_snaps = np.zeros( (n_snapshots, n_slots, n_modes), dtype=np.complex128, ) # Optionally collect Fourier-space TIME DERIVATIVE snapshots. # d'(t) = V · diag(λ · exp(λt)) · y0_eigen — exact, no numerical diff. # Used for machine-precision constraint velocity: v_c = recovery · d'. deriv_fourier_snaps: NDArray[np.complex128] | None = None dy_hat_t: NDArray[np.complex128] | None = None if return_derivative_fourier: deriv_fourier_snaps = np.zeros( (n_snapshots, n_slots, n_modes), dtype=np.complex128, ) dy_hat_t = np.zeros((n_slots, n_modes), dtype=np.complex128) # Pre-allocate buffer reused each timestep (avoids n_snapshots allocations) y_hat_t = np.zeros((n_slots, n_modes), dtype=np.complex128) for ti, t in enumerate(t_eval): dt = t - t0 y_hat_t[:] = 0.0 if dy_hat_t is not None: dy_hat_t[:] = 0.0 for block_slots, V_y0, V_y0_deriv, eig_vals in block_evolved: # exp_lambda shape: (n_modes, block_size) exp_lambda = np.exp(eig_vals * dt) # y_evolved[m, i] = Σ_j V_y0[m, i, j] * exp(λ_j * dt) y_evolved = np.einsum("mij,mj->mi", V_y0, exp_lambda) y_hat_t[block_slots, :] = y_evolved.T # Exact time derivative: dy[m,i] = Σ_j V_y0_deriv[m,i,j] * exp(λ_j*dt) if V_y0_deriv is not None and dy_hat_t is not None: dy_evolved = np.einsum("mij,mj->mi", V_y0_deriv, exp_lambda) dy_hat_t[block_slots, :] = dy_evolved.T if fourier_snaps is not None: fourier_snaps[ti] = y_hat_t if deriv_fourier_snaps is not None and dy_hat_t is not None: deriv_fourier_snaps[ti] = dy_hat_t y_physical = _ifft_slots(y_hat_t, layout, grid) snapshots[ti] = y_physical times[ti] = t if snapshot_callback is not None: snapshot_callback(t, y_physical) if progress is not None: progress.update(t) return times, snapshots, fourier_snaps, deriv_fourier_snaps def _evolve_full_matrix( A_full: NDArray[np.complex128], y0_hat: NDArray[np.complex128], t_eval: NDArray[np.float64], layout: StateLayout, grid: GridInfo, snapshot_callback: Callable[[float, NDArray[np.float64]], None] | None, progress: SimulationProgress | None, ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: """Evolve system with full coupled matrix (position-dependent coefficients). A_full has shape (n_total, n_total) where n_total = n_slots x n_modes. y0_hat has shape (n_slots, n_modes). Uses ``scipy.sparse.linalg.expm_multiply`` to compute exp(A·t)·y₀ at each output time without eigendecomposition. This is backward-stable for non-normal matrices (position-dependent gradient coupling creates non-normal convolution matrices whose eigenvalues have large real parts, but the true dynamics are bounded). The algorithm uses scaling + truncated Taylor series in matrix-vector products, avoiding individual exp(λ·t) overflow. **Why not eigendecomposition?** The original full-matrix eigendecomposition gave incorrect physics for localized Gertsenshtein (P=0.477 vs correct P=0.3437) because non-normal convolution matrices have eigenvalues with significant positive real parts despite conservative physics — individual exp(λ·t) overflow while exp(A·t)·y₀ is bounded (pseudospectral phenomenon; Trefethen & Embree 2005, Ch. 14). **Sparse optimization:** For localized coefficients (e.g. Gaussian B₀), the convolution kernel ĉ(q) decays exponentially, making the matrix effectively banded. Entries below a relative threshold (1e-14 x max|A|) are zeroed, and if density < 30% the matrix is converted to sparse CSC format. This accelerates expm_multiply's internal matrix-vector products. Ref: Al-Mohy & Higham (2011), "Computing the Action of the Matrix Exponential", SIAM J. Sci. Comput. 33(2):488-511. """ import scipy.sparse # noqa: PLC0415 # pyright: ignore[reportMissingTypeStubs] from scipy.sparse.linalg import ( # noqa: PLC0415 # pyright: ignore[reportMissingTypeStubs] expm_multiply, # pyright: ignore[reportUnknownVariableType] ) n_slots = layout.num_slots n_pts = layout.num_points n_modes = y0_hat.shape[1] n_snapshots = len(t_eval) # Flatten y0_hat to (n_total,) — slot-major order y0_flat = y0_hat.ravel() # --- Sparse matrix optimization --- # Position-dependent convolution matrices are effectively banded: for # Gaussian B₀(x) the kernel ĉ(q) decays exponentially, so most off- # diagonal entries are negligibly small. Thresholding and converting to # sparse CSC format accelerates expm_multiply's internal matrix-vector # products (the dominant cost) without affecting accuracy. abs_max = float(np.max(np.abs(A_full))) if abs_max > 0: threshold = abs_max * 1e-14 A_work = A_full.copy() A_work[np.abs(A_work) < threshold] = 0.0 density = np.count_nonzero(A_work) / A_work.size if density < 0.3: A_op: NDArray[np.complex128] | scipy.sparse.csc_array = ( scipy.sparse.csc_array(A_work) ) else: A_op = A_work else: A_op = A_full density = 1.0 snapshots = np.zeros((n_snapshots, n_slots * n_pts)) times = np.zeros(n_snapshots) # Compute exp(A·t)·y₀ at all requested times using Krylov/Taylor method. # expm_multiply handles the scaling internally — no manual dt stepping. t0 = float(t_eval[0]) t_end = float(t_eval[-1]) if n_snapshots > 1 and t_end > t0: # Use expm_multiply's built-in multi-point evaluation y_all: NDArray[np.complex128] = np.asarray( expm_multiply( A_op, y0_flat, start=t0, stop=t_end, num=n_snapshots, ), dtype=np.complex128, ) # y_all has shape (n_snapshots, n_total) for ti in range(n_snapshots): t = float(t_eval[ti]) y_hat_t = y_all[ti].reshape(n_slots, n_modes) y_physical = _ifft_slots(y_hat_t, layout, grid) snapshots[ti] = y_physical times[ti] = t if snapshot_callback is not None: snapshot_callback(t, y_physical) if progress is not None: progress.update(t) else: # Single time point or t0 == t_end for ti, t in enumerate(t_eval): if t == t0: y_evolved = y0_flat.copy() else: y_evolved = np.asarray( expm_multiply(A_op, y0_flat, start=t0, stop=float(t), num=2)[-1], dtype=np.complex128, ) y_hat_t = y_evolved.reshape(n_slots, n_modes) y_physical = _ifft_slots(y_hat_t, layout, grid) snapshots[ti] = y_physical times[ti] = t if snapshot_callback is not None: snapshot_callback(t, y_physical) if progress is not None: progress.update(t) return times, snapshots # --------------------------------------------------------------------------- # Public solver entry point # ---------------------------------------------------------------------------
[docs] def solve_modal( spec: EquationSystem, grid: GridInfo, y0: np.ndarray, t_span: tuple[float, float], *, bc: BCSpec | None = None, parameters: dict[str, float] | None = None, rtol: float = DEFAULT_RTOL, atol: float = DEFAULT_ATOL, num_snapshots: int = 101, snapshot_callback: Callable[[float, np.ndarray], None] | None = None, progress: SimulationProgress | None = None, ) -> SolverResult: """Solve a TIDAL equation system using Fourier modal decomposition. Transforms the spatial grid to Fourier space, builds the per-mode or full evolution matrix, eigendecomposes, and evaluates the exact solution at each output time. Parameters ---------- spec : EquationSystem Parsed equation specification (from JSON). grid : GridInfo Spatial grid (must be all-periodic). y0 : np.ndarray Initial state vector (flat). t_span : tuple[float, float] (t_start, t_end). bc : str or tuple, optional Boundary conditions (must be all-periodic). parameters : dict[str, float], optional Runtime parameter overrides for symbolic coefficients. rtol, atol : float Tolerances (unused for eigendecomposition; reserved for solve_ivp fallback with time-dependent coefficients). num_snapshots : int Number of output time points. snapshot_callback : callable, optional Called as ``callback(t, y)`` at each output time. progress : SimulationProgress, optional Progress tracker for tqdm display. Returns ------- SolverResult Dict with keys: ``t``, ``y``, ``success``, ``message``. """ from tidal.solver.coefficients import CoefficientEvaluator # noqa: PLC0415 layout = StateLayout.from_spec(spec, grid.num_points) coeff_eval = CoefficientEvaluator(spec, grid, parameters or {}) # Detect constraint fields has_constraints = any(eq.time_derivative_order == 0 for eq in spec.equations) if not has_constraints: warn_frozen_constraints(layout, "modal") # Build time evaluation points t_eval = np.linspace(t_span[0], t_span[1], num_snapshots) # Build k-grid k_axes = _build_k_axes(grid) k_grid = _build_k_grid(k_axes) # Compute rfft output shape rfft_shape_list = list(grid.shape) rfft_shape_list[-1] = grid.shape[-1] // 2 + 1 rfft_shape = tuple(rfft_shape_list) # FFT initial conditions y0_hat = _fft_slots(y0, layout, grid) # Zero the Nyquist mode(s) in the IC. The rfft Nyquist bin (last mode # in each dimension) must be real for real-valued fields. The modal # evolution matrix has complex entries (from gradient coupling ik), # which creates imaginary components at the Nyquist bin. irfft # silently drops these, causing energy non-conservation proportional # to the Nyquist power. Zeroing the Nyquist IC prevents this entirely. # This is standard practice in spectral methods — the Nyquist mode # aliases with its conjugate and cannot represent physical content. # Ref: Boyd (2001), Chebyshev & Fourier Spectral Methods, §11.5. for _dim_idx, n in enumerate(grid.shape): if n % 2 == 0: # Nyquist mode exists only for even N nyq_mode = n // 2 # last rfft bin if len(grid.shape) == 1: y0_hat[:, nyq_mode] = 0.0 else: # Multi-D: zero along the last-axis Nyquist slice rfft_last = grid.shape[-1] // 2 y0_hat[:, ..., rfft_last] = 0.0 has_pos_dep = _has_position_dependent_terms(spec) has_time_ops = _has_time_derivative_operators(spec) # Determine which matrix builder to use use_generalized = has_time_ops and not has_pos_dep use_constraint = has_constraints and not has_pos_dep and not use_generalized B_lhs_modes: NDArray[np.complex128] | None = None # set by generalized path constraint_vel_arrays: dict[ str, NDArray[np.float64] ] = {} # populated by Schur path if use_generalized or use_constraint: # Both paths produce: A_reduced, recovery, constraint names, slot mapping if use_generalized: # Generalized mass-matrix system: M·ẍ = K·x + D·ẋ + J·x⃛ # Returns A_rhs and optional B_lhs for generalized eigenvalue. # Ref: Golub & Van Loan (2013), Matrix Computations §7.7 ( A_reduced, B_lhs_modes, # CRITICAL: was _B_lhs_modes (discarded) — #177 recovery_matrix, _v_recovery_gen, # unused — constraint vel from eigendata c_names, orig_to_reduced, ) = _build_generalized_evolution_matrices( spec, layout, grid, coeff_eval, k_grid, rfft_shape, ) else: # Constraint elimination via Fourier Schur complement # Ref: Hairer & Wanner (1996), Solving ODEs II, Ch. VII ( A_reduced, recovery_matrix, _v_recovery_matrix, # recovery @ A_reduced (exact constraint velocities) c_names, orig_to_reduced, ) = _build_constraint_eliminated_matrices( spec, layout, grid, coeff_eval, k_grid, rfft_shape, ) n_dyn = A_reduced.shape[1] n_modes = y0_hat.shape[1] n_pts = layout.num_points # Extract dynamical IC in reduced ordering y0_hat_dyn = np.zeros((n_dyn, n_modes), dtype=np.complex128) for orig_si, red_pos in orig_to_reduced.items(): y0_hat_dyn[red_pos] = y0_hat[orig_si] # Build a reduced StateLayout for eigendecomposition sorted_orig = sorted(orig_to_reduced.keys()) red_slots = tuple(layout.slots[si] for si in sorted_orig) red_field_map: dict[str, int] = {} red_vel_map: dict[str, int] = {} for new_i, si in enumerate(sorted_orig): s = layout.slots[si] if s.kind == "field": red_field_map[s.field_name] = new_i elif s.kind == "velocity": red_vel_map[s.field_name] = new_i dyn_layout = StateLayout( slots=red_slots, num_points=n_pts, field_slot_map=red_field_map, velocity_slot_map=red_vel_map, dynamical_fields=layout.dynamical_fields, ) # Evolve dynamical fields (return Fourier + derivative data) times, dyn_snapshots, dyn_fourier, dyn_deriv_fourier = _evolve_per_mode( A_reduced, y0_hat_dyn, t_eval, dyn_layout, grid, None, progress, # callback handled below with full state return_fourier=True, return_derivative_fourier=True, # for exact constraint velocities B_modes=B_lhs_modes, # generalized eigenvalue if vel coupling ) # Reconstruct full state (including constraints) at each snapshot n_full = layout.num_slots * n_pts snapshots = np.zeros((len(t_eval), n_full)) assert dyn_fourier is not None # guaranteed by return_fourier=True # Populate constraint velocity arrays: exact ∂_t(c) from eigendata. # "Constraint" is a solver concept (algebraic evolution), not a physics # statement — these fields have physically meaningful velocities. # v_c(t) = recovery · d'(t), where d'(t) is computed from eigendata # inside _evolve_per_mode (V·diag(λ·exp(λt))·y0_eigen — exact). for c_name in c_names: constraint_vel_arrays[c_name] = np.zeros((len(t_eval), *grid.shape)) for ti in range(len(t_eval)): dyn_phys = dyn_snapshots[ti] # Use Fourier data directly (already computed in _evolve_per_mode) y_hat_dyn_t = dyn_fourier[ti] # (n_dyn, n_modes) # Recover constraint fields: c_hat = recovery @ d_hat c_hat = np.einsum("mcj,jm->cm", recovery_matrix, y_hat_dyn_t) # Recover constraint velocities: v_c_hat = recovery @ d'_hat # d'_hat comes from eigendata — exact, no numerical differentiation. assert dyn_deriv_fourier is not None dy_hat_dyn_t = dyn_deriv_fourier[ti] # (n_dyn, n_modes) v_c_hat = np.einsum("mcj,jm->cm", recovery_matrix, dy_hat_dyn_t) # Assemble full physical state full_state = np.zeros(n_full) for orig_si, red_pos in orig_to_reduced.items(): full_state[orig_si * n_pts : (orig_si + 1) * n_pts] = dyn_phys[ red_pos * n_pts : (red_pos + 1) * n_pts ] for ci, c_name in enumerate(c_names): c_slot = layout.field_slot_map[c_name] c_phys = np.fft.irfftn( c_hat[ci].reshape(rfft_shape), s=grid.shape, axes=list(range(len(grid.shape))), ).ravel() full_state[c_slot * n_pts : (c_slot + 1) * n_pts] = np.real( c_phys, ) # Store exact constraint velocity (from eigendata d') v_c_phys = np.fft.irfftn( v_c_hat[ci].reshape(rfft_shape), s=grid.shape, axes=list(range(len(grid.shape))), ) constraint_vel_arrays[c_name][ti] = np.real(v_c_phys) snapshots[ti] = full_state if snapshot_callback is not None: snapshot_callback(t_eval[ti], full_state) n_c = len(c_names) if use_generalized: method_desc = ( f"per-mode eigendecomposition with generalized Schur elimination " f"({n_c} constraints, {n_dyn} dynamical slots, mass-matrix)" ) else: method_desc = ( f"per-mode eigendecomposition with Schur constraint elimination " f"({n_c} constraints, {n_dyn} dynamical slots)" ) elif not has_pos_dep: # All-constant coefficients: per-mode independent evolution A_modes = _build_per_mode_matrices( spec, layout, grid, coeff_eval, k_grid, rfft_shape, ) times, snapshots, _, _ = _evolve_per_mode( A_modes, y0_hat, t_eval, layout, grid, snapshot_callback, progress, ) method_desc = "per-mode eigendecomposition (constant coefficients)" else: # Position-dependent coefficients: full convolution matrix A_full = _build_convolution_matrix( spec, layout, grid, coeff_eval, k_grid, rfft_shape, ) times, snapshots = _evolve_full_matrix( A_full, y0_hat, t_eval, layout, grid, snapshot_callback, progress, ) n_total = A_full.shape[0] method_desc = f"expm_multiply ({n_total}x{n_total}, position-dependent)" if progress is not None: progress.finish() result: SolverResult = { "t": times, "y": snapshots, "success": True, "message": f"Modal solver completed ({method_desc})", } # Attach constraint velocity arrays (exact ∂_t for constraint fields). # These are computed from v_recovery = recovery @ A_reduced inside the # Schur elimination path. For generalized eigenvalue or non-constraint # systems, constraint_vel_arrays is empty. if constraint_vel_arrays: result["constraint_velocities"] = constraint_vel_arrays # type: ignore[typeddict-unknown-key] return result