Source code for tidal.measurement._dispersion

"""Dispersion relation extraction from simulation data.

Extracts the dispersion relation omega(k) by performing a temporal FFT on
the spatial Fourier coefficients phi_hat(k, t) for each spatial mode.  The
resulting 2D power spectrum S(k, omega) reveals how energy is distributed
across wavenumber and frequency, with peak detection extracting the dominant
frequency per k-bin to form the dispersion curve.

The key insight is to use complex spatial FFT coefficients (not amplitudes
|phi_hat|) to avoid frequency doubling artifacts.  The temporal FFT of the
complex coefficient phi_hat(k, t) correctly identifies the oscillation
frequency omega(k).

For a group of fields the group spectral power is summed:

    S_group(k, omega) = sum_i |FFT(phi_i)(k, omega)|^2

This is rotationally covariant within the field group and reveals the actual
physical mode structure without depending on a particular component choice.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

import numpy as np

from tidal.measurement._spectral import (
    _build_k_grid,  # pyright: ignore[reportPrivateUsage]
    _radial_bin,  # pyright: ignore[reportPrivateUsage]
)
from tidal.measurement._utils import (
    _check_no_position_dependent_terms,  # pyright: ignore[reportPrivateUsage]
)

if TYPE_CHECKING:
    from collections.abc import Sequence

    from numpy.typing import NDArray

    from tidal.measurement._io import SimulationData


_MIN_SNAPSHOTS = 3


[docs] @dataclass(frozen=True) class DispersionResult: """Dispersion relation extracted from simulation output. Attributes ---------- wavenumbers : ndarray, shape ``(n_modes,)`` Radially binned wavenumber magnitudes ``|k|``. frequencies : ndarray, shape ``(n_freq,)`` Angular frequencies ``omega`` (rad/time), excluding DC. power : ndarray, shape ``(n_modes, n_freq)`` Spectral power ``S(k, omega) = sum_i |A_hat_i(k, omega)|^2`` summed over all fields in the group. peak_frequencies : ndarray, shape ``(n_modes,)`` Dominant angular frequency at each k-bin (0.0 for inactive modes). peak_powers : ndarray, shape ``(n_modes,)`` Spectral power at the dominant frequency per k-bin. field_name : str Which field (or comma-joined group) this dispersion was computed for. rayleigh_resolution : float ``2*pi/T`` -- minimum resolvable frequency difference. """ wavenumbers: NDArray[np.float64] frequencies: NDArray[np.float64] power: NDArray[np.float64] peak_frequencies: NDArray[np.float64] peak_powers: NDArray[np.float64] field_name: str rayleigh_resolution: float
def _validate_timestep(data: SimulationData) -> None: """Check snapshot count and uniform timestep. Raises ------ ValueError If fewer than 3 snapshots or non-uniform timestep. """ if data.n_snapshots < _MIN_SNAPSHOTS: msg = ( f"Need at least {_MIN_SNAPSHOTS} snapshots for dispersion " f"extraction, got {data.n_snapshots}" ) raise ValueError(msg) times = data.times dt = float(times[1] - times[0]) diffs = np.diff(times) if not np.allclose(diffs, dt, rtol=1e-6): msg = ( "Non-uniform timestep -- temporal FFT requires uniform sampling " f"(dt range: {float(diffs.min()):.6g} to {float(diffs.max()):.6g})" ) raise ValueError(msg) def _spacetime_fft( field_snapshots: NDArray[np.float64], dt: float, ) -> tuple[NDArray[np.float64], NDArray[np.complex128], NDArray[np.float64]]: """Spatial + temporal FFT pipeline. Returns (angular_freqs, spatial_fft, spacetime_power) where spacetime_power has shape ``(n_freq, *rfft_shape)`` and spatial_fft has shape ``(n_snapshots, *rfft_shape)``. """ n_snapshots = field_snapshots.shape[0] # Spatial rfftn per snapshot -> complex coefficients spatial_fft = np.stack( [np.fft.rfftn(field_snapshots[t]) for t in range(n_snapshots)] ) # Temporal fft (complex input -> must use fft, not rfft). # Keep only strictly positive frequencies (skip DC and negative Nyquist). full_fft = np.fft.fft(spatial_fft, axis=0) n_pos = n_snapshots // 2 spacetime_power = np.abs(full_fft[1:n_pos]) ** 2 # Angular frequencies for the positive temporal bins raw_t_freqs = np.fft.fftfreq(n_snapshots, d=dt) angular_freqs = np.asarray( 2.0 * np.pi * raw_t_freqs[1:n_pos], dtype=np.float64, ) return angular_freqs, spatial_fft, spacetime_power def _bin_and_detect( # noqa: PLR0913, PLR0917 angular_freqs: NDArray[np.float64], spatial_amp: NDArray[np.float64], spacetime_power: NDArray[np.float64], grid_shape: tuple[int, ...], grid_spacing: tuple[float, ...], min_amplitude: float, ) -> tuple[ NDArray[np.float64], NDArray[np.float64], NDArray[np.float64], NDArray[np.float64] ]: """Radially bin S(k, omega) and detect peaks. Parameters ---------- spatial_amp : ndarray, shape ``(n_snapshots, *rfft_shape)`` Spatial amplitude array (absolute values, already real). For a single field this is ``|spatial_fft|``; for a group it is the element-wise sum of ``|spatial_fft_i|`` over all fields. Returns (wavenumbers, power, peak_frequencies, peak_powers). """ _k_grid, k_mag = _build_k_grid(grid_shape, grid_spacing) # Establish bin structure wn_ref, _ = _radial_bin(k_mag, np.zeros_like(k_mag), grid_spacing, grid_shape) n_modes = len(wn_ref) n_freq = len(angular_freqs) power = np.zeros((n_modes, n_freq), dtype=np.float64) for fi in range(n_freq): _, binned = _radial_bin(k_mag, spacetime_power[fi], grid_spacing, grid_shape) power[: len(binned), fi] = binned # Max spatial amplitude per k-bin for activity detection max_amp = np.zeros(n_modes, dtype=np.float64) for t in range(spatial_amp.shape[0]): _, binned_amp = _radial_bin(k_mag, spatial_amp[t], grid_spacing, grid_shape) max_amp = np.maximum(max_amp, binned_amp[:n_modes]) # Peak detection per k-bin peak_frequencies = np.zeros(n_modes, dtype=np.float64) peak_powers = np.zeros(n_modes, dtype=np.float64) for ki in range(n_modes): if max_amp[ki] < min_amplitude: continue mode_power = power[ki, :] if np.max(mode_power) > 0.0: peak_idx = int(np.argmax(mode_power)) peak_frequencies[ki] = angular_freqs[peak_idx] peak_powers[ki] = mode_power[peak_idx] return wn_ref, power, peak_frequencies, peak_powers
[docs] def compute_dispersion( # noqa: PLR0914 data: SimulationData, field_names: str | Sequence[str], *, min_amplitude: float = 1e-12, ) -> DispersionResult: """Extract dispersion relation omega(k) from simulation output. Algorithm --------- 1. For each field in *field_names*, compute spatial ``rfftn`` per snapshot to get complex Fourier coefficients ``phi_hat_i(k, t)``. 2. For each spatial mode ``k``, temporal FFT of the **complex** coefficient gives ``S_i(k, omega)``. 3. Sum spectral power across all fields: ``S_group(k, omega) = sum_i S_i(k, omega)``. 4. Radially bin ``S_group(k, omega)`` into ``|k|`` shells. 5. Peak detection: ``argmax(S_group)`` per k-bin extracts ``omega(k)``. 6. Modes with combined max amplitude below threshold are inactive. Using complex coefficients (not ``|phi_hat|``) avoids frequency doubling artifacts from taking the absolute value before FFT. Summing power over a field group makes the measurement rotationally covariant within the group (no dependence on which single component is selected). Parameters ---------- data : SimulationData Simulation output with time-resolved field snapshots. field_names : str or sequence of str Field or group of fields to extract the dispersion relation for. All fields must be dynamical (``time_derivative_order >= 2``); constraint fields raise ``ValueError``. min_amplitude : float, optional Modes with combined max ``|phi_hat(k, t)|`` below this threshold are treated as inactive. Default ``1e-12``. Returns ------- DispersionResult Raises ------ ValueError If any field is unknown, is a constraint field (``time_derivative_order < 2``), fewer than 3 snapshots, the timestep is non-uniform, or any equation term is position-dependent (uniform medium required). """ _check_no_position_dependent_terms(data, "Dispersion relation omega(k)") # Normalize to list if isinstance(field_names, str): names: list[str] = [field_names] else: names = list(field_names) if not names: msg = "field_names must not be empty" raise ValueError(msg) # Validate all field names and check they are dynamical for fname in names: if fname not in data.spec.component_names: msg = f"Field '{fname}' not in spec fields: {data.spec.component_names}" raise ValueError(msg) eq = next(e for e in data.spec.equations if e.field_name == fname) if eq.time_derivative_order < 2: # noqa: PLR2004 msg = ( f"Field '{fname}' is a constraint (time_derivative_order=" f"{eq.time_derivative_order}). Dispersion requires a dynamical " f"field (time_derivative_order >= 2)." ) raise ValueError(msg) _validate_timestep(data) dt = float(data.times[1] - data.times[0]) rayleigh = 2.0 * np.pi / float(data.times[-1] - data.times[0]) # First field: seed the accumulators first_snapshots = data.fields[names[0]] angular_freqs, first_sfft, spacetime_power_total = _spacetime_fft( first_snapshots, dt ) spatial_amp_sum = np.abs(first_sfft) grid_shape = first_snapshots.shape[1:] # Remaining fields: add their spectral power and amplitude for fname in names[1:]: field_snapshots = data.fields[fname] _, sfft, power = _spacetime_fft(field_snapshots, dt) spacetime_power_total += power spatial_amp_sum += np.abs(sfft) wn, power_binned, peak_freqs, peak_pow = _bin_and_detect( angular_freqs, spatial_amp_sum, spacetime_power_total, grid_shape, data.grid_spacing, min_amplitude, ) return DispersionResult( wavenumbers=wn, frequencies=angular_freqs, power=power_binned, peak_frequencies=peak_freqs, peak_powers=peak_pow, field_name=", ".join(names), rayleigh_resolution=rayleigh, )