Source code for tidal.symbolic.reduction

"""Plane-wave dimensional reduction for JSON equation specs.

After the Wolfram pipeline exports a JSON spec with the original higher-dimensional
metadata (dimension, coordinates, operator names), this module transforms it into a
clean reduced-dimension spec suitable for efficient simulation.

For example, a 3+1D theory reduced along z produces a 1+1D spec with:
- ``dimension: 2``, ``coordinates: ["t", "x"]``, ``signature: [-1, 1]``
- Operators remapped: ``laplacian_z → laplacian_x``, ``gradient_z → gradient_x``
- Killed-axis operators (``laplacian_x``, ``gradient_y``, etc.) removed
- ``coordinate_dependent`` arrays and ``coefficient_symbolic`` strings updated
- Provenance metadata recording the reduction

Curved-coordinate support:
- Coefficients depending only on the surviving coordinate are preserved and remapped
- Volume element is kept if it depends only on the surviving coordinate
- If any surviving coefficient or volume element references a killed coordinate,
  ``ValueError`` is raised (incompatible reduction)
"""

from __future__ import annotations

import copy
import re
from typing import TYPE_CHECKING, Any, cast

if TYPE_CHECKING:
    from collections.abc import Mapping

# ---------------------------------------------------------------------------
# Plane-wave dimensional reduction
# ---------------------------------------------------------------------------

# Coordinate reference pattern in Wolfram expressions: x[], y[], z[]
_COORD_REF_RE = re.compile(r"\b([xyzwvut])\s*\[\s*\]")

# Axis-specific operator patterns
_AXIS_OP_PREFIXES = ("laplacian_", "gradient_", "cross_derivative_")


def _build_operator_remap(
    propagation_axis: str,
    spatial_coords: list[str],
) -> dict[str, str | None]:
    """Build a mapping from original operator names to reduced operator names.

    The propagation axis operators are remapped to the first spatial axis ("x").
    Killed-axis operators map to ``None`` (remove the term).
    Dimension-agnostic operators (``identity``, ``laplacian``, ``first_derivative_t``,
    ``biharmonic``) are preserved.

    Parameters
    ----------
    propagation_axis : str
        The surviving spatial axis (e.g., "z").
    spatial_coords : list[str]
        All spatial coordinates (e.g., ["x", "y", "z"]).

    Returns
    -------
    dict[str, str | None]
        Mapping from original operator name to new name (or None to remove).
    """
    remap: dict[str, str | None] = {}

    for axis in spatial_coords:
        if axis == propagation_axis:
            remap[f"laplacian_{axis}"] = "laplacian_x"
            remap[f"gradient_{axis}"] = "gradient_x"
        else:
            remap[f"laplacian_{axis}"] = None
            remap[f"gradient_{axis}"] = None

    # All cross derivatives involve at least one killed axis → remove
    for i, a in enumerate(spatial_coords):
        for b in spatial_coords[i + 1 :]:
            remap[f"cross_derivative_{a}{b}"] = None

    # Dimension-agnostic operators preserved as-is
    remap.update(
        {
            op: op
            for op in (
                "identity",
                "laplacian",
                "first_derivative_t",
                "biharmonic",
                "time_derivative",
            )
        }
    )

    return remap


def _remap_coord_string(expr: str, propagation_axis: str) -> str:
    """Remap coordinate references in a Wolfram expression string.

    Replaces ``{prop_axis}[]`` with ``x[]`` if the propagation axis
    is not already "x".

    Parameters
    ----------
    expr : str
        Wolfram expression string (e.g., ``"2/z[]"``).
    propagation_axis : str
        The surviving spatial axis.

    Returns
    -------
    str
        Expression with coordinate references remapped.
    """
    if propagation_axis == "x":
        return expr
    return re.sub(
        rf"\b{re.escape(propagation_axis)}\s*\[\s*\]",
        "x[]",
        expr,
    )


def _check_killed_coord_refs(
    expr: str,
    killed: set[str],
    context: str,
) -> None:
    """Raise if expression references any killed coordinate.

    Parameters
    ----------
    expr : str
        Wolfram expression string to check.
    killed : set[str]
        Set of killed coordinate names.
    context : str
        Description of where this expression appears (for error messages).

    Raises
    ------
    ValueError
        If ``expr`` references any coordinate in ``killed``.
    """
    for match in _COORD_REF_RE.finditer(expr):
        coord = match.group(1)
        if coord in killed:
            msg = (
                f"Incompatible plane-wave reduction: {context} "
                f"references killed coordinate '{coord}[]' in expression: {expr}"
            )
            raise ValueError(msg)


