"""Pre-solve algebraic constraints to produce consistent initial conditions.
Before IDA starts, constraint equations (time_order=0) must be satisfied.
For example, Gauss's law ``laplacian(A_0) = -div(pi)`` requires A_0 to be
non-trivially solved when the source (pi fields) is nonzero.
Three-tier solver hierarchy (automatically selected):
Tier 1 — **FFT** (O(N log N)): All axes periodic, all self-coefficients
constant, all operators have known Fourier multipliers. Uses modified
wavenumbers ``k_mod = (2/dx)*sin(k*dx/2)`` for exact FD-consistency.
Tier 2 — **Operator probing → sparse matrix** (O(N²) build, O(N) solve):
Universal fallback. Applies ``apply_operator()`` to unit vectors e_j to
build the operator matrix column by column. Automatically handles
position-dependent coefficients, arbitrary BCs, and unknown/future operators.
Tier 3 — **Iterative** (future): CG/GMRES for very large grids where
direct factorization is impractical.
Design choices:
- Tier 2 probing reuses the **exact same** ``apply_operator()`` code path as
the simulation, guaranteeing mathematical consistency.
- Position-dependent coefficients (NDArrays from ``CoefficientEvaluator``)
multiply element-wise in probing, yielding correct spatially-varying matrix
entries. This is critical for curved-spacetime or background-field
constraints where translational symmetry is broken.
- FFT eligibility is checked conservatively: any position-dependent self-coeff,
non-periodic axis, or unknown multiplier → fall back to Tier 2.
- Zero-mode handling for singular operators (pure Poisson/periodic): enforce
zero mean on the solution by setting ``u_hat[0,...,0] = 0``, with a
compatibility check that the source has zero mean.
References
----------
- Modified wavenumbers for finite differences: Lele, J. Comp. Phys., 1992.
- Spectral constraint projection: Dedalus (Burns et al., PRR 2020).
"""
from __future__ import annotations
import warnings
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast
import numpy as np
from tidal.solver._defaults import SECOND_ORDER
from tidal.solver._scipy_types import SparseMatrix, lil_matrix, sparse_solve
from tidal.solver.operators import BCSpec, apply_operator, is_periodic_bc
# Numerical tolerance thresholds
_SINGULAR_TOL = 1e-14 # Below this, a Fourier multiplier is treated as singular
_COMPAT_TOL = 1e-10 # Source projection tolerance for compatibility check
if TYPE_CHECKING:
from numpy.typing import NDArray
from tidal.solver.coefficients import CoefficientEvaluator
from tidal.solver.fields import FieldSet
from tidal.solver.grid import GridInfo
from tidal.solver.state import StateLayout
from tidal.symbolic.json_loader import (
ConstraintSolverConfig,
EquationSystem,
OperatorTerm,
)
# ---------------------------------------------------------------------------
# Temporal operator detection
# ---------------------------------------------------------------------------
# Operators with time_order >= 2 (accelerations, jerks) are not part of the IC
# state vector and must be skipped during constraint IC solving. The exact set
# mirrors modal._TIME_OPERATORS with time_order >= 2.
_ACCEL_AND_HIGHER_OPS: frozenset[str] = frozenset(
{
"d2_t",
"d3_t",
*(f"mixed_T2_S{s}{ax}" for s in (1, 2) for ax in "xyz"),
*(f"mixed_T3_S{s}{ax}" for s in (1,) for ax in "xyz"),
}
)
# ---------------------------------------------------------------------------
# Term classification
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class _ConstraintTerms:
"""Classified terms for one constraint equation.
``self_terms`` act on the constraint field itself (form the linear
operator L). ``source_terms`` act on other fields (form the RHS b).
The equation is: L[u] + b = 0 → L[u] = -b → solve for u.
"""
field_name: str
self_terms: list[tuple[float | NDArray[np.float64], str]]
source_terms: Sequence[tuple[float | NDArray[np.float64], str, str]]
eq_idx: int
has_position_dependent_self: bool
config: ConstraintSolverConfig
def _classify_terms( # noqa: PLR0913, PLR0917
eq_idx: int,
rhs_terms: tuple[OperatorTerm, ...],
constraint_field: str,
coeff_eval: CoefficientEvaluator,
t: float,
config: ConstraintSolverConfig,
) -> _ConstraintTerms:
"""Separate RHS terms into self-operator and source terms."""
self_terms: list[tuple[float | NDArray[np.float64], str]] = []
source_terms: list[tuple[float | NDArray[np.float64], str, str]] = []
has_pos_dep = False
for term_idx, term in enumerate(rhs_terms):
coeff = coeff_eval.resolve(term, t, eq_idx=eq_idx, term_idx=term_idx)
# first_derivative_t(X) = dX/dt = velocity of X.
# Resolve to identity(v_X) since this is not a spatial operator.
if term.operator == "first_derivative_t":
source_terms.append((coeff, "identity", f"v_{term.field}"))
elif term.operator in _ACCEL_AND_HIGHER_OPS:
# Acceleration/jerk terms (d2_t, mixed_T2_*, d3_t, mixed_T3_*)
# are not part of the IC state vector — skip them.
continue
elif term.field == constraint_field:
self_terms.append((coeff, term.operator))
if isinstance(coeff, np.ndarray):
has_pos_dep = True
else:
source_terms.append((coeff, term.operator, term.field))
return _ConstraintTerms(
field_name=constraint_field,
self_terms=self_terms,
source_terms=source_terms,
eq_idx=eq_idx,
has_position_dependent_self=has_pos_dep,
config=config,
)
# ---------------------------------------------------------------------------
# Method selection
# ---------------------------------------------------------------------------
def _lap_axis(k: NDArray[np.float64], h: float) -> NDArray[np.float64]:
"""Modified-wavenumber Laplacian along one axis.
Returns the Fourier symbol of the FD second-derivative stencil,
matching the current ``fd_order`` from ``operators.py``:
- Order 2: ``-(2/h²)(1 - cos(kh))``
- Order 4: ``(-cos(2kh)/6 + 8cos(kh)/3 - 5/2) / h²``
- Order 6: ``(cos(3kh)/45 - 3cos(2kh)/10 + 3cos(kh) - 49/18) / h²``
This ensures the FFT constraint solver uses the same dispersion
relation as the FD spatial operators.
Reference: Fornberg (1988), Mathematics of Computation 51(184).
"""
from tidal.solver.operators import get_fd_order # noqa: PLC0415
order = get_fd_order()
kh = k * h
inv_h2 = 1.0 / (h * h)
if order == 2: # noqa: PLR2004
return -(2.0 * inv_h2) * (1.0 - np.cos(kh))
if order == 4: # noqa: PLR2004
# Fourier symbol of [-1/12, 4/3, -5/2, 4/3, -1/12] / h²
return inv_h2 * (-np.cos(2 * kh) / 6.0 + 8.0 * np.cos(kh) / 3.0 - 5.0 / 2.0)
# Fall-through: order 6
# Fourier symbol of [1/90, -3/20, 3/2, -49/18, 3/2, -3/20, 1/90] / h^2
return inv_h2 * (
np.cos(3 * kh) / 45.0
- 3.0 * np.cos(2 * kh) / 10.0
+ 3.0 * np.cos(kh)
- 49.0 / 18.0
)
def _grad_axis(k: NDArray[np.float64], h: float) -> NDArray[np.complex128]:
"""Modified-wavenumber gradient along one axis.
Returns the Fourier symbol of the FD first-derivative stencil,
matching the current ``fd_order``:
- Order 2: ``i sin(kh) / h``
- Order 4: ``i (8sin(kh) - sin(2kh)) / (6h)``
- Order 6: ``i (45sin(kh) - 9sin(2kh) + sin(3kh)) / (30h)``
Reference: Fornberg (1988), Mathematics of Computation 51(184).
"""
from tidal.solver.operators import get_fd_order # noqa: PLC0415
order = get_fd_order()
kh = k * h
inv_h = 1.0 / h
if order == 2: # noqa: PLR2004
result = 1j * np.sin(kh) * inv_h
elif order == 4: # noqa: PLR2004
# Fourier symbol of [1/12, -2/3, 0, 2/3, -1/12] / h
result = 1j * inv_h * (8.0 * np.sin(kh) - np.sin(2 * kh)) / 6.0
else:
# Fall-through: order 6
# Fourier symbol of [-1/60, 3/20, -3/4, 0, 3/4, -3/20, 1/60] / h
result = (
1j
* inv_h
* (45.0 * np.sin(kh) - 9.0 * np.sin(2 * kh) + np.sin(3 * kh))
/ 30.0
)
return np.asarray(result, dtype=np.complex128)
# Type for Fourier multiplier functions: (kvecs, dx) -> NDArray
_MultiplierFn = Callable[
[list[np.ndarray], tuple[float, ...]],
np.ndarray,
]
def _biharmonic_mult(
kvecs: list[np.ndarray],
dx: tuple[float, ...],
) -> np.ndarray:
"""Fourier multiplier for biharmonic = (laplacian)^2."""
lap = sum(_lap_axis(kvecs[i], dx[i]) for i in range(len(kvecs)))
return lap * lap # type: ignore[return-value]
def _cross_deriv(
kvecs: list[np.ndarray],
dx: tuple[float, ...],
a: int,
b: int,
) -> np.ndarray:
"""Fourier multiplier for cross_derivative: d^2/(dx_a dx_b).
FD stencil uses central differences along each axis, giving
multiplier = (i*sin(k_a*h_a)/h_a) * (i*sin(k_b*h_b)/h_b)
= -sin(k_a*h_a)*sin(k_b*h_b)/(h_a*h_b).
"""
return _grad_axis(kvecs[a], dx[a]) * _grad_axis(kvecs[b], dx[b])
_MULTIPLIERS: dict[str, _MultiplierFn] = {
"identity": lambda kv, _dx: np.ones_like(kv[0]),
"laplacian": lambda kv, dx: sum( # type: ignore[return-value,arg-type]
_lap_axis(kv[i], dx[i]) for i in range(len(kv))
),
"laplacian_x": lambda kv, dx: _lap_axis(kv[0], dx[0]),
"laplacian_y": lambda kv, dx: _lap_axis(kv[1], dx[1]),
"laplacian_z": lambda kv, dx: _lap_axis(kv[2], dx[2]),
"gradient_x": lambda kv, dx: _grad_axis(kv[0], dx[0]),
"gradient_y": lambda kv, dx: _grad_axis(kv[1], dx[1]),
"gradient_z": lambda kv, dx: _grad_axis(kv[2], dx[2]),
"cross_derivative_xy": lambda kv, dx: _cross_deriv(kv, dx, 0, 1),
"cross_derivative_xz": lambda kv, dx: _cross_deriv(kv, dx, 0, 2),
"cross_derivative_yz": lambda kv, dx: _cross_deriv(kv, dx, 1, 2),
"biharmonic": _biharmonic_mult,
}
def _select_method(
terms: _ConstraintTerms,
grid: GridInfo,
bc: BCSpec | None,
) -> str:
"""Determine solver method: 'fft' or 'matrix'.
User override via config.method takes precedence. Otherwise, checks
FFT eligibility (all-periodic, constant coefficients, known multipliers)
and falls back to matrix (operator probing).
"""
config = terms.config
if config.method not in {"auto", "poisson"}:
return config.method
# Check all-periodic
all_periodic = all(grid.periodic)
if bc is not None:
bcs = (bc,) * grid.ndim if isinstance(bc, str) else tuple(bc)
all_periodic = all(is_periodic_bc(b) for b in bcs)
if not all_periodic:
return "matrix"
if terms.has_position_dependent_self:
return "matrix"
for _, op_name in terms.self_terms:
if op_name not in _MULTIPLIERS:
return "matrix"
return "fft"
# ---------------------------------------------------------------------------
# Field name resolution
# ---------------------------------------------------------------------------
def _build_name_map(spec: EquationSystem) -> dict[str, str]:
"""Build a map from JSON field references to FieldSet slot names.
The canonical naming convention is ``v_{field_name}`` (e.g.
``v_A_1``), matching ``StateLayout`` slot names.
"""
name_map: dict[str, str] = {}
for eq in spec.equations:
# Field names map to themselves
name_map[eq.field_name] = eq.field_name
# Velocity references: canonical v_field_name
if eq.time_derivative_order >= SECOND_ORDER:
vel_slot = f"v_{eq.field_name}"
name_map[vel_slot] = vel_slot
return name_map
# ---------------------------------------------------------------------------
# Connected component decomposition
# ---------------------------------------------------------------------------
def _find_connected_components(
groups: list[_ConstraintTerms],
name_map: dict[str, str] | None = None,
) -> list[list[_ConstraintTerms]]:
"""Decompose constraint groups into connected components.
Builds an undirected graph where constraint fields are nodes and
cross-constraint source references are edges. Returns a list of
connected components, each being a list of ``_ConstraintTerms``.
This prevents false coupling: e.g., if h_0 is independent of
h_1 ↔ h_2, they are solved as separate systems rather than one
singular block.
"""
constraint_names = {g.field_name for g in groups}
# Union-Find
parent: dict[str, str] = {g.field_name: g.field_name for g in groups}
def find(x: str) -> str:
while parent[x] != x:
parent[x] = parent[parent[x]]
x = parent[x]
return x
def union(a: str, b: str) -> None:
ra, rb = find(a), find(b)
if ra != rb:
parent[ra] = rb
for g in groups:
for _, _, field_name in g.source_terms:
resolved = field_name
if name_map:
resolved = name_map.get(field_name, field_name)
if resolved in constraint_names:
union(g.field_name, resolved)
# Collect components (preserve insertion order)
components: dict[str, list[_ConstraintTerms]] = {}
for g in groups:
root = find(g.field_name)
components.setdefault(root, []).append(g)
return list(components.values())
# ---------------------------------------------------------------------------
# Source evaluation
# ---------------------------------------------------------------------------
def _evaluate_source(
source_terms: Sequence[tuple[float | NDArray[np.float64], str, str]],
fields: FieldSet,
grid: GridInfo,
bc: BCSpec | None,
name_map: dict[str, str] | None = None,
) -> NDArray[np.float64]:
"""Compute source RHS: b = Σ coeff_i * operator_i(field_i)."""
result: NDArray[np.float64] = np.zeros(grid.shape)
for coeff, op_name, field_name in source_terms:
resolved = field_name
if name_map and field_name in name_map:
resolved = name_map[field_name]
data = fields[resolved] if resolved in fields else np.zeros(grid.shape)
operated = apply_operator(op_name, data, grid, bc)
result += coeff * operated
return result
# ---------------------------------------------------------------------------
# Tier 1: FFT solver
# ---------------------------------------------------------------------------
def _build_wavenumber_grids(
grid: GridInfo,
) -> tuple[list[Any], tuple[float, ...]]:
"""Build wavenumber grids for FFT-based solving."""
k_1d = [
2.0 * np.pi * np.fft.fftfreq(grid.shape[ax], d=grid.dx[ax])
for ax in range(grid.ndim)
]
kvecs = list(np.meshgrid(*k_1d, indexing="ij")) if grid.ndim > 1 else [k_1d[0]]
return kvecs, grid.dx
def _fft_solve_single(
terms: _ConstraintTerms,
grid: GridInfo,
source_rhs: NDArray[np.float64],
) -> NDArray[np.float64]:
"""Solve a single constraint via FFT.
Computes: u = F^{-1}[ -F[source] / multiplier ]
Raises
------
ValueError
If the operator is singular and the source has nonzero projection
onto the corresponding mode.
"""
kvecs, dx = _build_wavenumber_grids(grid)
# Build combined multiplier from self-terms
multiplier: NDArray[np.complex128] = np.zeros(grid.shape, dtype=np.complex128)
for coeff, op_name in terms.self_terms:
fn = _MULTIPLIERS[op_name]
multiplier += coeff * fn(kvecs, dx)
# Transform source
source_hat = np.fft.fftn(source_rhs)
# Detect singular modes
is_singular = np.abs(multiplier) < _SINGULAR_TOL
if np.any(is_singular):
source_at_singular = np.abs(source_hat[is_singular])
max_source = float(np.max(np.abs(source_hat))) + 1e-30
max_incompatible = (
float(np.max(source_at_singular)) if source_at_singular.size > 0 else 0.0
)
# Use relative tolerance with an absolute floor to avoid
# false positives from floating-point noise (e.g. 4th-order FD
# stencils can produce ~1e-15 DC components from rounding).
compat_threshold = max(_COMPAT_TOL * max_source, 1e-12)
if max_incompatible > compat_threshold:
msg = (
f"Constraint for '{terms.field_name}' is incompatible: "
f"source has nonzero projection (max={max_incompatible:.4g}) "
f"onto the null space of the self-operator. "
f"Check that source terms have zero mean for "
f"Poisson-type constraints."
)
raise ValueError(msg)
n_singular = int(np.sum(is_singular))
warnings.warn(
f"FFT constraint pre-solve: field '{terms.field_name}' has "
f"{n_singular} singular mode(s) in Fourier space (null space "
f"of operator). Setting u_hat = 0 at these modes (zero-mean "
f"gauge). Solution is unique up to these modes.",
UserWarning,
stacklevel=3,
)
safe_mult = np.where(is_singular, 1.0, multiplier)
u_hat = np.where(is_singular, 0.0, -source_hat / safe_mult)
else:
u_hat = -source_hat / multiplier
return np.fft.ifftn(u_hat).real.astype(np.float64)
def _fft_solve_coupled( # noqa: PLR0914
groups: list[_ConstraintTerms],
grid: GridInfo,
fields: FieldSet,
bc: BCSpec | None,
name_map: dict[str, str] | None = None,
) -> dict[str, NDArray[np.float64]]:
"""Solve coupled constraints via block FFT.
At each wavenumber k, assembles and solves the n-by-n system.
Raises
------
ValueError
If the coupled system is singular and the source has nonzero mean.
"""
n_c = len(groups)
kvecs, dx = _build_wavenumber_grids(grid)
constraint_names = [g.field_name for g in groups]
name_to_idx = {name: i for i, name in enumerate(constraint_names)}
# Evaluate sources from non-constraint fields
sources: list[NDArray[np.float64]] = []
for group in groups:
non_constraint = [
(c, op, f) for c, op, f in group.source_terms if f not in name_to_idx
]
sources.append(_evaluate_source(non_constraint, fields, grid, bc, name_map))
# Build multiplier matrix and RHS in Fourier space
m_hat = np.zeros((*grid.shape, n_c, n_c), dtype=np.complex128)
rhs_hat = np.zeros((*grid.shape, n_c), dtype=np.complex128)
for i, group in enumerate(groups):
# Self-terms → diagonal
for coeff, op_name in group.self_terms:
fn = _MULTIPLIERS[op_name]
m_hat[..., i, i] += coeff * fn(kvecs, dx)
# Cross-constraint terms → off-diagonal
for coeff, op_name, field_name in group.source_terms:
if field_name in name_to_idx:
j = name_to_idx[field_name]
fn = _MULTIPLIERS[op_name]
m_hat[..., i, j] += coeff * fn(kvecs, dx)
rhs_hat[..., i] = -np.fft.fftn(sources[i])
# Handle singular zero-mode
zero_idx = tuple(0 for _ in range(grid.ndim))
m_zero = m_hat[zero_idx]
if abs(np.linalg.det(m_zero)) < _SINGULAR_TOL:
rhs_zero = rhs_hat[zero_idx]
if float(np.max(np.abs(rhs_zero))) > _COMPAT_TOL:
msg = (
"Coupled constraint system is singular at zero wavenumber "
"and source has nonzero mean. Check compatibility."
)
raise ValueError(msg)
coupled_names = ", ".join(constraint_names)
warnings.warn(
f"FFT coupled constraint pre-solve: system "
f"[{coupled_names}] is singular at zero wavenumber (null "
f"space of operator). Setting zero-mode to identity/zero "
f"(zero-mean gauge). Solution is unique up to constants.",
UserWarning,
stacklevel=3,
)
m_hat[zero_idx] = np.eye(n_c)
rhs_hat[zero_idx] = 0.0
# np.linalg.solve dispatches to the matrix path (m,m),(m,n)->(m,n)
# when ndim >= 2, but rhs_hat is (..., n_c) — a vector per wavenumber.
# Add a trailing dim so numpy sees (..., n_c, 1), then squeeze it out.
u_hat = np.linalg.solve(m_hat, rhs_hat[..., np.newaxis])[..., 0]
results: dict[str, NDArray[np.float64]] = {}
for i, name in enumerate(constraint_names):
results[name] = np.fft.ifftn(u_hat[..., i]).real.astype(np.float64)
return results
# ---------------------------------------------------------------------------
# Tier 2: Operator probing → sparse matrix
# ---------------------------------------------------------------------------
def _probe_operator_matrix(
self_terms: list[tuple[float | NDArray[np.float64], str]],
grid: GridInfo,
bc: BCSpec | None,
) -> SparseMatrix:
"""Build sparse matrix by probing apply_operator() with unit vectors.
For single constant-coefficient terms, delegates to the shared
``build_operator_matrix()`` utility. For multiple or position-dependent
terms, builds the matrix by probing the combined operator.
"""
# Fast path: single term with constant coefficient
if len(self_terms) == 1:
coeff, op_name = self_terms[0]
if np.ndim(coeff) == 0:
from tidal.solver.analytical_jacobian import ( # noqa: PLC0415
build_operator_matrix,
)
mat = build_operator_matrix(op_name, grid, bc)
return float(coeff) * mat if float(coeff) != 1.0 else mat
# General path: multiple terms and/or position-dependent coefficients
n = grid.num_points
mat = lil_matrix((n, n))
for j in range(n):
e_j: NDArray[np.float64] = np.zeros(grid.shape)
e_j.flat[j] = 1.0
col: NDArray[np.float64] = np.zeros(grid.shape)
for coeff, op_name in self_terms:
col += coeff * apply_operator(op_name, e_j, grid, bc)
col_flat = col.ravel()
nz = np.nonzero(col_flat)[0]
for row in nz:
mat[row, j] = col_flat[row]
return mat.tocsc()
def _matrix_solve(
op_matrix: SparseMatrix,
source_rhs: NDArray[np.float64],
grid_shape: tuple[int, ...],
) -> NDArray[np.float64]:
"""Solve op_matrix @ u = -source_rhs via sparse direct factorization."""
rhs = -source_rhs.ravel()
return sparse_solve(op_matrix, rhs).reshape(grid_shape)
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
[docs]
def pre_solve_constraints( # noqa: PLR0913
spec: EquationSystem,
grid: GridInfo,
y0: NDArray[np.float64],
*,
bc: BCSpec | None = None,
parameters: dict[str, float] | None = None,
t: float = 0.0,
) -> NDArray[np.float64]:
"""Solve algebraic constraints to produce consistent initial conditions.
Only processes equations where ``time_derivative_order == 0`` AND
``constraint_solver.enabled == True``. Returns a copy of y0 with
constraint field slots overwritten by the solution.
Parameters
----------
spec : EquationSystem
Parsed JSON equation specification.
grid : GridInfo
Spatial grid.
y0 : NDArray[np.float64]
Initial state vector (flat, from StateLayout).
bc : str or tuple of str, optional
Boundary conditions for spatial operators.
parameters : dict[str, float], optional
Runtime parameter overrides for symbolic coefficients.
t : float
Time at which to evaluate coefficients (usually t_span[0]).
Returns
-------
NDArray[np.float64]
Updated y0 with constraint fields solved.
Warns
-----
UserWarning
When the FFT solver encounters singular modes (null space of the
operator) and regularizes by setting ``u_hat = 0`` at those modes.
This is a numerical gauge choice (zero-mean). To disable
automatic constraint solving for a field, set
``constraint_solver.enabled = false`` in the JSON spec.
"""
from tidal.solver.coefficients import CoefficientEvaluator # noqa: PLC0415
from tidal.solver.fields import FieldSet # noqa: PLC0415
from tidal.solver.state import StateLayout # noqa: PLC0415
# Find enabled constraint equations
constraint_eqs = [
(i, eq)
for i, eq in enumerate(spec.equations)
if eq.time_derivative_order == 0 and eq.constraint_solver.enabled
]
if not constraint_eqs:
return y0
layout = StateLayout.from_spec(spec, grid.num_points)
coeff_eval = CoefficientEvaluator(spec, grid, parameters)
coeff_eval.begin_timestep(t)
# Copy y0 BEFORE creating FieldSet — FieldSet wraps its backing array
# with zero-copy views, so mutations via fields[name]=... propagate back.
y0_out = y0.copy()
fields = FieldSet.from_flat(layout, grid.shape, y0_out)
# Build name map: JSON field references (e.g. "v_A_1") → FieldSet slot names
name_map = _build_name_map(spec)
# Classify terms for each constraint
groups: list[_ConstraintTerms] = []
for eq_idx, eq in constraint_eqs:
terms = _classify_terms(
eq_idx,
eq.rhs_terms,
eq.field_name,
coeff_eval,
t,
eq.constraint_solver,
)
if not terms.self_terms:
# The labeled field doesn't appear in this equation at all.
# This is a compatibility constraint on *other* fields (e.g.,
# the Fierz-Pauli Hamiltonian constraint constrains h_3/h_5
# but doesn't involve h_0). Skip — nothing to solve here.
continue
groups.append(terms)
# Decompose into connected components to avoid false coupling
# (e.g., massive_gravity: h_1↔h_2 coupled, h_0 skipped)
components = _find_connected_components(groups, name_map)
for component in components:
if len(component) == 1:
_solve_independent(component, grid, bc, fields, layout, y0_out, name_map)
else:
_solve_coupled(component, grid, bc, fields, layout, y0_out, name_map)
return y0_out
def _solve_independent( # noqa: PLR0913, PLR0917
groups: list[_ConstraintTerms],
grid: GridInfo,
bc: BCSpec | None,
fields: FieldSet,
layout: StateLayout,
y0: NDArray[np.float64],
name_map: dict[str, str] | None = None,
) -> None:
"""Solve independent (non-coupled) constraints one at a time."""
n = grid.num_points
for terms in groups:
source = _evaluate_source(terms.source_terms, fields, grid, bc, name_map)
method = _select_method(terms, grid, bc)
if method == "fft":
solution = _fft_solve_single(terms, grid, source)
else:
op_mat = _probe_operator_matrix(terms.self_terms, grid, bc)
solution = _matrix_solve(op_mat, source, grid.shape)
slot_idx = layout.field_slot_map[terms.field_name]
y0[slot_idx * n : (slot_idx + 1) * n] = solution.ravel()
fields[terms.field_name] = solution
def _solve_coupled( # noqa: PLR0912, PLR0913, PLR0914, PLR0917, C901
groups: list[_ConstraintTerms],
grid: GridInfo,
bc: BCSpec | None,
fields: FieldSet,
layout: StateLayout,
y0: NDArray[np.float64],
name_map: dict[str, str] | None = None,
) -> None:
"""Solve coupled constraints (e.g., coupled Proca A_0 ↔ B_0)."""
n = grid.num_points
# Check if FFT path is available for ALL coupled constraints.
# Must also verify that cross-constraint source operators have FFT
# multipliers, since _select_method only checks self-terms.
constraint_names = {g.field_name for g in groups}
all_fft = all(_select_method(g, grid, bc) == "fft" for g in groups)
if all_fft:
for g in groups:
for _, op_name, field_name in g.source_terms:
if field_name in constraint_names and op_name not in _MULTIPLIERS:
all_fft = False
break
if all_fft:
solutions = _fft_solve_coupled(groups, grid, fields, bc, name_map)
for name, sol in solutions.items():
slot_idx = layout.field_slot_map[name]
y0[slot_idx * n : (slot_idx + 1) * n] = sol.ravel()
fields[name] = sol
else:
# Gauss-Seidel iteration
max_iter = max(g.config.max_iterations for g in groups)
tol = min(g.config.tolerance for g in groups)
# Pre-compute operator matrices and methods (loop-invariant)
methods: dict[str, str] = {}
op_matrices: dict[str, np.ndarray] = {}
for terms in groups:
method = _select_method(terms, grid, bc)
methods[terms.field_name] = method
if method != "fft":
op_matrices[terms.field_name] = _probe_operator_matrix(
terms.self_terms, grid, bc
)
for _iteration in range(max_iter):
max_change = 0.0
for terms in groups:
source = _evaluate_source(
terms.source_terms, fields, grid, bc, name_map
)
method = methods[terms.field_name]
if method == "fft":
solution = _fft_solve_single(terms, grid, source)
else:
op_mat = op_matrices[terms.field_name]
solution = _matrix_solve(op_mat, source, grid.shape)
old = (
fields[terms.field_name]
if terms.field_name in fields
else np.zeros(grid.shape)
)
change = float(np.max(np.abs(solution - old)))
max_change = max(max_change, change)
slot_idx = layout.field_slot_map[terms.field_name]
y0[slot_idx * n : (slot_idx + 1) * n] = solution.ravel()
fields[terms.field_name] = solution
if max_change < tol:
break
# ---------------------------------------------------------------------------
# Unified constraint IC solver
# ---------------------------------------------------------------------------
# Tolerance for considering a field "zero" (uninitialized)
_FIELD_ZERO_TOL = 1e-30
# Tolerance for constraint verification residual
_IC_RESIDUAL_TOL = 1e-10
# Maximum number of propagation iterations
_MAX_PROPAGATION_ITERATIONS = 20
def _is_field_zero(data: np.ndarray) -> bool:
"""Check if a field array is effectively zero (uninitialized)."""
return float(np.max(np.abs(data))) < _FIELD_ZERO_TOL
def _solve_for_target(
terms: _ConstraintTerms,
grid: GridInfo,
bc: BCSpec | None,
source_rhs: np.ndarray,
) -> np.ndarray:
"""Solve L[target] = -source for a single target field.
Reuses the existing FFT/matrix solver infrastructure.
"""
method = _select_method(terms, grid, bc)
if method == "fft":
return _fft_solve_single(terms, grid, source_rhs)
op_mat = _probe_operator_matrix(terms.self_terms, grid, bc)
return _matrix_solve(op_mat, source_rhs, grid.shape)
def _find_target_field( # noqa: PLR0913, PLR0917
eq_idx: int,
rhs_terms: tuple[OperatorTerm, ...],
eq_field_name: str, # noqa: ARG001 — kept for caller consistency
determined: set[str],
coeff_eval: CoefficientEvaluator,
t: float,
config: ConstraintSolverConfig,
) -> tuple[str | None, _ConstraintTerms | None]:
"""Identify which field to solve for in a constraint equation.
Returns ``(target_field_name, classified_terms)`` or ``(None, None)``
if the constraint cannot be solved in this iteration.
Strategy: find exactly ONE free (undetermined) field among all
referenced fields. That field becomes the target. If the target
has self-referencing operator terms, it can be solved.
- 0 free fields → all determined, verification only (return None).
- 1 free field → solvable (the free field is the target).
- 2+ free fields → underdetermined, skip this iteration.
The single-free-field check naturally handles both:
- **Standard constraints** (constraint field is the free one):
e.g., ``laplacian(A_0) + source = 0`` when A_0 is free.
- **Subsidiary constraints** (dynamical field is the free one):
e.g., ``gradient(h_4) + gradient(h_7) = 0`` when only h_7 is free.
- **Cascade ordering**: ``identity(h_4) + identity(h_7) + identity(h_9) = 0``
waits until h_7 is determined before solving for h_9.
"""
# Collect all field names referenced by this equation (skip accel/jerk ops)
all_referenced: set[str] = {
term.field for term in rhs_terms if term.operator not in _ACCEL_AND_HIGHER_OPS
}
# Find free (undetermined) fields
free_fields: list[str] = [f for f in all_referenced if f not in determined]
if len(free_fields) != 1:
return None, None
target = free_fields[0]
terms = _classify_terms(
eq_idx,
rhs_terms,
target,
coeff_eval,
t,
config,
)
if terms.self_terms:
return target, terms
return None, None
[docs]
def ensure_consistent_ic( # noqa: PLR0912, PLR0913, PLR0914, PLR0915, C901
spec: EquationSystem,
grid: GridInfo,
y0: NDArray[np.float64],
*,
bc: BCSpec | None = None,
parameters: dict[str, float] | None = None,
t: float = 0.0,
strict: bool = True,
) -> NDArray[np.float64]:
"""Unified constraint IC solver.
Given user-supplied initial conditions, iteratively solves ALL
constraint equations (``time_derivative_order == 0``) to produce
consistent initial data. Handles three cases uniformly:
1. **Standard constraints** (field has self-terms): solve for the
constraint field. Applies to all constraint equations, not just
``constraint_solver.enabled=True``.
2. **Subsidiary constraints** (no self-terms, one free source field):
solve for the free dynamical field. E.g., transverse gauge
constraint ``gradient_x(h_4) + gradient_x(h_7) = 0`` determines
``h_7 = -h_4`` when ``h_4`` is user-initialized.
3. **Verification** (all referenced fields determined): check that
the constraint is satisfied; error or warn if not.
The algorithm iterates until no more fields can be determined.
Each iteration may unlock new solvable constraints (cascade).
Parameters
----------
spec : EquationSystem
Parsed JSON equation specification.
grid : GridInfo
Spatial grid.
y0 : NDArray[np.float64]
Initial state vector (flat, from StateLayout).
bc : str or tuple of str, optional
Boundary conditions for spatial operators.
parameters : dict[str, float], optional
Runtime parameter overrides for symbolic coefficients.
t : float
Time at which to evaluate coefficients.
strict : bool
If True (default), raise ``ValueError`` when a constraint is
violated and cannot be solved. If False, issue a warning.
Returns
-------
NDArray[np.float64]
Updated y0 with all solvable constraint fields determined.
Raises
------
ValueError
If ``strict=True`` and any constraint is violated after solving.
"""
from tidal.solver.coefficients import CoefficientEvaluator # noqa: PLC0415
from tidal.solver.fields import FieldSet # noqa: PLC0415
from tidal.solver.state import StateLayout # noqa: PLC0415
# Collect ALL constraint equations
constraint_eqs = [
(i, eq) for i, eq in enumerate(spec.equations) if eq.time_derivative_order == 0
]
if not constraint_eqs:
return y0
layout = StateLayout.from_spec(spec, grid.num_points)
coeff_eval = CoefficientEvaluator(spec, grid, parameters)
coeff_eval.begin_timestep(t)
n = grid.num_points
y0_out = y0.copy()
fields = FieldSet.from_flat(layout, grid.shape, y0_out)
name_map = _build_name_map(spec)
# Track which fields are "determined" (non-zero in y0).
determined: set[str] = set()
for name, slot_idx in layout.field_slot_map.items():
data = y0_out[slot_idx * n : (slot_idx + 1) * n]
if not _is_field_zero(data):
determined.add(name)
for name, slot_idx in layout.velocity_slot_map.items():
data = y0_out[slot_idx * n : (slot_idx + 1) * n]
if not _is_field_zero(data):
determined.add(f"v_{name}")
# Warn when constraint equations lack explicit constraint_solver config.
# Without it, the constraint field won't be solved via Phase 1 (FFT /
# Gauss-Seidel) and may stall in Phase 2 iterative propagation.
unconfigured = [
eq.field_name
for _i, eq in constraint_eqs
if not eq.constraint_solver.enabled
and any(term.field == eq.field_name for term in eq.rhs_terms)
]
if unconfigured:
warnings.warn(
f"Constraint equation(s) {unconfigured} have self-terms but no "
f"constraint_solver config (enabled=false). These will NOT be "
f"solved via Phase 1 (FFT/direct). Add [constraint_solver] to "
f"theory.toml and re-derive, or manually add "
f'"constraint_solver": {{"enabled": true}} to the JSON.',
UserWarning,
stacklevel=2,
)
# Phase 1: Handle enabled constraints with coupled detection
# (preserves existing coupled-constraint behavior for EM, Chern-Simons)
enabled_eqs = [(i, eq) for i, eq in constraint_eqs if eq.constraint_solver.enabled]
if enabled_eqs:
enabled_groups: list[_ConstraintTerms] = []
for eq_idx, eq in enabled_eqs:
terms = _classify_terms(
eq_idx,
eq.rhs_terms,
eq.field_name,
coeff_eval,
t,
eq.constraint_solver,
)
if not terms.self_terms:
# The labeled field doesn't appear in this equation.
# This is a compatibility constraint on other fields
# (e.g., Fierz-Pauli Hamiltonian constraint). Skip —
# nothing to solve. Phase 3 verification will check
# whether the constraint is satisfied by the IC.
continue
enabled_groups.append(terms)
# Decompose into connected components to avoid false coupling
components = _find_connected_components(enabled_groups, name_map)
for component in components:
if len(component) == 1:
_solve_independent(
component,
grid,
bc,
fields,
layout,
y0_out,
name_map,
)
else:
_solve_coupled(
component,
grid,
bc,
fields,
layout,
y0_out,
name_map,
)
determined.update(g.field_name for g in enabled_groups)
# Phase 2: Iterative propagation for remaining constraints
remaining_eqs = [
(i, eq) for i, eq in constraint_eqs if not eq.constraint_solver.enabled
]
for _iteration in range(_MAX_PROPAGATION_ITERATIONS):
progress = False
for eq_idx, eq in remaining_eqs:
target, terms = _find_target_field(
eq_idx,
eq.rhs_terms,
eq.field_name,
determined,
coeff_eval,
t,
eq.constraint_solver,
)
if target is None or terms is None:
continue
source = _evaluate_source(
terms.source_terms,
fields,
grid,
bc,
name_map,
)
solution = _solve_for_target(terms, grid, bc, source)
# Write solution into y0 and FieldSet
if target in layout.field_slot_map:
slot_idx = layout.field_slot_map[target]
y0_out[slot_idx * n : (slot_idx + 1) * n] = solution.ravel()
fields[target] = solution
elif target.startswith("v_"):
vel_field = target.removeprefix("v_")
if vel_field in layout.velocity_slot_map:
slot_idx = layout.velocity_slot_map[vel_field]
y0_out[slot_idx * n : (slot_idx + 1) * n] = solution.ravel()
fields[target] = solution
determined.add(target)
progress = True
eq_desc = eq.field_name
if target != eq.field_name:
eq_desc = f"{eq.field_name} (subsidiary)"
warnings.warn(
f"Constraint IC propagation: solved '{target}' "
f"from {eq_desc} constraint equation.",
UserWarning,
stacklevel=2,
)
if not progress:
break
# Phase 2 stall diagnostics: when propagation stalls with unsatisfied
# constraints, identify which fields are undetermined to guide the user.
stall_free_fields: dict[str, set[str]] = {}
for _eq_idx, eq in remaining_eqs:
mapped_refs: set[str] = set()
for term in eq.rhs_terms:
if term.operator in _ACCEL_AND_HIGHER_OPS:
continue
ref = term.field
if term.operator == "first_derivative_t":
ref = f"v_{ref}"
if name_map and ref in name_map:
ref = name_map[ref]
mapped_refs.add(ref)
free = sorted(f for f in mapped_refs if f not in determined)
if len(free) >= 2: # noqa: PLR2004
stall_free_fields[eq.field_name] = set(free)
if stall_free_fields:
# Count how often each free field appears across stuck constraints
field_counts: dict[str, int] = {}
for free_set in stall_free_fields.values():
for f in free_set:
field_counts[f] = field_counts.get(f, 0) + 1
lines = [
"Constraint propagation stalled — these constraints have "
"multiple undetermined fields:"
]
for eq_name, free_set in sorted(stall_free_fields.items()):
lines.append(
f" {eq_name}: {len(free_set)} undetermined "
f"[{', '.join(sorted(free_set))}]"
)
# Highlight high-impact fields (appear in multiple constraints)
shared = {f: c for f, c in field_counts.items() if c > 1}
if shared:
top = cast(
"list[str]",
sorted(shared, key=shared.get, reverse=True), # type: ignore[arg-type]
)
hints = [f"{f} (in {shared[f]} constraints)" for f in top[:3]]
lines.append(
f"Setting IC for: {', '.join(hints)} may resolve multiple constraints."
)
warnings.warn("\n".join(lines), UserWarning, stacklevel=2)
# Phase 3: Final verification of ALL constraint equations
violations: list[tuple[str, float, list[str]]] = []
for eq_idx, eq in constraint_eqs:
# Skip verification for equations containing acceleration/jerk
# operators — the residual is incomplete without those terms.
if any(t.operator in _ACCEL_AND_HIGHER_OPS for t in eq.rhs_terms):
continue
rhs = np.zeros(grid.shape)
for term_idx, term in enumerate(eq.rhs_terms):
coeff = coeff_eval.resolve(
term,
t,
eq_idx=eq_idx,
term_idx=term_idx,
)
# first_derivative_t(X) → identity(v_X)
op_name = term.operator
field_ref = term.field
if op_name == "first_derivative_t":
op_name = "identity"
field_ref = f"v_{field_ref}"
if name_map and field_ref in name_map:
field_ref = name_map[field_ref]
data = fields[field_ref] if field_ref in fields else np.zeros(grid.shape)
operated = apply_operator(op_name, data, grid, bc)
rhs += coeff * operated
max_res = float(np.max(np.abs(rhs)))
if max_res > _IC_RESIDUAL_TOL:
involved = sorted({term.field for term in eq.rhs_terms})
violations.append((eq.field_name, max_res, involved))
if violations:
lines = ["Initial data does not satisfy constraint equation(s):"]
for field_name, max_res, involved in violations:
free_info = ""
if field_name in stall_free_fields:
free = sorted(stall_free_fields[field_name])
free_info = f" — undetermined: [{', '.join(free)}]"
lines.append(
f" {field_name}: max|residual| = {max_res:.2e} "
f"(involves [{', '.join(involved)}]){free_info}"
)
lines.append(
"For physical consistency, choose initial conditions that "
"jointly satisfy all constraint equations, or use "
"--allow-inconsistent-ic to proceed with a warning."
)
msg = "\n".join(lines)
if strict:
raise ValueError(msg)
warnings.warn(msg, UserWarning, stacklevel=2)
return y0_out