Source code for tidal.solver.grid

"""GridInfo — lightweight Cartesian grid descriptor for TIDAL.

Replaces py-pde's CartesianGrid with a minimal frozen dataclass that provides
only what TIDAL needs: grid spacing for FD stencils, cell coordinates for
position-dependent coefficient evaluation, and periodicity flags for BC dispatch.
"""

from __future__ import annotations

import math
from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:
    from tidal.solver.operators import AxisBCSpec


[docs] @dataclass(frozen=True) class GridInfo: """Immutable Cartesian grid descriptor. Parameters ---------- bounds : tuple of (lo, hi) pairs Domain bounds per spatial axis, e.g. ``((0, 10), (0, 10))``. shape : tuple of int Number of grid cells per axis, e.g. ``(64, 64)``. periodic : tuple of bool Whether each axis is periodic, e.g. ``(True, True)``. bc : tuple of str, optional Legacy per-axis BC type strings (e.g. ``("neumann", "periodic")``). axis_bcs : tuple of AxisBCSpec, optional Structured per-axis BC specs with per-side values and Robin support. Takes precedence over ``bc`` when set. Examples -------- >>> g = GridInfo(bounds=((0, 10),), shape=(64,), periodic=(True,)) >>> g.ndim 1 >>> g.dx (0.15625,) >>> g.num_points 64 """ bounds: tuple[tuple[float, float], ...] shape: tuple[int, ...] periodic: tuple[bool, ...] bc: tuple[str, ...] | None = None axis_bcs: tuple[AxisBCSpec, ...] | None = None _VALID_BC = frozenset({"periodic", "neumann", "dirichlet", "robin"}) def __post_init__(self) -> None: # noqa: C901, PLR0912 if len(self.bounds) != len(self.shape): msg = ( f"bounds has {len(self.bounds)} axes but shape has " f"{len(self.shape)} — must match" ) raise ValueError(msg) if len(self.bounds) != len(self.periodic): msg = ( f"bounds has {len(self.bounds)} axes but periodic has " f"{len(self.periodic)} — must match" ) raise ValueError(msg) for i, (lo, hi) in enumerate(self.bounds): if hi <= lo: msg = f"bounds[{i}] = ({lo}, {hi}): upper must exceed lower" raise ValueError(msg) for i, n in enumerate(self.shape): if n < 2: # noqa: PLR2004 msg = f"shape[{i}] = {n}: need at least 2 cells per axis" raise ValueError(msg) if self.bc is not None: if len(self.bc) != len(self.bounds): msg = ( f"bc has {len(self.bc)} entries but grid has " f"{len(self.bounds)} axes — must match" ) raise ValueError(msg) for i, b in enumerate(self.bc): if b not in self._VALID_BC: msg = f"bc[{i}] = {b!r}: must be one of {sorted(self._VALID_BC)}" raise ValueError(msg) is_periodic_bc = b == "periodic" if is_periodic_bc != self.periodic[i]: msg = ( f"bc[{i}] = {b!r} but periodic[{i}] = {self.periodic[i]} " f"— these must be consistent" ) raise ValueError(msg) if self.axis_bcs is not None: if len(self.axis_bcs) != len(self.bounds): msg = ( f"axis_bcs has {len(self.axis_bcs)} entries but grid has " f"{len(self.bounds)} axes — must match" ) raise ValueError(msg) for i, abc in enumerate(self.axis_bcs): if abc.periodic != self.periodic[i]: msg = ( f"axis_bcs[{i}].periodic = {abc.periodic} but " f"periodic[{i}] = {self.periodic[i]} — must be consistent" ) raise ValueError(msg) @cached_property def ndim(self) -> int: """Number of spatial dimensions.""" return len(self.shape) @cached_property def num_points(self) -> int: """Total number of grid cells (product of shape).""" return math.prod(self.shape) @cached_property def effective_bc(self) -> tuple[str, ...] | tuple[AxisBCSpec, ...]: """Per-axis BC specs. Returns ``AxisBCSpec`` tuple if ``axis_bcs`` is set (structured BCs with per-side values), otherwise falls back to string tuple from ``bc`` or infers from ``periodic``. """ if self.axis_bcs is not None: return self.axis_bcs if self.bc is not None: return self.bc return tuple("periodic" if p else "neumann" for p in self.periodic) @cached_property def bc_types(self) -> tuple[str, ...]: """Per-axis BC type as simple strings. Always returns ``tuple[str, ...]`` — one of ``"periodic"``, ``"neumann"``, ``"dirichlet"``, or ``"robin"`` per axis. For structured ``AxisBCSpec`` entries, uses the low-side kind as the representative type (matching the solver's ghost cell convention). This is the canonical form stored in metadata and used by the energy module for BC-aware gradient computation. """ if self.axis_bcs is not None: result: list[str] = [] for abc in self.axis_bcs: if abc.periodic: result.append("periodic") elif abc.low is not None: result.append(abc.low.kind) else: result.append("neumann") return tuple(result) if self.bc is not None: return self.bc return tuple("periodic" if p else "neumann" for p in self.periodic) @cached_property def dx(self) -> tuple[float, ...]: """Grid spacing per axis (cell-centred: dx = (hi - lo) / N).""" return tuple( (hi - lo) / n for (lo, hi), n in zip(self.bounds, self.shape, strict=False) )
[docs] def axes_coords(self, axis: int) -> np.ndarray: """1-D array of cell centres along *axis*. Returns shape ``(shape[axis],)`` — a single 1-D vector, not a broadcasted grid. Useful for building per-axis coordinate arrays without the overhead of full meshgrid. Raises ------ IndexError If *axis* is out of range. """ if axis < 0 or axis >= self.ndim: msg = f"axis {axis} out of range for {self.ndim}-D grid" raise IndexError(msg) lo, hi = self.bounds[axis] n = self.shape[axis] d = (hi - lo) / n return np.linspace(lo + d / 2, hi - d / 2, n)
@cached_property def cell_coords(self) -> np.ndarray: """Cell-centre coordinates, shape ``(*shape, ndim)``. For a 1D grid with bounds (0, 10) and shape (4,), the cell centres are at [1.25, 3.75, 6.25, 8.75] (centred in each cell). Unlike py-pde's ``CartesianGrid.cell_coords`` (which returns a list of arrays with inconsistent shapes), this always returns a single ndarray with the coordinate dimension as the last axis. Cached on first access (GridInfo is immutable). """ grids = np.meshgrid( *(self.axes_coords(i) for i in range(self.ndim)), indexing="ij", ) return np.stack(grids, axis=-1)
[docs] def coord_arrays(self) -> tuple[np.ndarray, ...]: """Broadcasted coordinate arrays, one per axis. Returns ``ndim`` arrays, each of shape ``self.shape``, containing the coordinate value along that axis at every grid point. Equivalent to ``np.meshgrid(*1d_axes, indexing="ij")``. Cached on first access (GridInfo is immutable). """ try: return self._cached_coord_arrays # type: ignore[attr-defined] except AttributeError: result = tuple( np.meshgrid( *(self.axes_coords(i) for i in range(self.ndim)), indexing="ij", ) ) object.__setattr__(self, "_cached_coord_arrays", result) # noqa: PLC2801 return result