diff --git a/docs/api/utilities/spectral.md b/docs/api/utilities/spectral.md new file mode 100644 index 0000000..e90e9b5 --- /dev/null +++ b/docs/api/utilities/spectral.md @@ -0,0 +1,7 @@ +# Fourier-spectral utilities + +::: exponax.fft + +--- + +::: exponax.ifft \ No newline at end of file diff --git a/exponax/__init__.py b/exponax/__init__.py index 5c2b5f4..804849f 100644 --- a/exponax/__init__.py +++ b/exponax/__init__.py @@ -1,10 +1,11 @@ from . import _metrics as metrics from . import _poisson as poisson +from . import _spectral as spectral from . import etdrk, ic, nonlin_fun, normalized, reaction, stepper, viz from ._base_stepper import BaseStepper from ._forced_stepper import ForcedStepper from ._repeated_stepper import RepeatedStepper -from ._spectral import derivative, make_incompressible +from ._spectral import derivative, fft, ifft, make_incompressible from ._utils import ( build_ic_set, make_grid, @@ -22,6 +23,8 @@ "poisson", "RepeatedStepper", "derivative", + "fft", + "ifft", "make_incompressible", "make_grid", "rollout", @@ -37,4 +40,5 @@ "reaction", "stepper", "viz", + "spectral", ] diff --git a/exponax/_base_stepper.py b/exponax/_base_stepper.py index 54683bf..c41a8a2 100644 --- a/exponax/_base_stepper.py +++ b/exponax/_base_stepper.py @@ -1,12 +1,12 @@ from abc import ABC, abstractmethod import equinox as eqx -import jax.numpy as jnp from jaxtyping import Array, Complex, Float from ._spectral import ( build_derivative_operator, - space_indices, + fft, + ifft, spatial_shape, wavenumber_shape, ) @@ -202,12 +202,12 @@ def step(self, u: Float[Array, "C ... N"]) -> Float[Array, "C ... N"]: **Returns:** - `u_next`: The state vector after one step, shape `(C, ..., N,)`. """ - u_hat = jnp.fft.rfftn(u, axes=space_indices(self.num_spatial_dims)) + u_hat = fft(u, num_spatial_dims=self.num_spatial_dims) u_next_hat = self.step_fourier(u_hat) - u_next = jnp.fft.irfftn( + u_next = ifft( u_next_hat, - s=spatial_shape(self.num_spatial_dims, self.num_points), - axes=space_indices(self.num_spatial_dims), + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, ) return u_next diff --git a/exponax/_metrics.py b/exponax/_metrics.py index 0b709eb..ff17eae 100644 --- a/exponax/_metrics.py +++ b/exponax/_metrics.py @@ -4,7 +4,7 @@ import jax.numpy as jnp from jaxtyping import Array, Float -from ._spectral import low_pass_filter_mask, space_indices +from ._spectral import fft, low_pass_filter_mask def _MSE( @@ -505,8 +505,8 @@ def _fourier_nRMSE( mask = jnp.invert(low_mask) & high_mask - u_pred_fft = jnp.fft.rfftn(u_pred, axes=space_indices(num_spatial_dims)) - u_ref_fft = jnp.fft.rfftn(u_ref, axes=space_indices(num_spatial_dims)) + u_pred_fft = fft(u_pred, num_spatial_dims=num_spatial_dims) + u_ref_fft = fft(u_ref, num_spatial_dims=num_spatial_dims) # The FFT incurse rounding errors around the machine precision that can be # noticeable in the nRMSE. We will zero out the values that are smaller than diff --git a/exponax/_poisson.py b/exponax/_poisson.py index 74f6d51..19168ff 100644 --- a/exponax/_poisson.py +++ b/exponax/_poisson.py @@ -5,7 +5,8 @@ from ._spectral import ( build_derivative_operator, build_laplace_operator, - space_indices, + fft, + ifft, spatial_shape, ) @@ -90,12 +91,12 @@ def step( **Returns:** - `u`: The solution. """ - f_hat = jnp.fft.rfftn(f, axes=space_indices(self.num_spatial_dims)) + f_hat = fft(f, num_spatial_dims=self.num_spatial_dims) u_hat = self.step_fourier(f_hat) - u = jnp.fft.irfftn( + u = ifft( u_hat, - axes=space_indices(self.num_spatial_dims), - s=spatial_shape(self.num_spatial_dims, self.num_points), + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, ) return u diff --git a/exponax/_spectral.py b/exponax/_spectral.py index fd8a635..c814646 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -1,4 +1,4 @@ -from typing import TypeVar, Union +from typing import Optional, TypeVar, Union import jax.numpy as jnp from jaxtyping import Array, Bool, Complex, Float @@ -71,113 +71,6 @@ def build_scaled_wavenumbers( return scale * wavenumbers -def derivative( - field: Float[Array, "C ... N"], - domain_extent: float, - *, - order: int = 1, - indexing: str = "ij", -) -> Union[Float[Array, "C D ... (N//2)+1"], Float[Array, "D ... (N//2)+1"]]: - """ - Perform the spectral derivative of a field. In higher dimensions, this - defaults to the gradient (the collection of all partial derivatives). In 1d, - the resulting channel dimension holds the derivative. If the function is - called with an d-dimensional field which has 1 channel, the result will be a - d-dimensional field with d channels (one per partial derivative). If the - field originally had C channels, the result will be a matrix field with C - rows and d columns. - - Note that applying this operator twice will produce issues at the Nyquist if - the number of degrees of freedom N is even. For this, consider also using - the order option. - - **Arguments:** - - `field`: The field to differentiate, shape `(C, ..., N,)`. `C` can be - `1` for a scalar field or `D` for a vector field. - - `L`: The domain extent. - - `order`: The order of the derivative. Default is `1`. - - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. - Either `"ij"` or `"xy"`. Default is `"ij"`. - - **Returns:** - - `field_der`: The derivative of the field, shape `(C, D, ..., - (N//2)+1)` or `(D, ..., (N//2)+1)`. - """ - channel_shape = field.shape[0] - spatial_shape = field.shape[1:] - D = len(spatial_shape) - N = spatial_shape[0] - derivative_operator = build_derivative_operator( - D, domain_extent, N, indexing=indexing - ) - # # I decided to not use this fix - - # # Required for even N, no effect for odd N - # derivative_operator_fixed = ( - # derivative_operator * nyquist_filter_mask(D, N) - # ) - derivative_operator_fixed = derivative_operator**order - - field_hat = jnp.fft.rfftn(field, axes=space_indices(D)) - if channel_shape == 1: - # Do not introduce another channel axis - field_der_hat = derivative_operator_fixed * field_hat - else: - # Create a "derivative axis" right after the channel axis - field_der_hat = field_hat[:, None] * derivative_operator_fixed[None, ...] - - field_der = jnp.fft.irfftn(field_der_hat, s=spatial_shape, axes=space_indices(D)) - - return field_der - - -def make_incompressible( - field: Float[Array, "D ... N"], - *, - indexing: str = "ij", -): - channel_shape = field.shape[0] - spatial_shape = field.shape[1:] - num_spatial_dims = len(spatial_shape) - if channel_shape != num_spatial_dims: - raise ValueError( - f"Expected the number of channels to be {num_spatial_dims}, got {channel_shape}." - ) - num_points = spatial_shape[0] - - derivative_operator = build_derivative_operator( - num_spatial_dims, 1.0, num_points, indexing=indexing - ) # domain_extent does not matter because it will cancel out - - incompressible_field_hat = jnp.fft.rfftn( - field, axes=space_indices(num_spatial_dims) - ) - - divergence = jnp.sum( - derivative_operator * incompressible_field_hat, axis=0, keepdims=True - ) - - laplace_operator = build_laplace_operator(derivative_operator) - - inv_laplace_operator = jnp.where( - laplace_operator == 0, - 1.0, - 1.0 / laplace_operator, - ) - - pseudo_pressure = -inv_laplace_operator * divergence - - pseudo_pressure_garadient = derivative_operator * pseudo_pressure - - incompressible_field_hat = incompressible_field_hat - pseudo_pressure_garadient - - incompressible_field = jnp.fft.irfftn( - incompressible_field_hat, s=spatial_shape, axes=space_indices(num_spatial_dims) - ) - - return incompressible_field - - def build_derivative_operator( num_spatial_dims: int, domain_extent: float, @@ -470,3 +363,203 @@ def build_scaling_array( ) return scaling + + +def fft( + field: Float[Array, "C ... N"], + *, + num_spatial_dims: Optional[int] = None, +) -> Complex[Array, "C ... (N//2)+1"]: + """ + Perform a **real-valued** FFT of a field. This function is designed for + states in `Exponax` with a leading channel axis and then one, two, or three + following spatial axes, **each of the same length** N. + + Only accepts real-valued input fields and performs a real-valued FFT. Hence, + the last axis of the returned field is of length N//2+1. + + **Arguments:** + - `field`: The field to transform, shape `(C, ..., N,)`. + - `num_spatial_dims`: The number of spatial dimensions, i.e., how many + spatial axes follow the channel axis. Can be inferred from the array + if it follows the Exponax convention. For example, it is not allowed + to have a leading batch axis, in such a case use `jax.vmap` on this + function. + + **Returns:** + - `field_hat`: The transformed field, shape `(C, ..., N//2+1)`. + + !!! info + Internally uses `jax.numpy.fft.rfftn` with the default settings for the + `norm` argument with `norm="backward"`. This means that the forward FFT + (this function) does not apply any normalization to the result, only the + [`exponax.ifft`][] function applies normalization. + """ + if num_spatial_dims is None: + num_spatial_dims = field.ndim - 1 + + return jnp.fft.rfftn(field, axes=space_indices(num_spatial_dims)) + + +def ifft( + field_hat: Complex[Array, "C ... (N//2)+1"], + *, + num_spatial_dims: Optional[int] = None, + num_points: Optional[int] = None, +) -> Float[Array, "C ... N"]: + """ + Perform the inverse **real-valued** FFT of a field. This is the inverse + operation of `fft`. This function is designed for states in `Exponax` with a + leading channel axis and then one, two, or three following spatial axes. In + state space all spatial axes have the same length N (here called + `num_points`). + + Requires a complex-valued field in Fourier space with the last axis of + length N//2+1. + + The number of points (N, or `num_points`) must be provided if the number of + spatial dimensions is 1. Otherwise, it can be inferred from the shape of the + field. + + **Arguments:** + - `field_hat`: The transformed field, shape `(C, ..., N//2+1)`. + - `num_spatial_dims`: The number of spatial dimensions, i.e., how many + spatial axes follow the channel axis. Can be inferred from the array + if it follows the Exponax convention. For example, it is not allowed + to have a leading batch axis, in such a case use `jax.vmap` on this + function. + - `num_points`: The number of points in each spatial dimension. Can be + inferred if `num_spatial_dims` >= 2 + + **Returns:** + - `field`: The transformed field, shape `(C, ..., N,)`. + + !!! info + Internally uses `jax.numpy.fft.irfftn` with the default settings for the + `norm` argument with `norm="backward"`. This means that the forward FFT + [`exponax.fft`][] function does not apply any normalization to the + input, only the inverse FFT (this function) applies normalization. + Hence, if you want to define a state in Fourier space and inversely + transform it, consider using [`exponax.spectral.build_scaling_array`][] + to correctly scale the complex values before transforming them back. + """ + if num_spatial_dims is None: + num_spatial_dims = field_hat.ndim - 1 + + if num_points is None: + if num_spatial_dims >= 2: + num_points = field_hat.shape[-2] + else: + raise ValueError("num_points must be provided if num_spatial_dims == 1.") + return jnp.fft.irfftn( + field_hat, + s=spatial_shape(num_spatial_dims, num_points), + axes=space_indices(num_spatial_dims), + ) + + +def derivative( + field: Float[Array, "C ... N"], + domain_extent: float, + *, + order: int = 1, + indexing: str = "ij", +) -> Union[Float[Array, "C D ... (N//2)+1"], Float[Array, "D ... (N//2)+1"]]: + """ + Perform the spectral derivative of a field. In higher dimensions, this + defaults to the gradient (the collection of all partial derivatives). In 1d, + the resulting channel dimension holds the derivative. If the function is + called with an d-dimensional field which has 1 channel, the result will be a + d-dimensional field with d channels (one per partial derivative). If the + field originally had C channels, the result will be a matrix field with C + rows and d columns. + + Note that applying this operator twice will produce issues at the Nyquist if + the number of degrees of freedom N is even. For this, consider also using + the order option. + + **Arguments:** + - `field`: The field to differentiate, shape `(C, ..., N,)`. `C` can be + `1` for a scalar field or `D` for a vector field. + - `L`: The domain extent. + - `order`: The order of the derivative. Default is `1`. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. + + **Returns:** + - `field_der`: The derivative of the field, shape `(C, D, ..., + (N//2)+1)` or `(D, ..., (N//2)+1)`. + """ + channel_shape = field.shape[0] + spatial_shape = field.shape[1:] + D = len(spatial_shape) + N = spatial_shape[0] + derivative_operator = build_derivative_operator( + D, domain_extent, N, indexing=indexing + ) + # # I decided to not use this fix + + # # Required for even N, no effect for odd N + # derivative_operator_fixed = ( + # derivative_operator * nyquist_filter_mask(D, N) + # ) + derivative_operator_fixed = derivative_operator**order + + field_hat = fft(field, num_spatial_dims=D) + if channel_shape == 1: + # Do not introduce another channel axis + field_der_hat = derivative_operator_fixed * field_hat + else: + # Create a "derivative axis" right after the channel axis + field_der_hat = field_hat[:, None] * derivative_operator_fixed[None, ...] + + field_der = ifft(field_der_hat, num_spatial_dims=D, num_points=N) + + return field_der + + +def make_incompressible( + field: Float[Array, "D ... N"], + *, + indexing: str = "ij", +): + channel_shape = field.shape[0] + spatial_shape = field.shape[1:] + num_spatial_dims = len(spatial_shape) + if channel_shape != num_spatial_dims: + raise ValueError( + f"Expected the number of channels to be {num_spatial_dims}, got {channel_shape}." + ) + num_points = spatial_shape[0] + + derivative_operator = build_derivative_operator( + num_spatial_dims, 1.0, num_points, indexing=indexing + ) # domain_extent does not matter because it will cancel out + + incompressible_field_hat = fft(field, num_spatial_dims=num_spatial_dims) + + divergence = jnp.sum( + derivative_operator * incompressible_field_hat, axis=0, keepdims=True + ) + + laplace_operator = build_laplace_operator(derivative_operator) + + inv_laplace_operator = jnp.where( + laplace_operator == 0, + 1.0, + 1.0 / laplace_operator, + ) + + pseudo_pressure = -inv_laplace_operator * divergence + + pseudo_pressure_garadient = derivative_operator * pseudo_pressure + + incompressible_field_hat = incompressible_field_hat - pseudo_pressure_garadient + + incompressible_field = ifft( + incompressible_field_hat, + num_spatial_dims=num_spatial_dims, + num_points=num_points, + ) + + return incompressible_field diff --git a/exponax/ic/_gaussian_random_field.py b/exponax/ic/_gaussian_random_field.py index eb48c22..6fd1435 100644 --- a/exponax/ic/_gaussian_random_field.py +++ b/exponax/ic/_gaussian_random_field.py @@ -5,8 +5,7 @@ from .._spectral import ( build_scaled_wavenumbers, build_scaling_array, - space_indices, - spatial_shape, + ifft, wavenumber_shape, ) from ._base_ic import BaseRandomICGenerator @@ -82,11 +81,7 @@ def __call__( noise = noise * build_scaling_array(self.num_spatial_dims, num_points) - ic = jnp.fft.irfftn( - noise, - s=spatial_shape(self.num_spatial_dims, num_points), - axes=space_indices(self.num_spatial_dims), - ) + ic = ifft(noise, num_spatial_dims=self.num_spatial_dims, num_points=num_points) if self.zero_mean: ic = ic - jnp.mean(ic) diff --git a/exponax/ic/_truncated_fourier_series.py b/exponax/ic/_truncated_fourier_series.py index c93c303..c200a2f 100644 --- a/exponax/ic/_truncated_fourier_series.py +++ b/exponax/ic/_truncated_fourier_series.py @@ -4,9 +4,8 @@ from .._spectral import ( build_scaling_array, + ifft, low_pass_filter_mask, - space_indices, - spatial_shape, wavenumber_shape, ) from ._base_ic import BaseRandomICGenerator @@ -131,10 +130,10 @@ def __call__( self.num_spatial_dims, num_points ) - u = jnp.fft.irfftn( + u = ifft( fourier_noise, - s=spatial_shape(self.num_spatial_dims, num_points), - axes=space_indices(self.num_spatial_dims), + num_spatial_dims=self.num_spatial_dims, + num_points=num_points, ) if self.std_one: diff --git a/exponax/nonlin_fun/_base.py b/exponax/nonlin_fun/_base.py index 3f4ad03..d04d992 100644 --- a/exponax/nonlin_fun/_base.py +++ b/exponax/nonlin_fun/_base.py @@ -2,10 +2,9 @@ from typing import Optional import equinox as eqx -import jax.numpy as jnp from jaxtyping import Array, Bool, Complex, Float -from .._spectral import low_pass_filter_mask, space_indices, spatial_shape +from .._spectral import fft, ifft, low_pass_filter_mask class BaseNonlinearFun(eqx.Module, ABC): @@ -45,13 +44,11 @@ def dealias( return self.dealiasing_mask * u_hat def fft(self, u: Float[Array, "C ... N"]) -> Complex[Array, "C ... (N//2)+1"]: - return jnp.fft.rfftn(u, axes=space_indices(self.num_spatial_dims)) + return fft(u, num_spatial_dims=self.num_spatial_dims) def ifft(self, u_hat: Complex[Array, "C ... (N//2)+1"]) -> Float[Array, "C ... N"]: - return jnp.fft.irfftn( - u_hat, - s=spatial_shape(self.num_spatial_dims, self.num_points), - axes=space_indices(self.num_spatial_dims), + return ifft( + u_hat, num_spatial_dims=self.num_spatial_dims, num_points=self.num_points ) @abstractmethod diff --git a/mkdocs.yml b/mkdocs.yml index fd9e29c..63badb6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -157,6 +157,7 @@ nav: - Rollout & Repeat: 'api/utilities/rollout_and_repeat.md' - Grid Generation: 'api/utilities/grid_generation.md' - Derivatives: 'api/utilities/derivatives.md' + - Spectral: 'api/utilities/spectral.md' - Normalized & Difficulty: 'api/utilities/normalized_and_difficulty.md' - Metrics: - MSE-based: 'api/utilities/metrics/mse_based.md'