Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap FFT for specific array structures in Exponax #4

Merged
merged 15 commits into from
Sep 2, 2024
7 changes: 7 additions & 0 deletions docs/api/utilities/spectral.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Fourier-spectral utilities

::: exponax.fft

---

::: exponax.ifft
6 changes: 5 additions & 1 deletion exponax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from . import _metrics as metrics
from . import _poisson as poisson
from . import _spectral as spectral
from . import etdrk, ic, nonlin_fun, normalized, reaction, stepper, viz
from ._base_stepper import BaseStepper
from ._forced_stepper import ForcedStepper
from ._repeated_stepper import RepeatedStepper
from ._spectral import derivative, make_incompressible
from ._spectral import derivative, fft, ifft, make_incompressible
from ._utils import (
build_ic_set,
make_grid,
Expand All @@ -22,6 +23,8 @@
"poisson",
"RepeatedStepper",
"derivative",
"fft",
"ifft",
"make_incompressible",
"make_grid",
"rollout",
Expand All @@ -37,4 +40,5 @@
"reaction",
"stepper",
"viz",
"spectral",
]
12 changes: 6 additions & 6 deletions exponax/_base_stepper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from abc import ABC, abstractmethod

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, Complex, Float

from ._spectral import (
build_derivative_operator,
space_indices,
fft,
ifft,
spatial_shape,
wavenumber_shape,
)
Expand Down Expand Up @@ -202,12 +202,12 @@ def step(self, u: Float[Array, "C ... N"]) -> Float[Array, "C ... N"]:
**Returns:**
- `u_next`: The state vector after one step, shape `(C, ..., N,)`.
"""
u_hat = jnp.fft.rfftn(u, axes=space_indices(self.num_spatial_dims))
u_hat = fft(u, num_spatial_dims=self.num_spatial_dims)
u_next_hat = self.step_fourier(u_hat)
u_next = jnp.fft.irfftn(
u_next = ifft(
u_next_hat,
s=spatial_shape(self.num_spatial_dims, self.num_points),
axes=space_indices(self.num_spatial_dims),
num_spatial_dims=self.num_spatial_dims,
num_points=self.num_points,
)
return u_next

Expand Down
6 changes: 3 additions & 3 deletions exponax/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
from jaxtyping import Array, Float

from ._spectral import low_pass_filter_mask, space_indices
from ._spectral import fft, low_pass_filter_mask


def _MSE(
Expand Down Expand Up @@ -505,8 +505,8 @@ def _fourier_nRMSE(

mask = jnp.invert(low_mask) & high_mask

u_pred_fft = jnp.fft.rfftn(u_pred, axes=space_indices(num_spatial_dims))
u_ref_fft = jnp.fft.rfftn(u_ref, axes=space_indices(num_spatial_dims))
u_pred_fft = fft(u_pred, num_spatial_dims=num_spatial_dims)
u_ref_fft = fft(u_ref, num_spatial_dims=num_spatial_dims)

# The FFT incurse rounding errors around the machine precision that can be
# noticeable in the nRMSE. We will zero out the values that are smaller than
Expand Down
11 changes: 6 additions & 5 deletions exponax/_poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from ._spectral import (
build_derivative_operator,
build_laplace_operator,
space_indices,
fft,
ifft,
spatial_shape,
)

Expand Down Expand Up @@ -90,12 +91,12 @@ def step(
**Returns:**
- `u`: The solution.
"""
f_hat = jnp.fft.rfftn(f, axes=space_indices(self.num_spatial_dims))
f_hat = fft(f, num_spatial_dims=self.num_spatial_dims)
u_hat = self.step_fourier(f_hat)
u = jnp.fft.irfftn(
u = ifft(
u_hat,
axes=space_indices(self.num_spatial_dims),
s=spatial_shape(self.num_spatial_dims, self.num_points),
num_spatial_dims=self.num_spatial_dims,
num_points=self.num_points,
)
return u

Expand Down
Loading
Loading