def _remap_operator(op: str, remap: Mapping[str, str | None]) -> str | None:
    """Look up an operator in the remap table.

    Returns the new name, ``None`` (remove term), or the original name
    if not in the table (passthrough for unknown operators).
    """
    if op in remap:
        return remap[op]
    for prefix in _AXIS_OP_PREFIXES:
        if op.startswith(prefix):
            return None
    return op


def _remap_coord_deps(
    deps: list[str],
    propagation_axis: str,
    killed: set[str],
    context: str,
) -> list[str]:
    """Remap a ``coordinate_dependent`` array.

    Raises
    ------
    ValueError
        If any entry in ``deps`` is a killed coordinate.
    """
    for dep in deps:
        if dep in killed:
            msg = (
                f"Incompatible plane-wave reduction: {context} "
                f"has coordinate_dependent entry '{dep}' which is "
                f"a killed coordinate"
            )
            raise ValueError(msg)
    return [
        "x" if (dep == propagation_axis and propagation_axis != "x") else dep
        for dep in deps
    ]


def _transform_term(
    term: dict[str, Any],
    remap: dict[str, str | None],
    propagation_axis: str,
    killed: set[str],
) -> dict[str, Any] | None:
    """Transform a single RHS term, returning None if it should be removed."""
    op = term.get("operator", "")
    new_op = _remap_operator(op, remap)
    if new_op is None:
        return None

    result = copy.deepcopy(term)
    result["operator"] = new_op

    if "coefficient_symbolic" in result:
        sym = result["coefficient_symbolic"]
        _check_killed_coord_refs(sym, killed, f"term operator={op}")
        result["coefficient_symbolic"] = _remap_coord_string(sym, propagation_axis)

    if "coordinate_dependent" in result:
        result["coordinate_dependent"] = _remap_coord_deps(
            result["coordinate_dependent"],
            propagation_axis,
            killed,
            f"term operator={op}",
        )

    return result


def _transform_hamiltonian_term(
    term: dict[str, Any],
    remap: dict[str, str | None],
    propagation_axis: str,
    killed: set[str],
) -> dict[str, Any] | None:
    """Transform a single Hamiltonian bilinear term."""
    for factor_key in ("factor_a", "factor_b"):
        factor = term.get(factor_key, {})
        op = factor.get("operator", "")
        if _remap_operator(op, remap) is None:
            return None

    result = copy.deepcopy(term)

    for factor_key in ("factor_a", "factor_b"):
        factor = result[factor_key]
        op = factor["operator"]
        new_op = _remap_operator(op, remap)
        factor["operator"] = new_op  # type: ignore[assignment]

    if "coefficient_symbolic" in result:
        sym = result["coefficient_symbolic"]
        _check_killed_coord_refs(sym, killed, "hamiltonian_term coefficient")
        result["coefficient_symbolic"] = _remap_coord_string(sym, propagation_axis)

    if "coordinate_dependent" in result:
        result["coordinate_dependent"] = _remap_coord_deps(
            result["coordinate_dependent"],
            propagation_axis,
            killed,
            "hamiltonian_term",
        )

    return result


def _reduce_equations(
    result: dict[str, Any],
    remap: dict[str, str | None],
    propagation_axis: str,
    killed: set[str],
) -> tuple[set[str], list[str]]:
    """Transform equations and return (surviving_fields, eliminated_fields)."""
    original_fields: set[str] = {eq["field"] for eq in result.get("equations", [])}
    new_equations: list[dict[str, Any]] = []

    for eq in result.get("equations", []):
        terms = eq.get("rhs", {}).get("terms", [])
        new_terms = [
            t
            for term in terms
            if (t := _transform_term(term, remap, propagation_axis, killed)) is not None
        ]
        if not new_terms:
            continue
        new_eq = copy.deepcopy(eq)
        new_eq["rhs"]["terms"] = new_terms
        new_equations.append(new_eq)

    surviving: set[str] = {eq["field"] for eq in new_equations}
    result["equations"] = new_equations

    # Filter and reindex fields
    result["fields"] = [f for f in result.get("fields", []) if f["name"] in surviving]
    for idx, field in enumerate(result["fields"]):
        field["index"] = idx

    return surviving, sorted(original_fields - surviving)


def _reduce_hamiltonian(
    canonical: dict[str, Any],
    surviving_fields: set[str],
    remap: dict[str, str | None],
    propagation_axis: str,
    killed: set[str],
) -> None:
    """Filter and remap hamiltonian_terms in-place."""
    if "hamiltonian_terms" not in canonical:
        return
    new_terms: list[dict[str, Any]] = []
    for hterm in canonical["hamiltonian_terms"]:
        fa_field = hterm.get("factor_a", {}).get("field", "")
        fb_field = hterm.get("factor_b", {}).get("field", "")
        if fa_field not in surviving_fields or fb_field not in surviving_fields:
            continue
        transformed = _transform_hamiltonian_term(
            hterm,
            remap,
            propagation_axis,
            killed,
        )
        if transformed is not None:
            new_terms.append(transformed)
    canonical["hamiltonian_terms"] = new_terms


def _reduce_volume_element(
    canonical: dict[str, Any],
    propagation_axis: str,
    killed: set[str],
) -> None:
    """Handle volume element: remap or error if incompatible.

    Raises
    ------
    ValueError
        If volume element references killed coordinates.
    """
    if "volume_element" not in canonical:
        return

    vol_expr: str = canonical["volume_element"]
    for match in _COORD_REF_RE.finditer(vol_expr):
        if match.group(1) in killed:
            msg = (
                f"Volume element '{vol_expr}' references killed coordinate(s). "
                f"The Wolfram pipeline's factored volume element computation "
                f"should have handled this. The reduction may be incompatible "
                f"with this coordinate system for energy measurement."
            )
            raise ValueError(msg)

    new_vol = _remap_coord_string(vol_expr, propagation_axis)
    if new_vol.strip() in {"1", "1.", "1.0"}:
        del canonical["volume_element"]
    else:
        canonical["volume_element"] = new_vol


def _filter_coupling_matrices(
    result: dict[str, Any],
    surviving_indices: list[int],
) -> None:
    """Remove rows/columns for eliminated fields from coupling matrices.

    Parameters
    ----------
    result : dict
        Spec dict containing a ``"coupling"`` section.
    surviving_indices : list[int]
        Original field indices that survive (in order).
    """
    coupling: dict[str, Any] = result.get("coupling", {})
    if not coupling:
        return

    for matrix_key in (
        "mass_matrix",
        "coupling_matrix",
        "mass_matrix_symbolic",
        "coupling_matrix_symbolic",
    ):
        if matrix_key not in coupling:
            continue
        raw = coupling[matrix_key]
        if not isinstance(raw, list) or not raw:
            continue
        matrix = cast("list[list[Any]]", raw)
        new_matrix: list[list[Any]] = []
        for i in surviving_indices:
            if i < len(matrix):
                row = matrix[i]
                new_row: list[Any] = [row[j] for j in surviving_indices if j < len(row)]
                new_matrix.append(new_row)
        coupling[matrix_key] = new_matrix


[docs] def reduce_spec( spec_data: dict[str, Any], reduction_config: dict[str, Any], ) -> dict[str, Any]: """Apply plane-wave dimensional reduction to a JSON equation spec. Transforms a higher-dimensional spec into a clean 1+1D spec by: 1. Remapping operator names (``laplacian_z → laplacian_x``) 2. Removing terms with killed-axis operators 3. Updating spacetime metadata (dimension, signature, coordinates) 4. Remapping coordinate references in coefficient expressions 5. Handling volume element (keep if surviving-coord-only, error if not) 6. Adding provenance metadata Parameters ---------- spec_data : dict Parsed JSON spec (as loaded by ``json.loads``). reduction_config : dict The ``[reduction]`` section from the TOML config. Returns ------- dict Transformed spec with reduced dimension. """ result = copy.deepcopy(spec_data) propagation_axis: str = reduction_config["propagation_axis"] spacetime: dict[str, Any] = result["spacetime"] coords: list[str] = spacetime["coordinates"] spatial_coords = [c for c in coords if c != "t"] killed = {c for c in spatial_coords if c != propagation_axis} remap = _build_operator_remap(propagation_axis, spatial_coords) surviving_fields, eliminated_fields = _reduce_equations( result, remap, propagation_axis, killed, ) canonical: dict[str, Any] = result.get("canonical", {}) _reduce_hamiltonian(canonical, surviving_fields, remap, propagation_axis, killed) _reduce_volume_element(canonical, propagation_axis, killed) spacetime["dimension"] = 2 spacetime["signature"] = [-1, 1] spacetime["coordinates"] = ["t", "x"] surviving_indices = [ i for i, eq in enumerate(spec_data.get("equations", [])) if str(eq.get("field", "")) in surviving_fields ] _filter_coupling_matrices(result, surviving_indices) metadata: dict[str, Any] = result.get("metadata", {}) metadata["reduction"] = { "type": reduction_config["type"], "original_dimension": spec_data["spacetime"]["dimension"], "propagation_axis": propagation_axis, "eliminated_fields": eliminated_fields, } result["metadata"] = metadata return result