From bea4079ffabf69976088308846ecb458d80ffc6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20K=C3=B6hler?= <27728103+Ceyron@users.noreply.github.com> Date: Tue, 3 Sep 2024 12:46:59 +0200 Subject: [PATCH] Consistent docs (#26) * Add documentation to base * Add documentation to etdrk core * Adapt to Equinox style * Change to Equinox style * Change to Equinox style * Change to Equinox style * Fix typo * Consistent documentation * Improve nonlin fun base documentation * Change to Equinox style * Improve docs for convection * Improve docs for gradient norm * Improve docs for combined general nonlinearity * Fix docstring * Improve and fix documentation * Fix broken link * Adapt to Equinox style * Improve documentation * Add clarification * better wording * Fix convection difficulty computation * Improve doc of generic gradient norm stepper * Add docstring * Add documentation * Add documentation * Add documentation --- exponax/_base_stepper.py | 52 ++- exponax/_forced_stepper.py | 32 +- exponax/_poisson.py | 36 +- exponax/_repeated_stepper.py | 32 +- exponax/_utils.py | 180 +++++----- exponax/etdrk/_base_etdrk.py | 45 +++ exponax/etdrk/_etdrk_0.py | 25 +- exponax/etdrk/_etdrk_1.py | 38 +++ exponax/etdrk/_etdrk_2.py | 46 +++ exponax/etdrk/_etdrk_3.py | 49 +++ exponax/etdrk/_etdrk_4.py | 51 +++ exponax/etdrk/_utils.py | 11 +- exponax/ic/_base_ic.py | 25 +- exponax/ic/_clamping.py | 8 +- exponax/ic/_diffused_noise.py | 24 +- exponax/ic/_discontinuities.py | 43 +-- exponax/ic/_gaussian_blob.py | 38 ++- exponax/ic/_gaussian_random_field.py | 23 +- exponax/ic/_multi_channel.py | 37 ++- exponax/ic/_scaled.py | 5 +- exponax/ic/_sine_waves_1d.py | 61 ++-- exponax/ic/_truncated_fourier_series.py | 35 +- exponax/nonlin_fun/_base.py | 77 ++++- exponax/nonlin_fun/_convection.py | 79 ++++- exponax/nonlin_fun/_general_nonlinear.py | 50 ++- exponax/nonlin_fun/_gradient_norm.py | 53 +-- exponax/nonlin_fun/_vorticity_convection.py | 76 +++-- exponax/nonlin_fun/_zero.py | 11 +- exponax/stepper/_advection.py | 2 +- exponax/stepper/_advection_diffusion.py | 2 +- exponax/stepper/_diffusion.py | 2 +- exponax/stepper/_dispersion.py | 2 +- exponax/stepper/_hyper_diffusion.py | 2 +- exponax/stepper/_kuramoto_sivashinsky.py | 2 +- exponax/stepper/generic/_convection.py | 202 +++++++++--- exponax/stepper/generic/_gradient_norm.py | 198 ++++++++--- exponax/stepper/generic/_nonlinear.py | 229 ++++++++++++- exponax/stepper/generic/_polynomial.py | 180 +++++++++- exponax/stepper/generic/_utils.py | 307 +++++++++++++++++- .../stepper/generic/_vorticity_convection.py | 58 ++++ 40 files changed, 1999 insertions(+), 429 deletions(-) diff --git a/exponax/_base_stepper.py b/exponax/_base_stepper.py index c41a8a2..60d1de2 100644 --- a/exponax/_base_stepper.py +++ b/exponax/_base_stepper.py @@ -164,12 +164,14 @@ def _build_linear_operator( Assemble the L operator in Fourier space. **Arguments:** - - `derivative_operator`: The derivative operator, shape `( D, ..., - N//2+1 )`. The ellipsis are (D-1) axis of size N (**not** of size - N//2+1). + + - `derivative_operator`: The derivative operator, shape `( D, ..., + N//2+1 )`. The ellipsis are (D-1) axis of size N (**not** of size + N//2+1). **Returns:** - - `L`: The linear operator, shape `( C, ..., N//2+1 )`. + + - `L`: The linear operator, shape `( C, ..., N//2+1 )`. """ pass @@ -183,12 +185,15 @@ def _build_nonlinear_fun( transforms to Fourier space, and evaluates derivatives there. **Arguments:** - - `derivative_operator`: The derivative operator, shape `( D, ..., N//2+1 )`. + + - `derivative_operator`: The derivative operator, shape `( D, ..., + N//2+1 )`. **Returns:** - - `nonlinear_fun`: A function that evaluates the nonlinearities in - time space, transforms to Fourier space, and evaluates the - derivatives there. Should be a subclass of `BaseNonlinearFun`. + + - `nonlinear_fun`: A function that evaluates the nonlinearities in + time space, transforms to Fourier space, and evaluates the + derivatives there. Should be a subclass of `BaseNonlinearFun`. """ pass @@ -197,10 +202,12 @@ def step(self, u: Float[Array, "C ... N"]) -> Float[Array, "C ... N"]: Perform one step of the time integration. **Arguments:** - - `u`: The state vector, shape `(C, ..., N,)`. + + - `u`: The state vector, shape `(C, ..., N,)`. **Returns:** - - `u_next`: The state vector after one step, shape `(C, ..., N,)`. + + - `u_next`: The state vector after one step, shape `(C, ..., N,)`. """ u_hat = fft(u, num_spatial_dims=self.num_spatial_dims) u_next_hat = self.step_fourier(u_hat) @@ -220,11 +227,13 @@ def step_fourier( transforms. **Arguments:** - - `u_hat`: The (real) Fourier transform of the state vector + + - `u_hat`: The (real) Fourier transform of the state vector **Returns:** - - `u_next_hat`: The (real) Fourier transform of the state vector - after one step + + - `u_next_hat`: The (real) Fourier transform of the state vector + after one step """ return self._integrator.step_fourier(u_hat) @@ -233,7 +242,22 @@ def __call__( u: Float[Array, "C ... N"], ) -> Float[Array, "C ... N"]: """ - Performs a check + Perform one step of the time integration for a single state. + + **Arguments:** + + - `u`: The state vector, shape `(C, ..., N,)`. + + **Returns:** + + - `u_next`: The state vector after one step, shape `(C, ..., N,)`. + + !!! tip + Use this call method together with `exponax.rollout` to efficiently + produce temporal trajectories. + + !!! info + For batched operation, use `jax.vmap` on this function. """ expected_shape = (self.num_channels,) + spatial_shape( self.num_spatial_dims, self.num_points diff --git a/exponax/_forced_stepper.py b/exponax/_forced_stepper.py index 149b339..81d684b 100644 --- a/exponax/_forced_stepper.py +++ b/exponax/_forced_stepper.py @@ -33,7 +33,8 @@ def __init__( transient integrators to forced problems. **Arguments**: - - `stepper`: The stepper to be transformed. + + - `stepper`: The stepper to be transformed. """ self.stepper = stepper @@ -49,11 +50,13 @@ def step( The forcing term `f` is assumed to be evaluated on the same grid as `u`. **Arguments**: - - `u`: The current state. - - `f`: The forcing term. + + - `u`: The current state. + - `f`: The forcing term. **Returns**: - - `u_next`: The state after one time step. + + - `u_next`: The state after one time step. """ u_with_force = u + self.stepper.dt * f return self.stepper.step(u_with_force) @@ -71,11 +74,13 @@ def step_fourier( `u_hat`. **Arguments**: - - `u_hat`: The current state in Fourier space. - - `f_hat`: The forcing term in Fourier space. + + - `u_hat`: The current state in Fourier space. + - `f_hat`: The forcing term in Fourier space. **Returns**: - - `u_next_hat`: The state after one time step in Fourier space. + + - `u_next_hat`: The state after one time step in Fourier space. """ u_hat_with_force = u_hat + self.stepper.dt * f_hat return self.stepper.step_fourier(u_hat_with_force) @@ -91,12 +96,13 @@ def __call__( The forcing term `f` is assumed to be evaluated on the same grid as `u`. - **Arguments**: - - `u`: The current state. - - `f`: The forcing term. + **Arguments:** - **Returns**: - - `u_next`: The state after one time step. - """ + - `u`: The current state. + - `f`: The forcing term. + **Returns:** + + - `u_next`: The state after one time step. + """ return self.step(u, f) diff --git a/exponax/_poisson.py b/exponax/_poisson.py index 19168ff..2c9b9d4 100644 --- a/exponax/_poisson.py +++ b/exponax/_poisson.py @@ -41,11 +41,12 @@ def __init__( It is included for completion. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. - - `domain_extent`: The extent of the domain. - - `num_points`: The number of points in each spatial dimension. - - `order`: The order of the Poisson equation. Defaults to 2. You can - also set `order=4` for the biharmonic equation. + + - `num_spatial_dims`: The number of spatial dimensions. + - `domain_extent`: The extent of the domain. + - `num_points`: The number of points in each spatial dimension. + - `order`: The order of the Poisson equation. Defaults to 2. You can + also set `order=4` for the biharmonic equation. """ self.num_spatial_dims = num_spatial_dims self.domain_extent = domain_extent @@ -71,10 +72,12 @@ def step_fourier( Solve the Poisson equation in Fourier space. **Arguments:** - - `f_hat`: The Fourier transform of the right hand side. + + - `f_hat`: The Fourier transform of the right hand side. **Returns:** - - `u_hat`: The Fourier transform of the solution. + + - `u_hat`: The Fourier transform of the solution. """ return -self._inv_operator * f_hat @@ -83,13 +86,15 @@ def step( f: Float[Array, "C ... N"], ) -> Float[Array, "C ... N"]: """ - Solve the Poisson equation in real space. + Solve the Poisson equation in state space. **Arguments:** - - `f`: The right hand side. + + - `f`: The right hand side. **Returns:** - - `u`: The solution. + + - `u`: The solution. """ f_hat = fft(f, num_spatial_dims=self.num_spatial_dims) u_hat = self.step_fourier(f_hat) @@ -104,6 +109,17 @@ def __call__( self, f: Float[Array, "C ... N"], ) -> Float[Array, "C ... N"]: + """ + Solve the Poisson equation in state space. + + **Arguments:** + + - `f`: The right hand side. + + **Returns:** + + - `u`: The solution. + """ if f.shape[1:] != spatial_shape(self.num_spatial_dims, self.num_points): raise ValueError( f"Shape of f[1:] is {f.shape[1:]} but should be {spatial_shape(self.num_spatial_dims, self.num_points)}" diff --git a/exponax/_repeated_stepper.py b/exponax/_repeated_stepper.py index eb7045f..4914622 100644 --- a/exponax/_repeated_stepper.py +++ b/exponax/_repeated_stepper.py @@ -33,8 +33,9 @@ def __init__( time step of X/Y and then wrap it in a RepeatedStepper with num_sub_steps=Y. **Arguments:** - - `stepper`: The stepper to repeat. - - `num_sub_steps`: The number of substeps to perform. + + - `stepper`: The stepper to repeat. + - `num_sub_steps`: The number of substeps to perform. """ self.stepper = stepper self.num_sub_steps = num_sub_steps @@ -52,8 +53,16 @@ def step( u: Float[Array, "C ... N"], ) -> Float[Array, "C ... N"]: """ - Step the PDE forward in time by self.num_sub_steps time steps given the + Step the PDE forward in time by `self.num_sub_steps` time steps given the current state `u`. + + **Arguments:** + + - `u`: The current state. + + **Returns:** + + - `u_next`: The state after `self.num_sub_steps` time steps. """ return repeat(self.stepper.step, self.num_sub_steps)(u) @@ -64,6 +73,15 @@ def step_fourier( """ Step the PDE forward in time by self.num_sub_steps time steps given the current state `u_hat` in real-valued Fourier space. + + **Arguments:** + + - `u_hat`: The current state in Fourier space. + + **Returns:** + + - `u_next_hat`: The state after `self.num_sub_steps` time steps in Fourier + space. """ return repeat(self.stepper.step_fourier, self.num_sub_steps)(u_hat) @@ -74,5 +92,13 @@ def __call__( """ Step the PDE forward in time by self.num_sub_steps time steps given the current state `u`. + + **Arguments:** + + - `u`: The current state. + + **Returns:** + + - `u_next`: The state after `self.num_sub_steps` time steps. """ return repeat(self.stepper, self.num_sub_steps)(u) diff --git a/exponax/_utils.py b/exponax/_utils.py index 3de27c5..1d587cd 100644 --- a/exponax/_utils.py +++ b/exponax/_utils.py @@ -17,28 +17,31 @@ def make_grid( indexing: str = "ij", ) -> Float[Array, "D ... N"]: """ - Return a grid in the spatial domain. A grid in d dimensions is an array of - shape (d,) + (num_points,)*d with the first axis representing all coordiate - inidices. + Return a grid in the spatial domain. A grid in D dimensions is an array of + shape (D,) + (num_points,)*D with the leading axis representing all + coordiate inidices. Notice, that if `num_spatial_dims = 1`, the returned array has a singleton dimension in the first axis, i.e., the shape is `(1, num_points)`. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. - - `domain_extent`: The extent of the domain in each spatial dimension. - - `num_points`: The number of points in each spatial dimension. - - `full`: Whether to include the right boundary point in the grid. - Default: `False`. The right point is redundant for periodic boundary - conditions and is not considered a degree of freedom. Use this - option, for example, if you need a full grid for plotting. - - `zero_centered`: Whether to center the grid around zero. Default: - `False`. By default the grid considers a domain of (0, - domain_extent)^(num_spatial_dims). - - `indexing`: The indexing convention to use. Default: `'ij'`. + + - `num_spatial_dims`: The number of spatial dimensions. + - `domain_extent`: The extent of the domain in each spatial dimension. + - `num_points`: The number of points in each spatial dimension. + - `full`: Whether to include the right boundary point in the grid. + Default: `False`. The right point is redundant for periodic boundary + conditions and is not considered a degree of freedom. Use this option, + for example, if you need a full grid for plotting. + - `zero_centered`: Whether to center the grid around zero. Default: + `False`. By default the grid considers a domain of (0, + domain_extent)^(num_spatial_dims). + - `indexing`: The indexing convention to use. Default: `'ij'`. **Returns:** - - `grid`: The grid in the spatial domain. Shape: `(num_spatial_dims, ..., num_points)`. + + - `grid`: The grid in the spatial domain. Shape: `(num_spatial_dims, + ..., num_points)`. """ if full: grid_1d = jnp.linspace(0, domain_extent, num_points + 1, endpoint=True) @@ -59,18 +62,23 @@ def make_grid( return grid -def wrap_bc(u): +def wrap_bc(u: Float[Array, "C N"]) -> Float[Array, "C N+1"]: """ Wraps the periodic boundary conditions around the array `u`. This can be used to plot the solution of a periodic problem on the full - interval [0, L] by plotting `wrap_bc(u)` instead of `u`. + interval [0, L] by plotting `wrap_bc(u)` instead of `u`. Consider using + `exponax.make_grid` with the `full=True` option to create a full grid. Note + that all routines in `exponax.viz` already correctly wrap the boundary + conditions. - **Parameters:** - - `u`: The array to wrap, shape `(N,)`. + **Arguments:** + + - `u`: The array to wrap, shape `(C, N,)`. **Returns:** - - `u_wrapped`: The wrapped array, shape `(N + 1,)`. + + - `u_wrapped`: The wrapped array, shape `(C, N + 1,)`. """ _, *spatial_shape = u.shape num_spatial_dims = len(spatial_shape) @@ -98,33 +106,32 @@ def rollout( a force/control or additional metadata (like physical parameters, or time for non-autonomous systems). - Args: - - `stepper_fn`: The time stepper to transform. If `takes_aux = False` - (default), expected signature is `u_next = stepper_fn(u)`, else - `u_next = stepper_fn(u, aux)`. `u` and `u_next` need to be PyTrees - of identical structure, in the easiest case just arrays of same - shape. - - `n`: The number of time steps to rollout the trajectory into the - future. If `include_init = False` (default) produces the `n` steps - into the future. - - `include_init`: Whether to include the initial condition in the - trajectory. If `True`, the arrays in the returning PyTree have shape - `(n + 1, ...)`, else `(n, ...)`. Default: `False`. - - `takes_aux`: Whether the stepper function takes an additional PyTree - as second argument. - - `constant_aux`: Whether the auxiliary input is constant over the - trajectory. If `True`, the auxiliary input is repeated `n` times, - otherwise the leading axis in the PyTree arrays has to be of length - `n`. - - Returns: - - `rollout_stepper_fn`: A function that takes an initial condition `u_0` - and an auxiliary input `aux` (if `takes_aux = True`) and produces - the trajectory by autoregressively applying the stepper `n` times. - If `include_init = True`, the trajectory has shape `(n + 1, ...)`, - else `(n, ...)`. Returns a PyTree of the same structure as the - initial condition, but with an additional leading axis of length - `n`. + **Arguments:** + + - `stepper_fn`: The time stepper to transform. If `takes_aux = False` + (default), expected signature is `u_next = stepper_fn(u)`, else `u_next + = stepper_fn(u, aux)`. `u` and `u_next` need to be PyTrees of identical + structure, in the easiest case just arrays of same shape. + - `n`: The number of time steps to rollout the trajectory into the + future. If `include_init = False` (default) produces the `n` steps into + the future. + - `include_init`: Whether to include the initial condition in the + trajectory. If `True`, the arrays in the returning PyTree have shape `(n + + 1, ...)`, else `(n, ...)`. Default: `False`. + - `takes_aux`: Whether the stepper function takes an additional PyTree + as second argument. + - `constant_aux`: Whether the auxiliary input is constant over the + trajectory. If `True`, the auxiliary input is repeated `n` times, + otherwise the leading axis in the PyTree arrays has to be of length `n`. + + **Returns:** + + - `rollout_stepper_fn`: A function that takes an initial condition `u_0` + and an auxiliary input `aux` (if `takes_aux = True`) and produces the + trajectory by autoregressively applying the stepper `n` times. If + `include_init = True`, the trajectory has shape `(n + 1, ...)`, else + `(n, ...)`. Returns a PyTree of the same structure as the initial + condition, but with an additional leading axis of length `n`. """ if takes_aux: @@ -196,26 +203,25 @@ def repeat( a force/control or additional metadata (like physical parameters, or time for non-autonomous systems). - Args: - - `stepper_fn`: The time stepper to transform. If `takes_aux = False` - (default), expected signature is `u_next = stepper_fn(u)`, else - `u_next = stepper_fn(u, aux)`. `u` and `u_next` need to be PyTrees - of identical structure, in the easiest case just arrays of same - shape. - - `n`: The number of times to apply the stepper. - - `takes_aux`: Whether the stepper function takes an additional PyTree - as second argument. - - `constant_aux`: Whether the auxiliary input is constant over the - trajectory. If `True`, the auxiliary input is repeated `n` times, - otherwise the leading axis in the PyTree arrays has to be of length - `n`. - - Returns: - - `repeated_stepper_fn`: A function that takes an initial condition - `u_0` and an auxiliary input `aux` (if `takes_aux = True`) and - produces the final state by autoregressively applying the stepper - `n` times. Returns a PyTree of the same structure as the initial - condition. + **Arguments:** + + - `stepper_fn`: The time stepper to transform. If `takes_aux = False` + (default), expected signature is `u_next = stepper_fn(u)`, else `u_next + = stepper_fn(u, aux)`. `u` and `u_next` need to be PyTrees of identical + structure, in the easiest case just arrays of same shape. + - `n`: The number of times to apply the stepper. + - `takes_aux`: Whether the stepper function takes an additional PyTree + as second argument. + - `constant_aux`: Whether the auxiliary input is constant over the + trajectory. If `True`, the auxiliary input is repeated `n` times, + otherwise the leading axis in the PyTree arrays has to be of length `n`. + + **Returns:** + + - `repeated_stepper_fn`: A function that takes an initial condition + `u_0` and an auxiliary input `aux` (if `takes_aux = True`) and produces + the final state by autoregressively applying the stepper `n` times. + Returns a PyTree of the same structure as the initial condition. """ if takes_aux: @@ -256,18 +262,22 @@ def stack_sub_trajectories( Slice a trajectory into subtrajectories of length `n` and stack them together. Useful for rollout training neural operators with temporal mixing. - !!! Note that this function can produce very large arrays. + !!! warning + This function can produce very large arrays, especially if `sub_le >> + 1`. **Arguments:** - - `trj`: The trajectory to slice. Expected shape: `(n_timesteps, ...)`. - - `sub_len`: The length of the subtrajectories. If you want to perform rollout - training with k steps, note that `n=k+1` to also have an initial - condition in the subtrajectories. + + - `trj`: The trajectory to slice. Expected shape: `(n_timesteps, ...)`. + - `sub_len`: The length of the subtrajectories. If you want to perform + rollout training with k steps, note that `n=k+1` to also have an initial + condition in the subtrajectories. **Returns:** - - `sub_trjs`: The stacked subtrajectories. Expected shape: `(n_stacks, n, ...)`. - `n_stacks` is the number of subtrajectories stacked together, i.e., - `n_timesteps - n + 1`. + + - `sub_trjs`: The stacked subtrajectories. Expected shape: `(n_stacks, + n, ...)`. `n_stacks` is the number of subtrajectories stacked together, + i.e., `n_timesteps - n + 1`. """ n_time_steps = [leaf.shape[0] for leaf in jtu.tree_leaves(trj)] @@ -303,26 +313,28 @@ def scan_fn(_, i): def build_ic_set( - ic_generator, + ic_generator: Callable[[int, PRNGKeyArray], Float[Array, "C ... N"]], *, num_points: int, num_samples: int, key: PRNGKeyArray, -) -> Float[Array, "S 1 ... N"]: +) -> Float[Array, "S C ... N"]: """ Generate a set of initial conditions by sampling from a given initial condition distribution and evaluating the function on the given grid. **Arguments:** - - `ic_generator`: A function that takes a PRNGKey and returns a - function that takes a grid and returns a sample from the initial - condition distribution. - - `num_samples`: The number of initial conditions to sample. - - `key`: The PRNGKey to use for sampling. + + - `ic_generator`: A function that takes a number of points and a PRNGKey + and returns an array representing the discrete state of an initial + condition. The shape of the returned array is `(C, ..., N)`. + - `num_samples`: The number of initial conditions to sample. + - `key`: The PRNGKey to use for sampling. **Returns:** - - `ic_set`: The set of initial conditions. Shape: `(S, 1, ..., N)`. - `S = num_samples`. + + - `ic_set`: The set of initial conditions. Shape: `(S, C, ..., N)`. + `S = num_samples`. """ def scan_fn(k, _): diff --git a/exponax/etdrk/_base_etdrk.py b/exponax/etdrk/_base_etdrk.py index 3e9f28b..63d1fee 100644 --- a/exponax/etdrk/_base_etdrk.py +++ b/exponax/etdrk/_base_etdrk.py @@ -21,6 +21,43 @@ def __init__( dt: float, linear_operator: Complex[Array, "E ... (N//2)+1"], ): + """ + Base class for exponential time differencing Runge-Kutta methods. + + **Arguments:** + + - `dt`: The time step size. + - `linear_operator`: The linear operator of the PDE. Must have a leading + channel axis, followed by one, two or three spatial axes whereas the + last axis must be of size `(N//2)+1` where `N` is the number of + dimensions in the former spatial axes. + + !!! Example + Below is an example how to get the linear operator for + the heat equation. + + ```python + import jax.numpy as jnp + import exponax as ex + + # Define the linear operator + N = 256 + L = 5.0 # The domain size + D = 1 # Being in 1D + + derivative_operator = 1j * ex.spectral.build_derivative_operator( + D, + L, + N, + ) + + print(derivative_operator.shape) # (1, (N//2)+1) + + nu = 0.01 # The diffusion coefficient + + linear_operator = nu * derivative_operator**2 + ``` + """ self.dt = dt self._exp_term = jnp.exp(self.dt * linear_operator) @@ -31,5 +68,13 @@ def step_fourier( ) -> Complex[Array, "C ... (N//2)+1"]: """ Advance the state in Fourier space. + + **Arguments:** + + - `u_hat`: The previous state in Fourier space. + + **Returns:** + + - The next state in Fourier space, i.e., `self.dt` time units later. """ pass diff --git a/exponax/etdrk/_etdrk_0.py b/exponax/etdrk/_etdrk_0.py index 06a0dea..30e94cd 100644 --- a/exponax/etdrk/_etdrk_0.py +++ b/exponax/etdrk/_etdrk_0.py @@ -4,9 +4,28 @@ class ETDRK0(BaseETDRK): - """ - Exactly solve a linear PDE in Fourier space - """ + def __init__( + self, + dt: float, + linear_operator: Complex[Array, "E ... (N//2)+1"], + ): + r""" + Exactly solve a linear PDE in Fourier space. + + $$ + \hat{u}_h^{[t+1]} = \exp(\hat{\mathcal{L}}_h \Delta t) \odot + \hat{u}_h^{[t]} + $$ + + **Arguments:** + + - `dt`: The time step size. + - `linear_operator`: The linear operator of the PDE. Must have a leading + channel axis, followed by one, two or three spatial axes whereas the + last axis must be of size `(N//2)+1` where `N` is the number of + dimensions in the former spatial axes. + """ + super().__init__(dt, linear_operator) def step_fourier( self, diff --git a/exponax/etdrk/_etdrk_1.py b/exponax/etdrk/_etdrk_1.py index de8215e..33d445d 100644 --- a/exponax/etdrk/_etdrk_1.py +++ b/exponax/etdrk/_etdrk_1.py @@ -19,6 +19,44 @@ def __init__( num_circle_points: int = 16, circle_radius: float = 1.0, ): + r""" + Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta + with a **first order approximation**. + + Adapted from Eq. (4) of [Cox and Matthews + (2002)](https://doi.org/10.1006/jcph.2002.6995): + + $$ + \hat{u}_h^{[t+1]} = \exp(\hat{\mathcal{L}}_h \Delta t) \odot + \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t) - + 1}{\hat{\mathcal{L}}_h} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}) + $$ + + where $\hat{\mathcal{N}}_h$ is the Fourier pseudo-spectral treatment of + the nonlinear differential operator. + + **Arguments:** + + - `dt`: The time step size. + - `linear_operator`: The linear operator of the PDE. Must have a leading + channel axis, followed by one, two or three spatial axes whereas the + last axis must be of size `(N//2)+1` where `N` is the number of + dimensions in the former spatial axes. + - `nonlinear_fun`: The Fourier pseudo-spectral treatment of the + nonlinear differential operator. + - `num_circle_points`: The number of points on the unit circle used to + approximate the numerically challenging coefficients. + - `circle_radius`: The radius of the circle used to approximate the + numerically challenging coefficients. + + !!! warning + The nonlinear function must take care of proper dealiasing. + + !!! note + The numerically stable evaluation of the coefficients follows + [Kassam and Trefethen + (2005)](https://doi.org/10.1137/S1064827502410633). + """ super().__init__(dt, linear_operator) self._nonlinear_fun = nonlinear_fun diff --git a/exponax/etdrk/_etdrk_2.py b/exponax/etdrk/_etdrk_2.py index 45a2123..f7b9d62 100644 --- a/exponax/etdrk/_etdrk_2.py +++ b/exponax/etdrk/_etdrk_2.py @@ -20,6 +20,52 @@ def __init__( num_circle_points: int = 16, circle_radius: float = 1.0, ): + r""" + Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta + with a **second order approximation**. + + Adopted from Eq. (22) of [Cox and Matthews + (2002)](https://doi.org/10.1006/jcph.2002.6995): + + $$ + \begin{aligned} + \hat{u}_h^* &= \exp(\hat{\mathcal{L}}_h \Delta t) \odot + \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t) - + 1}{\hat{\mathcal{L}}_h} \odot + \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}). \\ \hat{u}_h^{[t+1]} &= + \hat{u}_h^* + \frac{\exp(\hat{\mathcal{L}}_h \Delta t) - 1 - + \hat{\mathcal{L}}_h \Delta t}{\hat{\mathcal{L}}_h^2 \Delta t} + \left( \hat{\mathcal{N}}_h(\hat{u}_h^*) - + \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}) \right) + \end{aligned} + $$ + + where $\hat{\mathcal{N}}_h$ is the Fourier pseudo-spectral treatment of + the nonlinear differential operator. + + **Arguments:** + + - `dt`: The time step size. + - `linear_operator`: The linear operator of the PDE. Must have a leading + channel axis, followed by one, two or three spatial axes whereas the + last axis must be of size `(N//2)+1` where `N` is the number of + dimensions in the former spatial axes. + - `nonlinear_fun`: The Fourier pseudo-spectral treatment of the + nonlinear differential operator. ! The operator must take care of + proper dealiasing. + - `num_circle_points`: The number of points on the unit circle used to + approximate the numerically challenging coefficients. + - `circle_radius`: The radius of the circle used to approximate the + numerically challenging coefficients. + + !!! warning + The nonlinear function must take care of proper dealiasing. + + !!! note + The numerically stable evaluation of the coefficients follows + [Kassam and Trefethen + (2005)](https://doi.org/10.1137/S1064827502410633). + """ super().__init__(dt, linear_operator) self._nonlinear_fun = nonlinear_fun diff --git a/exponax/etdrk/_etdrk_3.py b/exponax/etdrk/_etdrk_3.py index 8d1cade..5a8ab31 100644 --- a/exponax/etdrk/_etdrk_3.py +++ b/exponax/etdrk/_etdrk_3.py @@ -24,6 +24,55 @@ def __init__( num_circle_points: int = 16, circle_radius: float = 1.0, ): + r""" + Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta + with a **third order approximation**. + + Adapted from Eq. (23-25) of [Cox and Matthews + (2002)](https://doi.org/10.1006/jcph.2002.6995): + + $$ + \begin{aligned} + \hat{u}_h^* &= \exp(\hat{\mathcal{L}}_h \Delta t / 2) \odot \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t/2) - 1}{\hat{\mathcal{L}}_h} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}). + \\ + \hat{u}_h^{**} &= \exp(\hat{\mathcal{L}}_h \Delta t / 2) \odot \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t) - 1}{\hat{\mathcal{L}}_h} \odot \left( 2 \hat{\mathcal{N}}_h(\hat{u}_h^*) - \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}) \right). + \\ + \hat{u}_h^{[t+1]} &= \exp(\hat{\mathcal{L}}_h \Delta t) \odot \hat{u}_h^{[t]} + \\ + &+ \frac{-4 - \exp(\hat{\mathcal{L}}_h \Delta t) + \exp(\hat{\mathcal{L}}_h \Delta) \left( 4 - 3 \hat{\mathcal{L}}_h \Delta t + \left(\hat{\mathcal{L}}_h \Delta t\right)^2 \right)}{\hat{\mathcal{L}}_h^3 (\Delta t)^2} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}). + \\ + &+ 4 \frac{2 + \hat{\mathcal{L}}_h \Delta t + \exp(\hat{\mathcal{L}}_h \Delta t) \left( -2 + \hat{\mathcal{L}}_h \Delta t \right)}{\hat{\mathcal{L}}_h^3 (\Delta t)^2} \odot \hat{\mathcal{N}}_h(\hat{u}_h^*) + \\ + &+ \frac{-4 - 3 \hat{\mathcal{L}}_h \Delta t - \left( \hat{\mathcal{L}}_h \Delta t \right)^2 + \exp(\hat{\mathcal{L}}_h \Delta t) \left( 4 - \hat{\mathcal{L}}_h \Delta t \right)}{\hat{\mathcal{L}}_h^3 (\Delta t)^2} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{**}) + \end{aligned} + $$ + + where $\hat{\mathcal{N}}_h$ is the Fourier pseudo-spectral treatment of + the nonlinear differential operator. + + **Arguments:** + + - `dt`: The time step size. + - `linear_operator`: The linear operator of the PDE. Must have a leading + channel axis, followed by one, two or three spatial axes whereas the + last axis must be of size `(N//2)+1` where `N` is the number of + dimensions in the former spatial axes. + - `nonlinear_fun`: The Fourier pseudo-spectral treatment of the + nonlinear differential operator. ! The operator must take care of + proper dealiasing. + - `num_circle_points`: The number of points on the unit circle used to + approximate the numerically challenging coefficients. + - `circle_radius`: The radius of the circle used to approximate the + numerically challenging coefficients. + + !!! warning + The nonlinear function must take care of proper dealiasing. + + !!! note + The numerically stable evaluation of the coefficients follows + [Kassam and Trefethen + (2005)](https://doi.org/10.1137/S1064827502410633). + """ super().__init__(dt, linear_operator) self._nonlinear_fun = nonlinear_fun self._half_exp_term = jnp.exp(0.5 * dt * linear_operator) diff --git a/exponax/etdrk/_etdrk_4.py b/exponax/etdrk/_etdrk_4.py index a5c7336..04b16ae 100644 --- a/exponax/etdrk/_etdrk_4.py +++ b/exponax/etdrk/_etdrk_4.py @@ -25,6 +25,57 @@ def __init__( num_circle_points: int = 16, circle_radius: float = 1.0, ): + r""" + Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta + with a **fourth order approximation**. + + Adapted from Eq. (26-29) of [Cox and Matthews + (2002)](https://doi.org/10.1006/jcph.2002.6995): + + $$ + \begin{aligned} + \hat{u}_h^* &= \exp(\hat{\mathcal{L}}_h \Delta t / 2) \odot \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t/2) - 1}{\hat{\mathcal{L}}_h} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}). + \\ + \hat{u}_h^{**} &= \exp(\hat{\mathcal{L}}_h \Delta t / 2) \odot \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t / 2) - 1}{\hat{\mathcal{L}}_h} \odot \hat{\mathcal{N}}_h(\hat{u}_h^*). + \\ + \hat{u}_h^{***} &= \exp(\hat{\mathcal{L}}_h \Delta t) \odot \hat{u}_h^{*} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t/2) - 1}{\hat{\mathcal{L}}_h} \odot \left( 2 \hat{\mathcal{N}}_h(\hat{u}_h^{**}) - \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}) \right). + \\ + \hat{u}_h^{[t+1]} &= \exp(\hat{\mathcal{L}}_h \Delta t) \odot \hat{u}_h^{[t]} + \\ + &+ \frac{-4 - \hat{\mathcal{L}}_h \Delta t + \exp(\hat{\mathcal{L}}_h \Delta t) \left( 4 - 3 \hat{\mathcal{L}}_h \Delta t + \left(\hat{\mathcal{L}}_h \Delta t\right)^2 \right)}{\hat{\mathcal{L}}_h^3 (\Delta t)^2} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}) + \\ + &+ 2 \frac{2 + \hat{\mathcal{L}}_h \Delta t + \exp(\hat{\mathcal{L}}_h \Delta t) \left( -2 + \hat{\mathcal{L}}_h \Delta t \right)}{\hat{\mathcal{L}}_h^3 (\Delta t)^2} \odot \left( \hat{\mathcal{N}}_h(\hat{u}_h^*) + \hat{\mathcal{N}}_h(\hat{u}_h^{**}) \right) + \\ + &+ \frac{-4 - 3 \hat{\mathcal{L}}_h \Delta t - \left( \hat{\mathcal{L}}_h \Delta t \right)^2 + \exp(\hat{\mathcal{L}}_h \Delta t) \left( 4 - \hat{\mathcal{L}}_h \Delta t \right)}{\hat{\mathcal{L}}_h^3 (\Delta t)^2} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{***}) + \end{aligned} + $$ + + where $\hat{\mathcal{N}}_h$ is the Fourier pseudo-spectral treatment of + the nonlinear differential operator. + + **Arguments:** + + - `dt`: The time step size. + - `linear_operator`: The linear operator of the PDE. Must have a leading + channel axis, followed by one, two or three spatial axes whereas the + last axis must be of size `(N//2)+1` where `N` is the number of + dimensions in the former spatial axes. + - `nonlinear_fun`: The Fourier pseudo-spectral treatment of the + nonlinear differential operator. ! The operator must take care of + proper dealiasing. + - `num_circle_points`: The number of points on the unit circle used to + approximate the numerically challenging coefficients. + - `circle_radius`: The radius of the circle used to approximate the + numerically challenging coefficients. + + !!! warning + The nonlinear function must take care of proper dealiasing. + + !!! note + The numerically stable evaluation of the coefficients follows + [Kassam and Trefethen + (2005)](https://doi.org/10.1137/S1064827502410633). + """ super().__init__(dt, linear_operator) self._nonlinear_fun = nonlinear_fun self._half_exp_term = jnp.exp(0.5 * dt * linear_operator) diff --git a/exponax/etdrk/_utils.py b/exponax/etdrk/_utils.py index 909bf6a..ed0abb8 100644 --- a/exponax/etdrk/_utils.py +++ b/exponax/etdrk/_utils.py @@ -8,7 +8,16 @@ def roots_of_unity(M: int) -> Complex[Array, "M"]: """ - Return (complex-valued) array with M roots of unity. + Return (complex-valued) array with M roots of unity. Useful to perform + contour integrals in the complex plane. + + **Arguments:** + + - `M`: The number of roots of unity. + + **Returns:** + + - `roots`: The M roots of unity in an array of shape `(M,)`. """ # return jnp.exp(1j * jnp.pi * (jnp.arange(1, M+1) - 0.5) / M) return jnp.exp(2j * jnp.pi * (jnp.arange(1, M + 1) - 0.5) / M) diff --git a/exponax/ic/_base_ic.py b/exponax/ic/_base_ic.py index e0461a9..926f38d 100644 --- a/exponax/ic/_base_ic.py +++ b/exponax/ic/_base_ic.py @@ -13,10 +13,12 @@ def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]: Evaluate the initial condition. **Arguments**: - - `x`: The grid points. + + - `x`: The grid points. **Returns**: - - `u`: The initial condition evaluated at the grid points. + + - `u`: The initial condition evaluated at the grid points. """ pass @@ -30,11 +32,13 @@ def gen_ic_fun(self, *, key: PRNGKeyArray) -> BaseIC: Generate an initial condition function. **Arguments**: - - `key`: A jax random key. + + - `key`: A jax random key. **Returns**: - - `ic`: An initial condition function that can be evaluated at - degree of freedom locations. + + - `ic`: An initial condition function that can be evaluated at + degree of freedom locations. """ raise NotImplementedError( "This random ic generator cannot represent its initial condition as a function. Directly evaluate it." @@ -47,15 +51,16 @@ def __call__( key: PRNGKeyArray, ) -> Float[Array, "1 ... N"]: """ - Generate a random initial condition. + Generate a random initial condition on a grid with `num_points` points. **Arguments**: - - `num_points`: The number of grid points in each dimension. - - `key`: A jax random key. - - `indexing`: The indexing convention for the grid. + + - `num_points`: The number of grid points in each dimension. + - `key`: A jax random key. **Returns**: - - `u`: The initial condition evaluated at the grid points. + + - `u`: The initial condition evaluated at the grid points. """ ic_fun = self.gen_ic_fun(key=key) grid = make_grid( diff --git a/exponax/ic/_clamping.py b/exponax/ic/_clamping.py index 236ec6e..41beb9e 100644 --- a/exponax/ic/_clamping.py +++ b/exponax/ic/_clamping.py @@ -15,9 +15,13 @@ def __init__( A generator based on another generator that clamps the output to a given range. + Some dynamics (like the Fisher-KPP equation) require such initial + conditions. + **Arguments**: - - `ic_gen`: The initial condition generator to clamp. - - `limits`: The lower and upper limits of the clamping range. + + - `ic_gen`: The initial condition generator to clamp. + - `limits`: The lower and upper limits of the clamping range. """ self.ic_gen = ic_gen self.limits = limits diff --git a/exponax/ic/_diffused_noise.py b/exponax/ic/_diffused_noise.py index 7e1c466..710261b 100644 --- a/exponax/ic/_diffused_noise.py +++ b/exponax/ic/_diffused_noise.py @@ -31,20 +31,20 @@ def __init__( The original noise is drawn in state space with a uniform normal distribution. After the application of the diffusion operator, the - spectrum decays exponentially with a rate of `intensity`. + spectrum decays exponentially quadratic with a rate of `intensity`. **Arguments**: - - `num_spatial_dims`: The number of spatial dimensions `d`. - - `domain_extent`: The extent of the domain. Defaults to `1.0`. This - indirectly affects the intensity of the noise. It is best to - keep it at `1.0` and just adjust the `intensity` instead. - - `intensity`: The intensity of the noise. Defaults to `0.001`. - - `zero_mean`: Whether to zero the mean of the noise. Defaults to - `True`. - - `std_one`: Whether to normalize the noise to have a standard - deviation of one. Defaults to `False`. - - `max_one`: Whether to normalize the noise to the maximum absolute - value of one. Defaults to `False`. + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `domain_extent`: The extent of the domain. Defaults to `1.0`. This + indirectly affects the intensity of the noise. It is best to keep it + at `1.0` and just adjust the `intensity` instead. + - `intensity`: The intensity of the noise. Defaults to `0.001`. + - `zero_mean`: Whether to zero the mean of the noise. + - `std_one`: Whether to normalize the noise to have a standard + deviation of one. Defaults to `False`. + - `max_one`: Whether to normalize the noise to the maximum absolute + value of one. Defaults to `False`. """ if not zero_mean and std_one: raise ValueError("Cannot have `zero_mean=False` and `std_one=True`.") diff --git a/exponax/ic/_discontinuities.py b/exponax/ic/_discontinuities.py index 9af30e6..cc75bd1 100644 --- a/exponax/ic/_discontinuities.py +++ b/exponax/ic/_discontinuities.py @@ -42,14 +42,15 @@ def __init__( A state described by a collection of discontinuities. **Arguments**: - - `discontinuity_list`: A tuple of discontinuities. - - `zero_mean`: Whether the state should have zero mean. - - `std_one`: Whether to normalize the state to have a standard - deviation of one. Defaults to `False`. Only works if the offset - is zero. - - `max_one`: Whether to normalize the state to have the maximum - absolute value of one. Defaults to `False`. Only one of - `std_one` and `max_one` can be `True`. + + - `discontinuity_list`: A tuple of discontinuities. + - `zero_mean`: Whether the state should have zero mean. + - `std_one`: Whether to normalize the state to have a standard + deviation of one. Defaults to `False`. Only works if the offset is + zero. + - `max_one`: Whether to normalize the state to have the maximum + absolute value of one. Defaults to `False`. Only one of `std_one` + and `max_one` can be `True`. """ if not zero_mean and std_one: raise ValueError("Cannot have `zero_mean=False` and `std_one=True`.") @@ -102,17 +103,18 @@ def __init__( discontinuities. **Arguments**: - - `num_spatial_dims`: The number of spatial dimensions. - - `domain_extent`: The extent of the domain in each spatial direction. - - `num_discontinuities`: The number of discontinuities. - - `value_range`: The range of values for the discontinuities. - - `zero_mean`: Whether the state should have zero mean. - - `std_one`: Whether to normalize the state to have a standard - deviation of one. Defaults to `False`. Only works if the offset - is zero. - - `max_one`: Whether to normalize the state to have the maximum - absolute value of one. Defaults to `False`. Only one of - `std_one` and `max_one` can be `True`. + + - `num_spatial_dims`: The number of spatial dimensions. + - `domain_extent`: The extent of the domain in each spatial direction. + - `num_discontinuities`: The number of discontinuities. + - `value_range`: The range of values for the discontinuities. + - `zero_mean`: Whether the state should have zero mean. + - `std_one`: Whether to normalize the state to have a standard + deviation of one. Defaults to `False`. Only works if the offset is + zero. + - `max_one`: Whether to normalize the state to have the maximum + absolute value of one. Defaults to `False`. Only one of `std_one` + and `max_one` can be `True`. """ if not zero_mean and std_one: raise ValueError("Cannot have `zero_mean=False` and `std_one=True`.") @@ -129,6 +131,9 @@ def __init__( self.max_one = max_one def gen_one_ic_fn(self, *, key: PRNGKeyArray) -> Discontinuity: + """ + Generates a single discontinuity. + """ lower_limits = [] upper_limits = [] for i in range(self.num_spatial_dims): diff --git a/exponax/ic/_gaussian_blob.py b/exponax/ic/_gaussian_blob.py index e82b703..65911c4 100644 --- a/exponax/ic/_gaussian_blob.py +++ b/exponax/ic/_gaussian_blob.py @@ -24,12 +24,15 @@ def __init__( one_complement: bool = False, ): """ - A state described by a Gaussian blob. + A state described by a Gaussian blob. Note that the produced function is + not perfectly periodic, especially if the blobs are close to the domain + boundaries. **Arguments**: - - `position`: The position of the blob. - - `covariance`: The covariance matrix of the blob. - - `one_complement`: Whether to return one minus the Gaussian blob. + + - `position`: The position of the blob. + - `covariance`: The covariance matrix of the blob. + - `one_complement`: Whether to return one minus the Gaussian blob. """ self.position = position self.covariance = covariance @@ -78,7 +81,8 @@ def __init__( A state described by a collection of Gaussian blobs. **Arguments**: - - `blob_list`: A tuple of Gaussian blobs. + + - `blob_list`: A tuple of Gaussian blobs. """ self.blob_list = blob_list @@ -111,16 +115,17 @@ def __init__( A random Gaussian blob initial condition generator. **Arguments**: - - `num_spatial_dims`: The number of spatial dimensions. - - `domain_extent`: The extent of the domain. - - `num_blobs`: The number of blobs. - - `position_range`: The range of the position of the blobs. This - will be scaled by the domain extent. Hence, this acts as if the - domain_extent was 1 - - `variance_range`: The range of the variance of the blobs. This will - be scaled by the domain extent. Hence, this acts as if the - domain_extent was 1 - - `one_complement`: Whether to return one minus the Gaussian blob. + + - `num_spatial_dims`: The number of spatial dimensions. + - `domain_extent`: The extent of the domain. + - `num_blobs`: The number of blobs. + - `position_range`: The range of the position of the blobs. This + will be scaled by the domain extent. Hence, this acts as if the + domain_extent was 1 + - `variance_range`: The range of the variance of the blobs. This will + be scaled by the domain extent. Hence, this acts as if the + domain_extent was 1 + - `one_complement`: Whether to return one minus the Gaussian blob. """ self.num_spatial_dims = num_spatial_dims self.domain_extent = domain_extent @@ -130,6 +135,9 @@ def __init__( self.one_complement = one_complement def gen_blob(self, *, key) -> GaussianBlob: + """ + Generates a single Gaussian blob. + """ position_key, variance_key = jr.split(key) position = jr.uniform( diff --git a/exponax/ic/_gaussian_random_field.py b/exponax/ic/_gaussian_random_field.py index 6fd1435..f3f8ba6 100644 --- a/exponax/ic/_gaussian_random_field.py +++ b/exponax/ic/_gaussian_random_field.py @@ -31,19 +31,20 @@ def __init__( ): """ Random generator for initial states following a power-law spectrum in - Fourier space. + Fourier space, i.e., it decays polynomially with the wavenumber. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. - - `domain_extent`: The extent of the domain in each spatial direction. - - `powerlaw_exponent`: The exponent of the power-law spectrum. - - `zero_mean`: Whether the field should have zero mean. - - `std_one`: Whether to normalize the state to have a standard - deviation of one. Defaults to `False`. Only works if the offset - is zero. - - `max_one`: Whether to normalize the state to have the maximum - absolute value of one. Defaults to `False`. Only one of - `std_one` and `max_one` can be `True`. + + - `num_spatial_dims`: The number of spatial dimensions. + - `domain_extent`: The extent of the domain in each spatial direction. + - `powerlaw_exponent`: The exponent of the power-law spectrum. + - `zero_mean`: Whether the field should have zero mean. + - `std_one`: Whether to normalize the state to have a standard + deviation of one. Defaults to `False`. Only works if the offset is + zero. + - `max_one`: Whether to normalize the state to have the maximum + absolute value of one. Defaults to `False`. Only one of `std_one` + and `max_one` can be `True`. """ if not zero_mean and std_one: raise ValueError("Cannot have `zero_mean=False` and `std_one=True`.") diff --git a/exponax/ic/_multi_channel.py b/exponax/ic/_multi_channel.py index 09c4a97..8084dea 100644 --- a/exponax/ic/_multi_channel.py +++ b/exponax/ic/_multi_channel.py @@ -14,7 +14,8 @@ def __init__(self, initial_conditions: tuple[BaseIC, ...]): A multi-channel initial condition. **Arguments**: - - `initial_conditions`: A tuple of initial conditions. + + - `initial_conditions`: A tuple of initial conditions. """ self.initial_conditions = initial_conditions @@ -23,10 +24,12 @@ def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "C ... N"]: Evaluate the initial condition. **Arguments**: - - `x`: The grid points. + + - `x`: The grid points. **Returns**: - - `u`: The initial condition evaluated at the grid points. + + - `u`: The initial condition evaluated at the grid points. """ return jnp.concatenate([ic(x) for ic in self.initial_conditions], axis=0) @@ -36,10 +39,34 @@ class RandomMultiChannelICGenerator(eqx.Module): def __init__(self, ic_generators: tuple[BaseRandomICGenerator, ...]): """ - A multi-channel random initial condition generator. + A multi-channel random initial condition generator. Use this for + problems with multiple channels, like Burgers in higher dimensions or + the Gray-Scott dynamics. **Arguments**: - - `ic_generators`: A tuple of initial condition generators. + + - `ic_generators`: A tuple of initial condition generators. + + !!! example + Below is an example for generating a random multi-channel initial + condition for the three-dimensional Burgers equation which has three + channels. For simplicity, we will use the same IC generator for each + channel. + + ```python + import jax + import exponax as ex + + single_channel_ic_gen = ex.ic.RandomTruncatedFourierSeries( + 3, + max_one=True, + ) + multi_channel_ic_gen = ex.ic.RandomMultiChannelICGenerator( + [single_channel_ic_gen,] * 3 + ) + + ic = multi_channel_ic_gen(100, key=jax.random.PRNGKey(0)) + ``` """ self.ic_generators = ic_generators diff --git a/exponax/ic/_scaled.py b/exponax/ic/_scaled.py index 6ecdbec..fde2adf 100644 --- a/exponax/ic/_scaled.py +++ b/exponax/ic/_scaled.py @@ -23,8 +23,9 @@ def __init__(self, ic_gen: BaseRandomICGenerator, scale: float): `max_one=True` or `std_one=True`. **Arguments**: - - `ic_gen`: The initial condition generator. - - `scale`: The scaling factor. + + - `ic_gen`: The initial condition generator. + - `scale`: The scaling factor. """ self.ic_gen = ic_gen self.scale = scale diff --git a/exponax/ic/_sine_waves_1d.py b/exponax/ic/_sine_waves_1d.py index 6cd8f14..de96370 100644 --- a/exponax/ic/_sine_waves_1d.py +++ b/exponax/ic/_sine_waves_1d.py @@ -29,17 +29,18 @@ def __init__( A state described by a collection of sine waves. Only works in 1d. **Arguments**: - - `domain_extent`: The extent of the domain. - - `amplitudes`: A tuple of amplitudes. - - `wavenumbers`: A tuple of wavenumbers. - - `phases`: A tuple of phases. - - `offset`: A constant offset. - - `std_one`: Whether to normalize the state to have a standard - deviation of one. Defaults to `False`. Only works if the offset - is zero. - - `max_one`: Whether to normalize the state to have the maximum - absolute value of one. Defaults to `False`. Only one of - `std_one` and `max_one` can be `True`. + + - `domain_extent`: The extent of the domain. + - `amplitudes`: A tuple of amplitudes. + - `wavenumbers`: A tuple of wavenumbers. + - `phases`: A tuple of phases. + - `offset`: A constant offset. + - `std_one`: Whether to normalize the state to have a standard + deviation of one. Defaults to `False`. Only works if the offset + is zero. + - `max_one`: Whether to normalize the state to have the maximum + absolute value of one. Defaults to `False`. Only one of + `std_one` and `max_one` can be `True`. """ if offset != 0.0 and std_one: raise ValueError("Cannot have non-zero offset and `std_one=True`.") @@ -103,23 +104,29 @@ def __init__( Random generator for initial states described by a collection of sine waves. Only works in 1d. + This is a simplified version of the `RandomTruncatedFourierSeries` + generator that works in arbitrary dimensions. However, only this + generator can produce a functional representation of the initial + condition. + **Arguments**: - - `num_spatial_dims`: The number of spatial dimensions. - - `domain_extent`: The extent of the domain. - - `cutoff`: The cutoff of the wavenumbers. This limits the - "complexity" of the initial state. Note that some dynamics are - very sensitive to high-frequency information. - - `amplitude_range`: The range of the amplitudes. Defaults to - `(-1.0, 1.0)`. - - `phase_range`: The range of the phases. Defaults to `(0.0, 2π)`. - - `offset_range`: The range of the offsets. Defaults to `(0.0, - 0.0)`, meaning **zero-mean** by default. - - `std_one`: Whether to normalize the state to have a standard - deviation of one. Defaults to `False`. Only works if the offset - is zero. - - `max_one`: Whether to normalize the state to have the maximum - absolute value of one. Defaults to `False`. Only one of - `std_one` and `max_one` can be `True`. + + - `num_spatial_dims`: The number of spatial dimensions. + - `domain_extent`: The extent of the domain. + - `cutoff`: The cutoff of the wavenumbers. This limits the + "complexity" of the initial state. Note that some dynamics are very + sensitive to high-frequency information. + - `amplitude_range`: The range of the amplitudes. Defaults to + `(-1.0, 1.0)`. + - `phase_range`: The range of the phases. Defaults to `(0.0, 2π)`. + - `offset_range`: The range of the offsets. Defaults to `(0.0, + 0.0)`, meaning **zero-mean** by default. + - `std_one`: Whether to normalize the state to have a standard + deviation of one. Defaults to `False`. Only works if the offset is + zero. + - `max_one`: Whether to normalize the state to have the maximum + absolute value of one. Defaults to `False`. Only one of `std_one` + and `max_one` can be `True`. """ if num_spatial_dims != 1: raise ValueError("RandomSineWaves1d only works in 1d.") diff --git a/exponax/ic/_truncated_fourier_series.py b/exponax/ic/_truncated_fourier_series.py index c200a2f..ee4aad0 100644 --- a/exponax/ic/_truncated_fourier_series.py +++ b/exponax/ic/_truncated_fourier_series.py @@ -59,22 +59,27 @@ def __init__( in the range `amplitude_range`. Angles (=angular offsets) are drawn according to a uniform distribution in the range `angle_range`. + See also `exponax.ic.RandomSineWaves1d` for a simplified version that + only works in 1d but can also produce a functional representation of the + initial state. + **Arguments**: - - `num_spatial_dims`: The number of spatial dimensions `d`. - - `cutoff`: The cutoff of the wavenumbers. This limits the - "complexity" of the initial state. Note that some dynamics are - very sensitive to high-frequency information. - - `amplitude_range`: The range of the amplitudes. Defaults to - `(-1.0, 1.0)`. - - `angle_range`: The range of the angles. Defaults to `(0.0, 2π)`. - - `offset_range`: The range of the offsets. Defaults to `(0.0, - 0.0)`, meaning **zero-mean** by default. - - `std_one`: Whether to normalize the state to have a standard - deviation of one. Defaults to `False`. Only works if the offset - is zero. - - `max_one`: Whether to normalize the state to have the maximum - absolute value of one. Defaults to `False`. Only one of - `std_one` and `max_one` can be `True`. + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `cutoff`: The cutoff of the wavenumbers. This limits the + "complexity" of the initial state. Note that some dynamics are very + sensitive to high-frequency information. + - `amplitude_range`: The range of the amplitudes. Defaults to + `(-1.0, 1.0)`. + - `angle_range`: The range of the angles. Defaults to `(0.0, 2π)`. + - `offset_range`: The range of the offsets. Defaults to `(0.0, + 0.0)`, meaning **zero-mean** by default. + - `std_one`: Whether to normalize the state to have a standard + deviation of one. Defaults to `False`. Only works if the offset is + zero. + - `max_one`: Whether to normalize the state to have the maximum + absolute value of one. Defaults to `False`. Only one of `std_one` + and `max_one` can be `True`. """ if offset_range == (0.0, 0.0) and std_one: raise ValueError("Cannot have non-zero offset and `std_one=True`.") diff --git a/exponax/nonlin_fun/_base.py b/exponax/nonlin_fun/_base.py index d04d992..8770b9b 100644 --- a/exponax/nonlin_fun/_base.py +++ b/exponax/nonlin_fun/_base.py @@ -19,6 +19,31 @@ def __init__( *, dealiasing_fraction: Optional[float] = None, ): + """ + Base class for all nonlinear functions. This class provides the basic + functionality to dealias the nonlinear terms and perform forward and + inverse Fourier transforms. + + **Arguments:** + + - `num_spatial_dims`: The number of spatial dimensions `D`. + - `num_points`: The number of points `N` used to discretize the domain. + This **includes** the left boundary point and **excludes** the right + boundary point. In higher dimensions; the number of points in each + dimension is the same. + - `dealiasing_fraction`: The fraction of the highest resolved mode to + keep for dealiasing. For example, `2/3` corresponds to Orszag's 2/3 + rule typically used for quadratic nonlinearities. If `None`, no + dealiasing is performed. + + !!! info + Some dealiasing strategies (like Orszag's 2/3 rule) are designed to + not fully remove aliasing (which would require 1/2 in the case of + quadratic nonlinearities), rather to only have aliases being created + in those modes that will be zeroed out anyway in the next + dealiasing step. See also [Orszag + (1971)](https://doi.org/10.1175/1520-0469(1971)028%3C1074:OTEOAI%3E2.0.CO;2) + """ self.num_spatial_dims = num_spatial_dims self.num_points = num_points @@ -39,14 +64,52 @@ def __init__( def dealias( self, u_hat: Complex[Array, "C ... (N//2)+1"] ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Dealias the Fourier representation of a state `u_hat` by zeroing out all + the coefficients associated with modes beyond `dealiasing_fraction` set + in the constructor. + + **Arguments:** + + - `u_hat`: The Fourier representation of the state `u`. + + **Returns:** + + - `u_hat_dealiased`: The dealiased Fourier representation of the state + `u`. + """ if self.dealiasing_mask is None: raise ValueError("Nonlinear function was set up without dealiasing") return self.dealiasing_mask * u_hat def fft(self, u: Float[Array, "C ... N"]) -> Complex[Array, "C ... (N//2)+1"]: + """ + Correctly wrapped **real-valued** Fourier transform for the shape of the + state vector associated with this nonlinear function. + + **Arguments:** + + - `u`: The state vector in real space. + + **Returns:** + + - `u_hat`: The (real-valued) Fourier transform of the state vector. + """ return fft(u, num_spatial_dims=self.num_spatial_dims) def ifft(self, u_hat: Complex[Array, "C ... (N//2)+1"]) -> Float[Array, "C ... N"]: + """ + Correctly wrapped **real-valued** inverse Fourier transform for the shape + of the state vector associated with this nonlinear function. + + **Arguments:** + + - `u_hat`: The (real-valued) Fourier transform of the state vector. + + **Returns:** + + - `u`: The state vector in real space. + """ return ifft( u_hat, num_spatial_dims=self.num_spatial_dims, num_points=self.num_points ) @@ -57,6 +120,18 @@ def __call__( u_hat: Complex[Array, "C ... (N//2)+1"], ) -> Complex[Array, "C ... (N//2)+1"]: """ - Evaluate all potential nonlinearities "pseudo-spectrally", account for dealiasing. + Evaluates the nonlinear function with a pseudo-spectral treatment and + accounts for dealiasing. + + Use this in combination with `exponax.etdrk` routines to solve + semi-linear PDEs in Fourier space. + + **Arguments:** + + - `u_hat`: The Fourier representation of the state `u`. + + **Returns:** + + - `𝒩(u_hat)`: The Fourier representation of the nonlinear term. """ pass diff --git a/exponax/nonlin_fun/_convection.py b/exponax/nonlin_fun/_convection.py index 0e30f9a..ea594ce 100644 --- a/exponax/nonlin_fun/_convection.py +++ b/exponax/nonlin_fun/_convection.py @@ -24,14 +24,20 @@ def __init__( found in the Burgers equation. In 1d and state space, this reads ``` - 𝒩(u) = b₁ 1/2 (u²)ₓ + 𝒩(u) = - b₁ 1/2 (u²)ₓ ``` - with a scale `b₁`. The typical extension to higher dimensions requires u - to have as many channels as spatial dimensions and then gives + with a scale `b₁`. The minus arises because `Exponax` follows the + convention that all nonlinear and linear differential operators are on + the right-hand side of the equation. Typically, the convection term is + on the left-hand side. Hence, the minus is required to move the term to + the right-hand side. + + The typical extension to higher dimensions requires u to have as many + channels as spatial dimensions and then gives ``` - 𝒩(u) = b₁ 1/2 ∇ ⋅ (u ⊗ u) + 𝒩(u) = - b₁ 1/2 ∇ ⋅ (u ⊗ u) ``` with `∇ ⋅` the divergence operator and the outer product `u ⊗ u`. @@ -39,23 +45,24 @@ def __init__( matter the spatial dimensions. This reads ``` - 𝒩(u) = b₁ 1/2 (1⃗ ⋅ ∇)(u²) + 𝒩(u) = - b₁ 1/2 (1⃗ ⋅ ∇)(u²) ``` **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions `d`. - - `num_points`: The number of points `N` used to discretize the - domain. This **includes** the left boundary point and - **excludes** the right boundary point. In higher dimensions; the - number of points in each dimension is the same. - - `derivative_operator`: A complex array of shape `(d, ..., N//2+1)` - that represents the derivative operator in Fourier space. - - `dealiasing_fraction`: The fraction of the highest resolved modes - that are not aliased. Defaults to `2/3` which corresponds to - Orszag's 2/3 rule. - - `scale`: The scale `b₁` of the convection term. Defaults to `1.0`. - - `single_channel`: Whether to use the single-channel hack. Defaults - to `False`. + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `num_points`: The number of points `N` used to discretize the + domain. This **includes** the left boundary point and **excludes** + the right boundary point. In higher dimensions; the number of points + in each dimension is the same. + - `derivative_operator`: A complex array of shape `(d, ..., N//2+1)` + that represents the derivative operator in Fourier space. + - `dealiasing_fraction`: The fraction of the highest resolved modes + that are not aliased. Defaults to `2/3` which corresponds to + Orszag's 2/3 rule. + - `scale`: The scale `b₁` of the convection term. Defaults to `1.0`. + - `single_channel`: Whether to use the single-channel hack. Defaults + to `False`. """ self.derivative_operator = derivative_operator self.scale = scale @@ -69,6 +76,24 @@ def __init__( def _multi_channel_eval( self, u_hat: Complex[Array, "C ... (N//2)+1"] ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Evaluates the convection term for a multi-channel state `u_hat` in + Fourier space. The convection term is given by + + ``` + 𝒩(u) = b₁ 1/2 ∇ ⋅ (u ⊗ u) + ``` + + with `∇ ⋅` the divergence operator and the outer product `u ⊗ u`. + + **Arguments:** + + - `u_hat`: The state in Fourier space. + + **Returns:** + + - `convection`: The evaluation of the convection term in Fourier space. + """ num_channels = u_hat.shape[0] if num_channels != self.num_spatial_dims: raise ValueError( @@ -88,6 +113,24 @@ def _multi_channel_eval( def _single_channel_eval( self, u_hat: Complex[Array, "C ... (N//2)+1"] ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Evaluates the convection term for a single-channel state `u_hat` in + Fourier space. The convection term is given by + + ``` + 𝒩(u) = b₁ 1/2 (1⃗ ⋅ ∇)(u²) + ``` + + with `∇ ⋅` the divergence operator and `1⃗` a vector of ones. + + **Arguments:** + + - `u_hat`: The state in Fourier space. + + **Returns:** + + - `convection`: The evaluation of the convection term in Fourier space. + """ u_hat_dealiased = self.dealias(u_hat) u = self.ifft(u_hat_dealiased) u_square = u**2 diff --git a/exponax/nonlin_fun/_general_nonlinear.py b/exponax/nonlin_fun/_general_nonlinear.py index 00c848d..c266572 100644 --- a/exponax/nonlin_fun/_general_nonlinear.py +++ b/exponax/nonlin_fun/_general_nonlinear.py @@ -22,9 +22,55 @@ def __init__( zero_mode_fix: bool = True, ): """ - Uses an additional scaling of 0.5 on the latter two components only + Fourier pseudo-spectral evaluation of a nonlinear differential operator + that has a square, convection (with single-channel hack), and gradient + norm term. In 1D and state space, this reads - By default: Burgers equation + ``` + 𝒩(u) = b₀ u² + b₁ 1/2 (u²)ₓ + b₂ 1/2 (uₓ)² + ``` + + The higher-dimensional extension is designed for a single-channel state + `u` (i.e., the number of channels do not grow with the number of spatial + dimensions, see also the description of + `exponax.nonlin_fun.ConvectionNonlinearFun`). The extension reads + + ``` + 𝒩(u) = b₀ u² + b₁ 1/2 (1⃗ ⋅ ∇)(u²) + b₂ 1/2 ‖∇u‖₂² + ``` + + !!! warning + In contrast to the individual nonlinear functions + `exponax.nonlin_fun.ConvectionNonlinearFun` and + `exponax.nonlin_fun.GradientNormNonlinearFun`, there is no minus. + Hence, to have a "propoper" convection term, consider supplying a + negative scale for the convection term, etc. + + **Arguments**: + + - `num_spatial_dims`: The number of spatial dimensions `D`. + - `num_points`: The number of points `N` used to discretize the domain. + This **includes** the left boundary point and **excludes** the right + boundary point. In higher dimensions; the number of points in each + dimension is the same. + - `derivative_operator`: A complex array of shape `(D, ..., N//2+1)` + that represents the derivative operator in Fourier space. + - `dealiasing_fraction`: The fraction of the highest resolved modes that + are not aliased. Defaults to `2/3` which corresponds to Orszag's 2/3 + rule. + - `scale_list`: A tuple of three floats `[b₀, b₁, b₂]` that represent + the scales of the square, (single-channel) convection, and gradient + norm term, respectively. Defaults to `[0.0, -1.0, 0.0]` which + corresponds to a pure convection term (i.e, in 1D together with a + diffusion linear term, this would be the Burgers equation). !!! + important: note that negation has to be manually provided! + - `zero_mode_fix`: Whether to set the zero mode to zero. In other words, + whether to have mean zero energy after nonlinear function activation. + This exists because the nonlinear operation happens after the + derivative operator is applied. Naturally, the derivative sets any + constant offset to zero. However, the square nonlinearity introduces + again a new constant offset. Setting this argument to `True` removes + this offset. Defaults to `True`. """ if len(scale_list) != 3: raise ValueError("The scale list must have exactly 3 elements") diff --git a/exponax/nonlin_fun/_gradient_norm.py b/exponax/nonlin_fun/_gradient_norm.py index df946c0..53f1fdd 100644 --- a/exponax/nonlin_fun/_gradient_norm.py +++ b/exponax/nonlin_fun/_gradient_norm.py @@ -26,38 +26,45 @@ def __init__( In 1d and state space, this reads ``` - 𝒩(u) = b₂ 1/2 (u²)ₓ + 𝒩(u) = - b₂ 1/2 (uₓ)² ``` - with a scale `b₂`. In higher dimensions, u has to be single channel and - the nonlinear function reads + with a scale `b₂`. The minus arises because `Exponax` follows the + convention that all nonlinear and linear differential operators are on + the right-hand side of the equation. Typically, the gradient norm term + is on the left-hand side. Hence, the minus is required to move the term + to the right-hand side. + + In higher dimensions, u has to be single channel and the nonlinear + function reads ``` - 𝒩(u) = b₂ 1/2 ‖∇u‖₂² + 𝒩(u) = - b₂ 1/2 ‖∇u‖₂² ``` with `‖∇u‖₂²` the squared L2 norm of the gradient of `u`. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions `d`. - - `num_points`: The number of points `N` used to discretize the - domain. This **includes** the left boundary point and - **excludes** the right boundary point. In higher dimensions; the - number of points in each dimension is the same. - - `derivative_operator`: A complex array of shape `(d, ..., N//2+1)` - that represents the derivative operator in Fourier space. - - `dealiasing_fraction`: The fraction of the highest resolved modes - that are not aliased. Defaults to `2/3` which corresponds to - Orszag's 2/3 rule. - - `zero_mode_fix`: Whether to set the zero mode to zero. In other - words, whether to have mean zero energy after nonlinear function - activation. This exists because the nonlinear operation happens - after the derivative operator is applied. Naturally, the - derivative sets any constant offset to zero. However, the square - nonlinearity introduces again a new constant offset. Setting - this argument to `True` removes this offset. Defaults to `True`. - - `scale`: The scale `b₂` of the gradient norm term. Defaults to - `1.0`. + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `num_points`: The number of points `N` used to discretize the + domain. This **includes** the left boundary point and **excludes** + the right boundary point. In higher dimensions; the number of points + in each dimension is the same. + - `derivative_operator`: A complex array of shape `(d, ..., N//2+1)` + that represents the derivative operator in Fourier space. + - `dealiasing_fraction`: The fraction of the highest resolved modes + that are not aliased. Defaults to `2/3` which corresponds to + Orszag's 2/3 rule. + - `zero_mode_fix`: Whether to set the zero mode to zero. In other + words, whether to have mean zero energy after nonlinear function + activation. This exists because the nonlinear operation happens + after the derivative operator is applied. Naturally, the derivative + sets any constant offset to zero. However, the square nonlinearity + introduces again a new constant offset. Setting this argument to + `True` removes this offset. Defaults to `True`. + - `scale`: The scale `b₂` of the gradient norm term. Defaults to + `1.0`. """ super().__init__( num_spatial_dims, diff --git a/exponax/nonlin_fun/_vorticity_convection.py b/exponax/nonlin_fun/_vorticity_convection.py index 46231ab..558f56f 100644 --- a/exponax/nonlin_fun/_vorticity_convection.py +++ b/exponax/nonlin_fun/_vorticity_convection.py @@ -25,25 +25,35 @@ def __init__( streamfunction-vorticity formulation. In state space, it reads ``` - 𝒩(ω) = b ([1, -1]ᵀ ⊙ ∇(Δ⁻¹u)) ⋅ ∇u + 𝒩(u) = - b ([1, -1]ᵀ ⊙ ∇(Δ⁻¹u)) ⋅ ∇u ``` with `b` the convection scale, `⊙` the Hadamard product, `∇` the derivative operator, `Δ⁻¹` the inverse Laplacian, and `u` the vorticity. + The minus arises because `Exponax` follows the convention that all + nonlinear and linear differential operators are on the right-hand side + of the equation. Typically, the vorticity convection term is on the + left-hand side. Hence, the minus is required to move the term to the + right-hand side. + + Since the inverse Laplacian is required, it internally performs a + Poisson solve which is straightforward in Fourier space. + **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions `d`. - - `num_points`: The number of points `N` used to discretize the - domain. This **includes** the left boundary point and **excludes** - the right boundary point. In higher dimensions; the number of - points in each dimension is the same. - - `convection_scale`: The scale `b` of the convection term. Defaults to - `1.0`. - - `derivative_operator`: A complex array of shape `(d, ..., N//2+1)` that - represents the derivative operator in Fourier space. - - `dealiasing_fraction`: The fraction of the highest resolved modes that - are not aliased. Defaults to `2/3` which corresponds to Orszag's 2/3 - rule. + + - `num_spatial_dims`: The number of spatial dimensions `D`. + - `num_points`: The number of points `N` used to discretize the domain. + This **includes** the left boundary point and **excludes** the right + boundary point. In higher dimensions; the number of points in each + dimension is the same. + - `convection_scale`: The scale `b` of the convection term. Defaults + to `1.0`. + - `derivative_operator`: A complex array of shape `(d, ..., N//2+1)` + that represents the derivative operator in Fourier space. + - `dealiasing_fraction`: The fraction of the highest resolved modes + that are not aliased. Defaults to `2/3` which corresponds to + Orszag's 2/3 rule. """ if num_spatial_dims != 2: raise ValueError(f"Expected num_spatial_dims = 2, got {num_spatial_dims}.") @@ -60,7 +70,9 @@ def __init__( laplacian = build_laplace_operator(derivative_operator, order=2) # Uses the UNCHANGED mean solution to the Poisson equation (hence, the - # mean of the "right-hand side" will be the mean of the solution) + # mean of the "right-hand side" will be the mean of the solution). + # However, this does not matter because we subsequently take the + # gradient which would annihilate any mean energy anyway. self.inv_laplacian = jnp.where(laplacian == 0, 1.0, 1 / laplacian) def __call__( @@ -110,11 +122,12 @@ def __init__( In state space, it reads ``` - 𝒩(ω) = b ([1, -1]ᵀ ⊙ ∇(Δ⁻¹u)) ⋅ ∇u - f + 𝒩(u) = - b ([1, -1]ᵀ ⊙ ∇(Δ⁻¹u)) ⋅ ∇u - f ``` For details on the vorticity convective term, see - `VorticityConvection2d`. The forcing term has the form + `exponax.nonlin_fun.VorticityConvection2d`. The forcing term has the + form ``` f = -k (2π/L) γ cos(k (2π/L) x₁) @@ -126,22 +139,21 @@ def __init__( the vorticity is derived via the curl). **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions `d`. - - `num_points`: The number of points `N` used to discretize the - domain. This **includes** the left boundary point and - **excludes** the right boundary point. In higher dimensions; the - number of points in each dimension is the same. - - `convection_scale`: The scale `b` of the convection term. Defaults - to `1.0`. - - `injection_mode`: The wavenumber `k` at which energy is injected. - Defaults to `4`. - - `injection_scale`: The intensity `γ` of the injection term. - Defaults to `1.0`. - - `derivative_operator`: A complex array of shape `(d, ..., N//2+1)` - that represents the derivative operator in Fourier space. - - `dealiasing_fraction`: The fraction of the highest resolved modes - that are not aliased. Defaults to `2/3` which corresponds to - Orszag's 2/3 rule. + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `num_points`: The number of points `N` used to discretize the + domain. This **includes** the left boundary point and **excludes** + the right boundary point. In higher dimensions; the number of points + in each dimension is the same. + - `convection_scale`: The scale `b` of the convection term. Defaults + to `1.0`. + - `injection_mode`: The wavenumber `k` at which energy is injected. + - `injection_scale`: The intensity `γ` of the injection term. + - `derivative_operator`: A complex array of shape `(d, ..., N//2+1)` + that represents the derivative operator in Fourier space. + - `dealiasing_fraction`: The fraction of the highest resolved modes + that are not aliased. Defaults to `2/3` which corresponds to + Orszag's 2/3 rule. """ super().__init__( num_spatial_dims, diff --git a/exponax/nonlin_fun/_zero.py b/exponax/nonlin_fun/_zero.py index f365d11..322146a 100644 --- a/exponax/nonlin_fun/_zero.py +++ b/exponax/nonlin_fun/_zero.py @@ -19,11 +19,12 @@ def __init__( ``` **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions `d`. - - `num_points`: The number of points `N` used to discretize the - domain. This **includes** the left boundary point and - **excludes** the right boundary point. In higher dimensions; the - number of points in each dimension is the same. + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `num_points`: The number of points `N` used to discretize the domain. + This **includes** the left boundary point and **excludes** the right + boundary point. In higher dimensions; the number of points in each + dimension is the same. """ super().__init__( num_spatial_dims, diff --git a/exponax/stepper/_advection.py b/exponax/stepper/_advection.py index 34fc3c8..cbabf0a 100644 --- a/exponax/stepper/_advection.py +++ b/exponax/stepper/_advection.py @@ -61,7 +61,7 @@ def __init__( **Notes:** - - The stepper is unconditionally stable, not matter the choice of + - The stepper is unconditionally stable, no matter the choice of any argument because the equation is solved analytically in Fourier space. **However**, note that initial conditions with modes higher than the Nyquist freuency (`(N//2)+1` with `N` being the diff --git a/exponax/stepper/_advection_diffusion.py b/exponax/stepper/_advection_diffusion.py index 4a00b75..a965176 100644 --- a/exponax/stepper/_advection_diffusion.py +++ b/exponax/stepper/_advection_diffusion.py @@ -78,7 +78,7 @@ def __init__( **Notes:** - - The stepper is unconditionally stable, not matter the choice of + - The stepper is unconditionally stable, no matter the choice of any argument because the equation is solved analytically in Fourier space. **However**, note that initial conditions with modes higher than the Nyquist freuency (`(N//2)+1` with `N` being the diff --git a/exponax/stepper/_diffusion.py b/exponax/stepper/_diffusion.py index ccc1f85..403a66c 100644 --- a/exponax/stepper/_diffusion.py +++ b/exponax/stepper/_diffusion.py @@ -73,7 +73,7 @@ def __init__( **Notes:** - - The stepper is unconditionally stable, not matter the choice of + - The stepper is unconditionally stable, no matter the choice of any argument because the equation is solved analytically in Fourier space. - A `ν > 0` leads to stable and decaying solutions (i.e., energy is diff --git a/exponax/stepper/_dispersion.py b/exponax/stepper/_dispersion.py index b7d8347..a330af1 100644 --- a/exponax/stepper/_dispersion.py +++ b/exponax/stepper/_dispersion.py @@ -75,7 +75,7 @@ def __init__( **Notes:** - - The stepper is unconditionally stable, not matter the choice of + - The stepper is unconditionally stable, no matter the choice of any argument because the equation is solved analytically in Fourier space. **However**, note that initial conditions with modes higher than the Nyquist freuency (`(N//2)+1` with `N` being the diff --git a/exponax/stepper/_hyper_diffusion.py b/exponax/stepper/_hyper_diffusion.py index 73bac08..9c0623f 100644 --- a/exponax/stepper/_hyper_diffusion.py +++ b/exponax/stepper/_hyper_diffusion.py @@ -72,7 +72,7 @@ def __init__( **Notes:** - - The stepper is unconditionally stable, not matter the choice of + - The stepper is unconditionally stable, no matter the choice of any argument because the equation is solved analytically in Fourier space. - Ultimately, only the factor `μ Δt / L⁴` affects the characteristic diff --git a/exponax/stepper/_kuramoto_sivashinsky.py b/exponax/stepper/_kuramoto_sivashinsky.py index 88d7d4c..9799f29 100644 --- a/exponax/stepper/_kuramoto_sivashinsky.py +++ b/exponax/stepper/_kuramoto_sivashinsky.py @@ -31,7 +31,7 @@ def __init__( equation on periodic boundary conditions. Uses the **combustion format** (or non-conservative format). Most deep learning papers in 1d considered the conservative format available as - [`KuramotoSivashinskyConservative`](exponax/stepper/KuramotoSivashinskyConservative). + [`exponax.stepper.KuramotoSivashinskyConservative`][]. In 1d, the KS equation is given by diff --git a/exponax/stepper/generic/_convection.py b/exponax/stepper/generic/_convection.py index c5b5aca..00df9c2 100644 --- a/exponax/stepper/generic/_convection.py +++ b/exponax/stepper/generic/_convection.py @@ -56,47 +56,48 @@ def __init__( Alternatively, with `single_channel=True`, the number of channels can be kept to constant 1 no matter the number of spatial dimensions. - Depending on the collection of linear coefficients can be represented, - for example: + Depending on the collection of linear coefficients a range of dynamics + can be represented, for example: - Burgers equation with `a = (0, 0, 0.01)` with `len(a) = 3` - KdV equation with `a = (0, 0, 0, 0.01)` with `len(a) = 4` **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions `d`. - - `domain_extent`: The size of the domain `L`; in higher dimensions - the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`. - - `num_points`: The number of points `N` used to discretize the - domain. This **includes** the left boundary point and - **excludes** the right boundary point. In higher dimensions; the - number of points in each dimension is the same. Hence, the total - number of degrees of freedom is `Nᵈ`. - - `dt`: The timestep size `Δt` between two consecutive states. - - `coefficients` (keyword-only): The list of coefficients `a_j` - corresponding to the derivatives. The length of this tuple - represents the highest occuring derivative. The default value - `(0.0, 0.0, 0.01)` corresponds to the Burgers equation (because - of the diffusion) - - `convection_scale` (keyword-only): The scale `b₁` of the - convection term. Default is `1.0`. - - `single_channel`: Whether to use the single channel mode in higher - dimensions. In this case the the convection is `b₁ (∇ ⋅ 1)(u²)`. - In this case, the state always has a single channel, no matter - the spatial dimension. Default: False. - - `order`: The order of the Exponential Time Differencing Runge - Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` - only solves the linear part of the equation. Use higher values - for higher accuracy and stability. The default choice of `2` is - a good compromise for single precision floats. - - `dealiasing_fraction`: The fraction of the wavenumbers to keep - before evaluating the nonlinearity. The default 2/3 corresponds - to Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. - Default: 2/3. - - `num_circle_points`: How many points to use in the complex contour - integral method to compute the coefficients of the exponential - time differencing Runge Kutta method. Default: 16. - - `circle_radius`: The radius of the contour used to compute the - coefficients of the exponential time differencing Runge Kutta - method. Default: 1.0. + + - `num_spatial_dims`: The number of spatial dimensions `D`. + - `domain_extent`: The size of the domain `L`; in higher dimensions + the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`. + - `num_points`: The number of points `N` used to discretize the + domain. This **includes** the left boundary point and **excludes** + the right boundary point. In higher dimensions; the number of points + in each dimension is the same. Hence, the total number of degrees of + freedom is `Nᵈ`. + - `dt`: The timestep size `Δt` between two consecutive states. + - `coefficients` (keyword-only): The list of coefficients `a_j` + corresponding to the derivatives. The length of this tuple + represents the highest occuring derivative. The default value `(0.0, + 0.0, 0.01)` corresponds to the Burgers equation (because of the + diffusion) + - `convection_scale` (keyword-only): The scale `b₁` of the + convection term. Default is `1.0`. + - `single_channel`: Whether to use the single channel mode in higher + dimensions. In this case the the convection is `b₁ (∇ ⋅ 1)(u²)`. In + this case, the state always has a single channel, no matter the + spatial dimension. Default: False. + - `order`: The order of the Exponential Time Differencing Runge + Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only + solves the linear part of the equation. Use higher values for higher + accuracy and stability. The default choice of `2` is a good + compromise for single precision floats. + - `dealiasing_fraction`: The fraction of the wavenumbers to keep + before evaluating the nonlinearity. The default 2/3 corresponds to + Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. Default: + 2/3. + - `num_circle_points`: How many points to use in the complex contour + integral method to compute the coefficients of the exponential time + differencing Runge Kutta method. Default: 16. + - `circle_radius`: The radius of the contour used to compute the + coefficients of the exponential time differencing Runge Kutta + method. Default: 1.0. """ self.coefficients = coefficients self.convection_scale = convection_scale @@ -166,12 +167,64 @@ def __init__( circle_radius: float = 1.0, ): """ - By default: Behaves like a Burgers with + Time stepper for the **normalized** d-dimensional (`d ∈ {1, 2, 3}`) + semi-linear PDEs consisting of a convection nonlinearity and an + arbitrary combination of (isotropic) linear derivatives. Uses a + normalized interface, i.e., the domain is scaled to `Ω = (0, 1)ᵈ` and + time step size is `Δt = 1.0`. + + See `exponax.stepper.generic.GeneralConvectionStepper` for more details + on the functional form of the PDE. + + In the default configuration, the number of channel grows with the + number of spatial dimensions. Setting the flag `single_channel=True` + activates a single-channel hack. + + Under the default settings, it behaves like the Burgers equation under + the following settings - ``` Burgers( + ```python + + exponax.stepper.Burgers( D=D, L=1, N=N, dt=0.1, diffusivity=0.01, ) ``` + + **Arguments:** + + - `num_spatial_dims`: The number of spatial dimensions `D`. + - `num_points`: The number of points `N` used to discretize the domain. + This **includes** the left boundary point and **excludes** the right + boundary point. In higher dimensions; the number of points in each + dimension is the same. Hence, the total number of degrees of freedom + is `Nᵈ`. + - `normalized_coefficients`: The list of coefficients + `α_j` corresponding to the derivatives. The length of this tuple + represents the highest occuring derivative. The default value `(0.0, + 0.0, 0.01)` corresponds to the Burgers equation (because of the + diffusion contribution). Note that these coefficients are normalized + on the unit domain and unit time step size. + - `normalized_convection_scale`: The scale `β` of the convection term. + Default is `1.0`. + - `single_channel`: Whether to use the single channel mode in higher + dimensions. In this case the the convection is `β (∇ ⋅ 1)(u²)`. In + this case, the state always has a single channel, no matter the + spatial dimension. Default: False. + - `order`: The order of the Exponential Time Differencing Runge + Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only + solves the linear part of the equation. Use higher values for higher + accuracy and stability. The default choice of `2` is a good + compromise for single precision floats. + - `dealiasing_fraction`: The fraction of the wavenumbers to keep + before evaluating the nonlinearity. The default 2/3 corresponds to + Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. Default: + 2/3. + - `num_circle_points`: How many points to use in the complex contour + integral method to compute the coefficients of the exponential time + differencing Runge Kutta method. Default: 16. + - `circle_radius`: The radius of the contour used to compute the + coefficients of the exponential time differencing Runge Kutta + method. Default: 1.0. """ self.normalized_coefficients = normalized_coefficients self.normalized_convection_scale = normalized_convection_scale @@ -209,8 +262,77 @@ def __init__( circle_radius: float = 1.0, ): """ - By default: Behaves like a Burgers + Timestepper for the **difficulty-based** d-dimensional (`d ∈ {1, 2, 3}`) + semi-linear PDEs consisting of a convection nonlinearity and an + arbitrary combination of (isotropic) linear derivatives. Uses a + difficulty-based interface where the "intensity" of the dynamics reduces + with increasing resolution. This is intended such that emulator learning + problems on two resolutions are comparibly difficult. + + Different to `exponax.stepper.generic.NormalizedConvectionStepper`, the + dynamics are defined by difficulties. The difficulties are a different + combination of normalized dynamics, `num_spatial_dims`, and + `num_points`. + + γᵢ = αᵢ Nⁱ 2ⁱ⁻¹ d + + with `d` the number of spatial dimensions, `N` the number of points, and + `αᵢ` the normalized coefficient. + + The difficulty of the nonlinear convection scale is defined by + + δ₁ = β₁ * M * N * D + + with `M` the maximum absolute value of the input state (typically `1.0` + if one uses the `exponax.ic` random generators with the `max_one=True` + argument). + + This interface is more natural than the normalized interface because the + difficulties for all orders (given by `i`) are around 1.0. Additionally, + they relate to stability condition of explicit Finite Difference schemes + for the particular equations. For example, for advection (`i=1`), the + absolute of the difficulty is the Courant-Friedrichs-Lewy (CFL) number. + + Under the default settings, this timestepper represents the Burgers + equation. + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions `D`. + - `num_points`: The number of points `N` used to discretize the domain. + This **includes** the left boundary point and **excludes** the right + boundary point. In higher dimensions; the number of points in each + dimension is the same. Hence, the total number of degrees of freedom + is `Nᵈ`. + - `linear_difficulties`: The list of difficulties `γᵢ` corresponding to + the derivatives. The length of this tuple represents the highest + occuring derivative. The default value `(0.0, 0.0, 4.5)` corresponds + to the Burgers equation. Note that these coefficients are normalized + on the unit domain and unit time step size. + - `convection_difficulty`: The difficulty `δ` of the convection term. + Default is `5.0`. + - `single_channel`: Whether to use the single channel mode in higher + dimensions. In this case the the convection is `δ (∇ ⋅ 1)(u²)`. In + this case, the state always has a single channel, no matter the + spatial dimension. Default: False. + - `maximum_absolute`: The maximum absolute value of the state. This is + used to extract the normalized dynamics from the convection + difficulty. + - `order`: The order of the Exponential Time Differencing Runge + Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only + solves the linear part of the equation. Use higher values for higher + accuracy and stability. The default choice of `2` is a good + compromise for single precision floats. + - `dealiasing_fraction`: The fraction of the wavenumbers to keep + before evaluating the nonlinearity. The default 2/3 corresponds to + Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. Default: + 2/3. + - `num_circle_points`: How many points to use in the complex contour + integral method to compute the coefficients of the exponential time + differencing Runge Kutta method. Default: 16. + - `circle_radius`: The radius of the contour used to compute the + coefficients of the exponential time differencing Runge Kutta + method. Default: 1.0. """ self.linear_difficulties = linear_difficulties self.convection_difficulty = convection_difficulty diff --git a/exponax/stepper/generic/_gradient_norm.py b/exponax/stepper/generic/_gradient_norm.py index fe03f7d..cd6e299 100644 --- a/exponax/stepper/generic/_gradient_norm.py +++ b/exponax/stepper/generic/_gradient_norm.py @@ -51,42 +51,43 @@ def __init__( ``` The default configuration coincides with a Kuramoto-Sivashinsky equation - in combustion format. Note that this requires negative values (because - the KS usually defines their linear operators on the left hand side of - the equation) + in combustion format (see `exponax.stepper.KuramotoSivashinsky`). Note + that this requires negative values (because the KS usually defines their + linear operators on the left hand side of the equation) **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions `d`. - - `domain_extent`: The size of the domain `L`; in higher dimensions - the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`. - - `num_points`: The number of points `N` used to discretize the - domain. This **includes** the left boundary point and - **excludes** the right boundary point. In higher dimensions; the - number of points in each dimension is the same. Hence, the total - number of degrees of freedom is `Nᵈ`. - - `dt`: The timestep size `Δt` between two consecutive states. - - `coefficients` (keyword-only): The list of coefficients `a_j` - corresponding to the derivatives. The length of this tuple - represents the highest occuring derivative. The default value - `(0.0, 0.0, -1.0, 0.0, -1.0)` corresponds to the Kuramoto- - Sivashinsky equation in combustion format. - - `gradient_norm_scale` (keyword-only): The scale of the gradient - norm term `b₂`. Default: 1.0. - - `order`: The order of the Exponential Time Differencing Runge - Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` - only solves the linear part of the equation. Use higher values - for higher accuracy and stability. The default choice of `2` is - a good compromise for single precision floats. - - `dealiasing_fraction`: The fraction of the wavenumbers to keep - before evaluating the nonlinearity. The default 2/3 corresponds - to Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. - Default: 2/3. - - `num_circle_points`: How many points to use in the complex contour - integral method to compute the coefficients of the exponential - time differencing Runge Kutta method. Default: 16. - - `circle_radius`: The radius of the contour used to compute the - coefficients of the exponential time differencing Runge Kutta - method. Default: 1.0. + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `domain_extent`: The size of the domain `L`; in higher dimensions + the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`. + - `num_points`: The number of points `N` used to discretize the + domain. This **includes** the left boundary point and **excludes** + the right boundary point. In higher dimensions; the number of points + in each dimension is the same. Hence, the total number of degrees of + freedom is `Nᵈ`. + - `dt`: The timestep size `Δt` between two consecutive states. + - `coefficients` (keyword-only): The list of coefficients `a_j` + corresponding to the derivatives. The length of this tuple + represents the highest occuring derivative. The default value `(0.0, + 0.0, -1.0, 0.0, -1.0)` corresponds to the Kuramoto- Sivashinsky + equation in combustion format. + - `gradient_norm_scale` (keyword-only): The scale of the gradient + norm term `b₂`. Default: 1.0. + - `order`: The order of the Exponential Time Differencing Runge + Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only + solves the linear part of the equation. Use higher values for higher + accuracy and stability. The default choice of `2` is a good + compromise for single precision floats. + - `dealiasing_fraction`: The fraction of the wavenumbers to keep + before evaluating the nonlinearity. The default 2/3 corresponds to + Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. Default: + 2/3. + - `num_circle_points`: How many points to use in the complex contour + integral method to compute the coefficients of the exponential time + differencing Runge Kutta method. Default: 16. + - `circle_radius`: The radius of the contour used to compute the + coefficients of the exponential time differencing Runge Kutta + method. Default: 1.0. """ self.coefficients = coefficients self.gradient_norm_scale = gradient_norm_scale @@ -153,11 +154,68 @@ def __init__( circle_radius: float = 1.0, ): """ - the number of channels do **not** grow with the number of spatial - dimensions. They are always 1. + Timestepper for the **normalized** d-dimensional (`d ∈ {1, 2, 3}`) + semi-linear PDEs consisting of a gradient norm nonlinearity and an + arbitrary combination of (isotropic) linear operators. Uses a normalized + interface, i.e., the domain is scaled to `Ω = (0, 1)ᵈ` and time step + size is `Δt = 1.0`. + + See `exponax.stepper.generic.GeneralGradientNormStepper` for more + details on the functional form of the PDE. + + The number of channels do **not** grow with the number of spatial + dimensions. They are always one. + + Under the default settings, it behaves like the Kuramoto-Sivashinsky + equation in combustion format under the following settings. By default: the KS equation on L=60.0 + ```python + + exponax.stepper.KuramotoSivashinsky( + num_spatial_dims=D, domain_extent=60.0, num_points=N, dt=0.1, + gradient_norm_scale=1.0, second_order_diffusivity=1.0, + fourth_order_diffusivity=1.0, + ) + ``` + + Note that the coefficient list requires a negative sign because the + linear derivatives are moved to the right-hand side in this generic + interface. + + **Arguments:** + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `num_points`: The number of points `N` used to discretize the + domain. This **includes** the left boundary point and **excludes** + the right boundary point. In higher dimensions; the number of points + in each dimension is the same. Hence, the total number of degrees of + freedom is `Nᵈ`. + - `normalized_coefficients`: The list of coefficients `a_j` + corresponding to the derivatives. The length of this tuple + represents the highest occuring derivative. The default value `(0.0, + 0.0, -1.0 * 0.1 / (60.0**2), 0.0, -1.0 * 0.1 / (60.0**4))` + corresponds to the Kuramoto-Sivashinsky equation in combustion + format on a domain of size `L=60.0` with a time step size of + `Δt=0.1`. + - `normalized_gradient_norm_scale`: The scale of the gradient + norm term `b₂`. Default: `1.0 * 0.1 / (60.0**2)`. + - `order`: The order of the Exponential Time Differencing Runge + Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only + solves the linear part of the equation. Use higher values for higher + accuracy and stability. The default choice of `2` is a good + compromise for single precision floats. + - `dealiasing_fraction`: The fraction of the wavenumbers to keep + before evaluating the nonlinearity. The default 2/3 corresponds to + Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. Default: + 2/3. + - `num_circle_points`: How many points to use in the complex contour + integral method to compute the coefficients of the exponential time + differencing Runge Kutta method. Default: 16. + - `circle_radius`: The radius of the contour used to compute the + coefficients of the exponential time differencing Runge Kutta + method. Default: 1.0. """ self.normalized_coefficients = normalized_coefficients self.normalized_gradient_norm_scale = normalized_gradient_norm_scale @@ -193,7 +251,73 @@ def __init__( circle_radius: float = 1.0, ): """ - By default: KS equation + Timestepper for the **difficulty-based** d-dimensional (`d ∈ {1, 2, 3}`) + semi-linear PDEs consisting of a gradient norm nonlinearity and an + arbitrary combination of (isotropic) linear operators. Uses a + difficulty-based interface where the "intensity" of the dynamics reduces + with increasing resolution. This is intended such that emulator learning + problems on two resolutions are comparibly difficult. + + Different to `exponax.stepper.generic.NormalizedGradientNormStepper`, + the dynamics are defined by difficulties. The difficulties are a + different combination of normalized dynamics, `num_spatial_dims`, and + `num_points`. + + γᵢ = αᵢ Nⁱ 2ⁱ⁻¹ d + + with `d` the number of spatial dimensions, `N` the number of points, and + `αᵢ` the normalized coefficient. + + The difficulty of the nonlinear convection scale is defined by + + δ₂ = β₂ * M * N² * D + + with `M` the maximum absolute value of the input state (typically `1.0` + if one uses the `exponax.ic` random generators with the `max_one=True` + argument). + + This interface is more natural than the normalized interface because the + difficulties for all orders (given by `i`) are around 1.0. Additionally, + they relate to stability condition of explicit Finite Difference schemes + for the particular equations. For example, for advection (`i=1`), the + absolute of the difficulty is the Courant-Friedrichs-Lewy (CFL) number. + + Under the default settings, this timestepper represents the + Kuramoto-Sivashinsky equation (in combustion format). + + **Arguments:** + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `num_points`: The number of points `N` used to discretize the + domain. This **includes** the left boundary point and **excludes** + the right boundary point. In higher dimensions; the number of points + in each dimension is the same. Hence, the total number of degrees of + freedom is `Nᵈ`. + - `linear_difficulties`: The list of difficulties `γᵢ` corresponding to + the derivatives. The length of this tuple represents the highest + occuring derivative. The default value `(0.0, 0.0, -0.128, 0.0, + -0.32768)` corresponds to the Kuramoto-Sivashinsky equation in + combustion format (because it contains both a negative diffusion and + a negative hyperdiffusion term). + - `gradient_norm_difficulty`: The difficulty of the gradient norm term + `δ₂`. + - `maximum_absolute`: The maximum absolute value of the input state. This + is used to scale the gradient norm term. + - `order`: The order of the Exponential Time Differencing Runge + Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only + solves the linear part of the equation. Use higher values for higher + accuracy and stability. The default choice of `2` is a good + compromise for single precision floats. + - `dealiasing_fraction`: The fraction of the wavenumbers to keep + before evaluating the nonlinearity. The default 2/3 corresponds to + Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. Default: + 2/3. + - `num_circle_points`: How many points to use in the complex contour + integral method to compute the coefficients of the exponential time + differencing Runge Kutta method. Default: 16. + - `circle_radius`: The radius of the contour used to compute the + coefficients of the exponential time differencing Runge Kutta + method. Default: 1.0. """ self.linear_difficulties = linear_difficulties self.gradient_norm_difficulty = gradient_norm_difficulty diff --git a/exponax/stepper/generic/_nonlinear.py b/exponax/stepper/generic/_nonlinear.py index 67c43ec..60315be 100644 --- a/exponax/stepper/generic/_nonlinear.py +++ b/exponax/stepper/generic/_nonlinear.py @@ -29,7 +29,83 @@ def __init__( circle_radius: float = 1.0, ): """ - By default Burgers equation + Timestepper for d-dimensional (`d ∈ {1, 2, 3}`) semi-linear PDEs + consisting of a quadratic, a single-channel convection, and a gradient + norm nonlinearity together with an arbitrary combination of (isotropic) + linear derivatives. + + In 1d, the PDE is of the form + + ``` + uₜ = b₀ u² + b₁ 1/2 (u²)ₓ + b₂ 1/2 (uₓ)² + sum_j a_j uₓʲ + ``` + + where `b₀`, `b₁`, `b₂` are the coefficients of the quadratic, + convection, and gradient norm nonlinearity, respectively, and `a_j` are + the coefficients of the linear derivatives. Effectively, this + timestepper is a combination of the + `exponax.stepper.generic.GeneralPolynomialStepper` (with only the + coefficient to the quadratic polynomial being set with `b₀`), the + `exponax.stepper.generic.GeneralConvectionStepper` (with the + single-channel hack activated via `single_channel=True` and the + convection scale set with `- b₁`), and the + `exponax.stepper.generic.GeneralGradientNormStepper` (with the gradient + norm scale set with `- b₂`). + + !!! warning + In contrast to the + `exponax.stepper.generic.GeneralConvectionStepper` and the + `exponax.stepper.generic.GeneralGradientNormStepper`, the nonlinear + terms are no considered to be on right-hand side of the PDE. Hence, + in order to get the same dynamics as in the other steppers, the + coefficients must be negated. (This is not relevant for the + coefficient of the quadratic polynomial because in the + `exponax.stepper.generic.GeneralPolynomialStepper` the polynomial + nonlinearity is already on the right-hand side.) + + The higher-dimensional generalization is + + ``` + uₜ = b₀ u² + b₁ 1/2 (1⃗ ⋅ ∇)(u²) + b₂ 1/2 ‖ ∇u ‖₂² + sum_j a_j uₓˢ + ``` + + Under the default configuration, this correspons to a Burgers equation + in single-channel mode. + + **Arguments:** + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `domain_extent`: The size of the domain `L`; in higher dimensions the + domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`. + - `num_points`: The number of points `N` used to discretize the domain. + This **includes** the left boundary point and **excludes** the right + boundary point. In higher dimensions; the number of points in each + dimension is the same. Hence, the total number of degrees of freedom + is `Nᵈ`. + - `dt`: The timestep size `Δt` between two consecutive states. + - `coefficients_linear`: The list of coefficients `a_j` corresponding to + the derivatives. The length of this tuple represents the highest + occuring derivative. The default value `(0.0, 0.0, 0.01)` together + with the default `coefficients_nonlinear` corresponds to the Burgers + equation. + - `coefficients_nonlinear`: The list of coefficients `b₀`, `b₁`, `b₂` + (in this order). The default value `(0.0, -1.0, 0.0)` corresponds to + a (single-channel) convection nonlinearity scaled with `1.0`. Note + that all nonlinear contributions are considered to be on the + right-hand side of the PDE. Hence, in order to get the "correct + convection" dynamics, the coefficients must be negated. + - `order`: The order of the ETDRK method to use. Must be one of {0, 1, 2, + 3, 4}. The option `0` only solves the linear part of the equation. + Use higher values for higher accuracy and stability. The default + choice of `2` is a good compromise for single precision floats. + - `dealiasing_fraction`: The fraction of the wavenumbers to keep before + evaluating the nonlinearity. The default value `2/3` corresponds to + Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. + - `num_circle_points`: How many points to use in the complex contour + integral method to compute the coefficients of the exponential time + differencing Runge Kutta method. + - `circle_radius`: The radius of the contour used to compute the + coefficients of the exponential time differencing Runge Kutta method. """ if len(coefficients_nonlinear) != 3: raise ValueError( @@ -96,6 +172,84 @@ def __init__( ): """ By default Burgers. + + Timesteppr for **normalized** d-dimensional (`d ∈ {1, 2, 3}`) + semi-linear PDEs consisting of a quadratic, a single-channel convection, + and a gradient norm nonlinearity together with an arbitrary combination + of (isotropic) linear derivatives. Uses a normalized interface, i.e., + the domain is scaled to `Ω = (0, 1)ᵈ` and time step size is `Δt = 1.0`. + + See `exponax.stepper.generic.GeneralNonlinearStepper` for more details + on the functional form of the PDE. + + Behaves like a single-channel Burgers equation by default under the + following settings + + ```python + + exponax.stepper.Burgers( + num_spatial_dims=num_spatial_dims, domain_extent=1.0, + num_points=num_points, dt=1.0, convection_scale=1.0, + diffusivity=0.1, single_channel=True, + ) + ``` + + Effectively, this timestepper is a combination of the + `exponax.stepper.generic.NormalizedPolynomialStepper` (with only the + coefficient to the quadratic polynomial being set with `b₀`), the + `exponax.stepper.generic.NormalizedConvectionStepper` (with the + single-channel hack activated via `single_channel=True` and the + convection scale set with `- b₁`), and the + `exponax.stepper.generic.NormalizedGradientNormStepper` (with the + gradient norm scale set with `- b₂`). + + !!! warning + In contrast to the + `exponax.stepper.generic.NormalizedConvectionStepper` and the + `exponax.stepper.generic.NormalizedGradientNormStepper`, the + nonlinear terms are no considered to be on right-hand side of the + PDE. Hence, in order to get the same dynamics as in the other + steppers, the coefficients must be negated. (This is not relevant + for the coefficient of the quadratic polynomial because in the + `exponax.stepper.generic.NormalizedPolynomialStepper` the polynomial + nonlinearity is already on the right-hand side.) + + + **Arguments:** + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `num_points`: The number of points `N` used to discretize the domain. + This **includes** the left boundary point and **excludes** the right + boundary point. In higher dimensions; the number of points in each + dimension is the same. Hence, the total number of degrees of freedom + is `Nᵈ`. + - `normalized_coefficients_linear`: The list of coefficients `αⱼ` + corresponding to the linear derivatives. The length of this tuple + represents the highest occuring derivative. The default value `(0.0, + 0.0, 0.1 * 0.1)` together with the default + `normalized_coefficients_nonlinear` corresponds to the Burgers + equation (in single-channel mode). + - `normalized_coefficients_nonlinear`: The list of coefficients `β₀`, + `β₁`, and `β₂` (in this order) corresponding to the quadratic, + (single-channel) convection, and gradient norm nonlinearity, + respectively. The default value `(0.0, -1.0 * 0.1, 0.0)` corresponds + to a (single-channel) convection nonlinearity scaled with `0.1`. + Note that all nonlinear contributions are considered to be on the + right-hand side of the PDE. Hence, in order to get the "correct + convection" dynamics, the coefficients must be negated. (Also + relevant for the gradient norm nonlinearity) + - `order`: The order of the ETDRK method to use. Must be one of {0, 1, 2, + 3, 4}. The option `0` only solves the linear part of the equation. + Use higher values for higher accuracy and stability. The default + choice of `2` is a good compromise for single precision floats. + - `dealiasing_fraction`: The fraction of the wavenumbers to keep before + evaluating the nonlinearity. The default value `2/3` corresponds to + Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. + - `num_circle_points`: How many points to use in the complex contour + integral method to compute the coefficients of the exponential time + differencing Runge Kutta method. + - `circle_radius`: The radius of the contour used to compute the + coefficients of the exponential time differencing Runge Kutta method. """ self.normalized_coefficients_linear = normalized_coefficients_linear @@ -141,7 +295,78 @@ def __init__( circle_radius: float = 1.0, ): """ - By default Burgers. + Timestepper for **difficulty-based** d-dimensional (`d ∈ {1, 2, 3}`) + semi-linear PDEs consisting of a quadratic, a single-channel convection, + and a gradient norm nonlinearity together with an arbitrary combination + of (isotropic) linear derivatives. Uses a difficulty-based interface + where the "intensity" of the dynamics reduces with increasing + resolution. This is intended such that emulator learning problems on two + resolutions are comparibly difficult. + + Different to `exponax.stepper.generic.NormalizedNonlinearStepper`, the + dynamics are defined by difficulties. The difficulties are a different + combination of normalized dynamics, `num_spatial_dims`, and + `num_points`. + + γᵢ = αᵢ Nⁱ 2ⁱ⁻¹ d + + with `d` the number of spatial dimensions, `N` the number of points, and + `αᵢ` the normalized coefficient. + + The difficulties of the nonlinear terms are + + δ₀ = β₀ + + δ₁ = β₁ * M * N * D + + δ₂ = β₂ * M * N² * D + + with `βᵢ` the normalized coefficient and `M` the maximum absolute value + of the input state (typically `1.0` if one uses the `exponax.ic` random + generators with the `max_one=True` argument). + + This interface is more natural than the normalized interface because the + difficulties for all orders (given by `i`) are around 1.0. Additionally, + they relate to stability condition of explicit Finite Difference schemes + for the particular equations. For example, for advection (`i=1`), the + absolute of the difficulty is the Courant-Friedrichs-Lewy (CFL) number. + + Under the default settings, this timestep corresponds to a Burgers + equation in single-channel mode. + + **Arguments:** + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `num_points`: The number of points `N` used to discretize the domain. + This **includes** the left boundary point and **excludes** the right + boundary point. In higher dimensions; the number of points in each + dimension is the same. Hence, the total number of degrees of freedom + is `Nᵈ`. + - `linear_difficulties`: The list of difficulties `γᵢ` corresponding to + the linear derivatives. The length of this tuple represents the + highest occuring derivative. The default value `(0.0, 0.0, 0.1 * 0.1 + / 1.0 * 48**2 * 2)` together with the default `nonlinear_difficulties` + corresponds to the Burgers equation. + - `nonlinear_difficulties`: The list of difficulties `δ₀`, `δ₁`, and `δ₂` + (in this order) corresponding to the quadratic, (single-channel) + convection, and gradient norm nonlinearity, respectively. The default + value `(0.0, -1.0 * 0.1 / 1.0 * 48, 0.0)` corresponds to a + (single-channel) convection nonlinearity. Note that all nonlinear + contributions are considered to be on the right-hand side of the PDE. + - `maximum_absolute`: The maximum absolute value of the input state. This + is used to scale the nonlinear difficulties. + - `order`: The order of the ETDRK method to use. Must be one of {0, 1, 2, + 3, 4}. The option `0` only solves the linear part of the equation. + Use higher values for higher accuracy and stability. The default + choice of `2` is a good compromise for single precision floats. + - `dealiasing_fraction`: The fraction of the wavenumbers to keep before + evaluating the nonlinearity. The default value `2/3` corresponds to + Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. + - `num_circle_points`: How many points to use in the complex contour + integral method to compute the coefficients of the exponential time + differencing Runge Kutta method. + - `circle_radius`: The radius of the contour used to compute the + coefficients of the exponential time differencing Runge Kutta method. """ self.linear_difficulties = linear_difficulties self.nonlinear_difficulties = nonlinear_difficulties diff --git a/exponax/stepper/generic/_polynomial.py b/exponax/stepper/generic/_polynomial.py index bb95451..372611b 100644 --- a/exponax/stepper/generic/_polynomial.py +++ b/exponax/stepper/generic/_polynomial.py @@ -26,11 +26,84 @@ def __init__( circle_radius: float = 1.0, ): """ - By default: Fisher-KPP with a small diffusion and 10.0 reactivity + Timestepper for the d-dimensional (`d ∈ {1, 2, 3}`) semi-linear PDEs + consisting of an arbitrary combination of polynomial nonlinearities and + (isotropic) linear derivatives. This can be used to represent a wide + array of reaction-diffusion equations. - Note that the first two entries in the polynomial_scales list are often zero. + In 1d, the PDE is of the form - The effect of polynomial_scale[1] is similar to the effect of coefficients[0] + ``` + uₜ = ∑ₖ pₖ uᵏ + ∑ⱼ aⱼ uₓʲ + ``` + + where `pₖ` are the polynomial coefficients and `aⱼ` are the linear + coefficients. `uᵏ` denotes `u` pointwise raised to the power of `k` + (hence the polynomial contribution) and `uₓʲ` denotes the `j`-th + derivative of `u`. + + The higher-dimensional generalization reads + + ``` + uₜ = ∑ₖ pₖ uᵏ + ∑ⱼ a_j (1⋅∇ʲ)u + + ``` + + where `∇ʲ` is the `j`-th derivative operator. + + The default configuration corresponds to the Fisher-KPP equation with + the following settings + + ```python + + exponax.stepper.reaction.FisherKPP( + num_spatial_dims=num_spatial_dims, domain_extent=domain_extent, + num_points=num_points, dt=dt, diffusivity=0.01, reactivity=-10.0, + #TODO: Check this + ) + ``` + + Note that the effect of polynomial_scale[1] is similar to the effect of + coefficients[0] with the difference that in ETDRK integration the latter + is treated anlytically and should be preferred. + + **Arguments:** + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `domain_extent`: The size of the domain `L`; in higher dimensions + the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`. + - `num_points`: The number of points `N` used to discretize the + domain. This **includes** the left boundary point and **excludes** + the right boundary point. In higher dimensions; the number of points + in each dimension is the same. Hence, the total number of degrees of + freedom is `Nᵈ`. + - `dt`: The timestep size `Δt` between two consecutive states. + - `coefficients`: The list of coefficients `a_j` corresponding to the + derivatives. The length of this tuple represents the highest + occuring derivative. The default value `(10.0, 0.0, 0.01)` in + combination with the default `polynomial_scales` corresponds to the + Fisher-KPP equation. + - `polynomial_scales`: The list of scales `pₖ` corresponding to the + polynomial contributions. The length of this tuple represents the + highest occuring polynomial. The default value `(0.0, 0.0, 10.0)` in + combination with the default `coefficients` corresponds to the + Fisher-KPP equation. + - `order`: The order of the Exponential Time Differencing Runge + Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only + solves the linear part of the equation. Use higher values for higher + accuracy and stability. The default choice of `2` is a good + compromise for single precision floats. + - `dealiasing_fraction`: The fraction of the wavenumbers to keep + before evaluating the nonlinearity. The default 2/3 corresponds to + Orszag's 2/3 rule which is sufficient if the highest occuring + polynomial is quadratic (i.e., there are at maximum three entries in + the `polynomial_scales` tuple). + - `num_circle_points`: How many points to use in the complex contour + integral method to compute the coefficients of the exponential time + differencing Runge Kutta method. + - `circle_radius`: The radius of the contour used to compute the + coefficients of the exponential time differencing Runge Kutta + method. """ self.coefficients = coefficients self.polynomial_scales = polynomial_scales @@ -98,7 +171,48 @@ def __init__( circle_radius: float = 1.0, ): """ - By default: Fisher-KPP + Timestepper for the **normalized** d-dimensional (`d ∈ {1, 2, 3}`) + semi-linear PDEs consisting of an arbitrary combination of polynomial + nonlinearities and (isotropic) linear derivatives. Uses a normalized + interface, i.e., the domain is scaled to `Ω = (0, 1)ᵈ` and time step + size is `Δt = 1.0`. + + See `exponax.stepper.generic.GeneralPolynomialStepper` for more details + on the functional form of the PDE. + + The default settings correspond to the Fisher-KPP equation. + + **Arguments:** + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `num_points`: The number of points `N` used to discretize the domain. + This **includes** the left boundary point and **excludes** the right + boundary point. In higher dimensions; the number of points in each + dimension is the same. Hence, the total number of degrees of freedom + is `Nᵈ`. + - `normalized_coefficients`: The list of coefficients `α_j` corresponding + to the derivatives. The length of this tuple represents the highest + occuring derivative. The default value corresponds to the Fisher-KPP + equation. + - `normalized_polynomial_scales`: The list of scales `βₖ` corresponding + to the polynomial contributions. The length of this tuple represents + the highest occuring polynomial. The default value corresponds to the + Fisher-KPP equation. + - `order`: The order of the Exponential Time Differencing Runge Kutta + method. Must be one of {0, 1, 2, 3, 4}. The option `0` only solves + the linear part of the equation. Use higher values for higher accuracy + and stability. The default choice of `2` is a good compromise for + single precision floats. + - `dealiasing_fraction`: The fraction of the wavenumbers to keep before + evaluating the nonlinearity. The default 2/3 corresponds to Orszag's + 2/3 rule which is sufficient if the highest occuring polynomial is + quadratic (i.e., there are at maximum three entries in the + `normalized_polynomial_scales` tuple). + - `num_circle_points`: How many points to use in the complex contour + integral method to compute the coefficients of the exponential time + differencing Runge Kutta method. + - `circle_radius`: The radius of the contour used to compute the + coefficients of the exponential time differencing Runge Kutta method. """ self.normalized_coefficients = normalized_coefficients self.normalized_polynomial_scales = normalized_polynomial_scales @@ -142,7 +256,63 @@ def __init__( circle_radius: float = 1.0, ): """ - By default: Fisher-KPP + Timestepper for **difficulty-based** d-dimensional (`d ∈ {1, 2, 3}`) + semi-linear PDEs consisting of an arbitrary combination of polynomial + nonlinearities and (isotropic) linear derivatives. Uses a + difficulty-based interface where the "intensity" of the dynamics reduces + with increasing resolution. This is intended such that emulator learning + problems on two resolutions are comparibly difficult. + + Different to `exponax.stepper.generic.NormalizedPolynomialStepper`, the + dynamics are defined by difficulties. The difficulties are a different + combination of normalized dynamics, `num_spatial_dims`, and + `num_points`. + + γᵢ = αᵢ Nⁱ 2ⁱ⁻¹ d + + with `d` the number of spatial dimensions, `N` the number of points, and + `αᵢ` the normalized coefficient. + + Since the polynomial nonlinearity does not contain any derivatives, we + have that + + ``` + normalized_polynomial_scales = polynomial_difficulties + ``` + + The default settings correspond to the Fisher-KPP equation. + + **Arguments:** + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `num_points`: The number of points `N` used to discretize the domain. + This **includes** the left boundary point and **excludes** the right + boundary point. In higher dimensions; the number of points in each + dimension is the same. Hence, the total number of degrees of freedom + is `Nᵈ`. + - `linear_difficulties`: The list of difficulties `γ_j` corresponding to + the derivatives. The length of this tuple represents the highest + occuring derivative. The default value corresponds to the Fisher-KPP + equation. + - `polynomial_difficulties`: The list of difficulties `δₖ` corresponding + to the polynomial contributions. The length of this tuple represents + the highest occuring polynomial. The default value corresponds to the + Fisher-KPP equation. + - `order`: The order of the Exponential Time Differencing Runge Kutta + method. Must be one of {0, 1, 2, 3, 4}. The option `0` only solves + the linear part of the equation. Use higher values for higher accuracy + and stability. The default choice of `2` is a good compromise for + single precision floats. + - `dealiasing_fraction`: The fraction of the wavenumbers to keep before + evaluating the nonlinearity. The default 2/3 corresponds to Orszag's + 2/3 rule which is sufficient if the highest occuring polynomial is + quadratic (i.e., there are at maximum three entries in the + `polynomial_difficulties` tuple). + - `num_circle_points`: How many points to use in the complex contour + integral method to compute the coefficients of the exponential time + differencing Runge Kutta method. + - `circle_radius`: The radius of the contour used to compute the + coefficients of the exponential time differencing Runge Kutta method. """ self.linear_difficulties = linear_difficulties self.polynomial_difficulties = polynomial_difficulties diff --git a/exponax/stepper/generic/_utils.py b/exponax/stepper/generic/_utils.py index 6049e34..4263913 100644 --- a/exponax/stepper/generic/_utils.py +++ b/exponax/stepper/generic/_utils.py @@ -11,11 +11,24 @@ def normalize_coefficients( Normalize the coefficients to a linear time stepper to be used with the normalized linear stepper. + αᵢ = aᵢ Δt / Lⁱ + + !!! warning + A consequence of this normalization is that the normalized coefficients + for high order derivatives will be very small. + **Arguments:** + - `coefficients`: coefficients for the linear operator, `coefficients[i]` is the coefficient for the `i`-th derivative - `domain_extent`: extent of the domain - `dt`: time step + + **Returns:** + + - `normalized_coefficients`: normalized coefficients for the linear + operator, `normalized_coefficients[i]` is the coefficient for the `i`-th + derivative """ normalized_coefficients = tuple( c * dt / (domain_extent**i) for i, c in enumerate(coefficients) @@ -31,14 +44,22 @@ def denormalize_coefficients( ) -> tuple[float, ...]: """ Denormalize the coefficients as they were used in the normalized linear to - then be used again in a regular linear stepper. + then be used again in a genric linear stepper with a physical interface. + + aᵢ = αᵢ Lⁱ / Δt **Arguments:** + - `normalized_coefficients`: coefficients for the linear operator, `normalized_coefficients[i]` is the coefficient for the `i`-th derivative - `domain_extent`: extent of the domain - `dt`: time step + + **Returns:** + + - `coefficients`: coefficients for the linear operator, `coefficients[i]` is + the coefficient for the `i`-th derivative """ coefficients = tuple( c_n / dt * domain_extent**i for i, c_n in enumerate(normalized_coefficients) @@ -52,6 +73,23 @@ def normalize_convection_scale( domain_extent: float, dt: float, ) -> float: + """ + Normalize the scale (=coefficient) in front of the convection term to be + used with the normalized generic steppers. + + β₁ = b₁ Δt / L + + **Arguments:** + + - `convection_scale`: scale in front of the convection term, i.e., the `b_1` + in `𝒩(u) = - b₁ 1/2 (u²)ₓ` + - `domain_extent`: extent of the domain + - `dt`: time step + + **Returns:** + + - `normalized_convection_scale`: normalized scale in front of the convection + """ normalized_convection_scale = convection_scale * dt / domain_extent return normalized_convection_scale @@ -62,6 +100,24 @@ def denormalize_convection_scale( domain_extent: float, dt: float, ) -> float: + """ + Denormalize the scale in front of the convection term as it was used in the + normalized generic steppers to then be used again in a generic stepper with + a physical interface. + + b₁ = β₁ L / Δt + + **Arguments:** + + - `normalized_convection_scale`: normalized scale in front of the convection + - `domain_extent`: extent of the domain + - `dt`: time step + + **Returns:** + + - `convection_scale`: scale in front of the convection term, i.e., the `b_1` + in `𝒩(u) = - b₁ 1/2 (u²)ₓ` + """ convection_scale = normalized_convection_scale / dt * domain_extent return convection_scale @@ -72,6 +128,24 @@ def normalize_gradient_norm_scale( domain_extent: float, dt: float, ): + """ + Normalize the scale in front of the gradient norm term to be used with the + normalized generic steppers. + + β₂ = b₂ Δt / L² + + **Arguments:** + + - `gradient_norm_scale`: scale in front of the gradient norm term, i.e., the + `b_2` in `𝒩(u) = - b₂ 1/2 ‖∇u‖₂²` + - `domain_extent`: extent of the domain + - `dt`: time step + + **Returns:** + + - `normalized_gradient_norm_scale`: normalized scale in front of the + gradient norm term + """ normalized_gradient_norm_scale = ( gradient_norm_scale * dt / jnp.square(domain_extent) ) @@ -84,6 +158,25 @@ def denormalize_gradient_norm_scale( domain_extent: float, dt: float, ): + """ + Denormalize the scale in front of the gradient norm term as it was used in + the normalized generic steppers to then be used again in a generic stepper + with a physical interface. + + b₂ = β₂ L² / Δt + + **Arguments:** + + - `normalized_gradient_norm_scale`: normalized scale in front of the gradient + norm term + - `domain_extent`: extent of the domain + - `dt`: time step + + **Returns:** + + - `gradient_norm_scale`: scale in front of the gradient norm term, i.e., the + `b_2` in `𝒩(u) = - b₂ 1/2 ‖∇u‖₂²` + """ gradient_norm_scale = ( normalized_gradient_norm_scale / dt * jnp.square(domain_extent) ) @@ -101,11 +194,18 @@ def normalize_polynomial_scales( stepper. **Arguments:** - - `polynomial_scales`: scales for the polynomial operator, - `polynomial_scales[i]` is the scale for the `i`-th derivative - - `domain_extent`: extent of the domain (not needed, kept for - compatibility with other normalization APIs) - - `dt`: time step + + - `polynomial_scales`: scales for the polynomial operator, + `polynomial_scales[i]` is the scale for the `i`-th degree polynomial + - `domain_extent`: extent of the domain (not needed, kept for + compatibility with other normalization APIs) + - `dt`: time step + + **Returns:** + + - `normalized_polynomial_scales`: normalized scales for the polynomial + operator, `normalized_polynomial_scales[i]` is the scale for the `i`-th + degree polynomial """ normalized_polynomial_scales = tuple(c * dt for c in polynomial_scales) return normalized_polynomial_scales @@ -122,12 +222,17 @@ def denormalize_polynomial_scales( polynomial to then be used again in a regular polynomial stepper. **Arguments:** - - `normalized_polynomial_scales`: scales for the polynomial operator, - `normalized_polynomial_scales[i]` is the scale for the `i`-th - derivative - - `domain_extent`: extent of the domain (not needed, kept for - compatibility with other normalization APIs) - - `dt`: time step + + - `normalized_polynomial_scales`: scales for the polynomial operator, + `normalized_polynomial_scales[i]` is the scale for the `i`-th degree + polynomial + - `domain_extent`: extent of the domain (not needed, kept for + compatibility with other normalization APIs) + - `dt`: time step + + **Returns:** + + - `polynomial_scales`: scales for the polynomial operator, """ polynomial_scales = tuple(c_n / dt for c_n in normalized_polynomial_scales) return polynomial_scales @@ -139,6 +244,31 @@ def reduce_normalized_coefficients_to_difficulty( num_spatial_dims: int, num_points: int, ): + """ + Reduce the normalized coefficients for a linear operator to a difficulty + based interface. This interface is designed to "reduce the intensity of the + dynamics" at higher resolutions to make emulator learning across resolutions + comparible. Thereby, it resembles the stability numbers of the most compact + finite difference scheme of the respective PDE. + + γ₀ = α₀ + + γⱼ = αⱼ Nʲ 2ʲ⁻¹ D + + **Arguments:** + + - `normalized_coefficients`: normalized coefficients for the linear + operator, `normalized_coefficients[i]` is the coefficient for the `i`-th + derivative + - `num_spatial_dims`: number of spatial dimensions `d` + - `num_points`: number of points `N` used to discretize the domain per + dimension + + **Returns:** + + - `difficulty_coefficients`: difficulty coefficients for the linear operator, + `difficulty_coefficients[i]` is the coefficient for the `i`-th derivative + """ difficulty_coefficients = list( alpha * num_points**j * 2 ** (j - 1) * num_spatial_dims for j, alpha in enumerate(normalized_coefficients) @@ -155,6 +285,27 @@ def extract_normalized_coefficients_from_difficulty( num_spatial_dims: int, num_points: int, ): + """ + Extract the normalized coefficients for a linear operator from a difficulty + based interface. + + α₀ = γ₀ + + αⱼ = γⱼ / (Nʲ 2ʲ⁻¹ D) + + **Arguments:** + + - `difficulty_coefficients`: difficulty coefficients for the linear operator, + `difficulty_coefficients[i]` is the coefficient for the `i`-th derivative + - `num_spatial_dims`: number of spatial dimensions `d` + - `num_points`: number of points `N` used to discretize the domain per + dimension + + **Returns:** + + - `normalized_coefficients`: normalized coefficients for the linear operator, + `normalized_coefficients[i]` is the coefficient for the `i`-th derivative + """ normalized_coefficients = list( gamma / (num_points**j * 2 ** (j - 1) * num_spatial_dims) for j, gamma in enumerate(difficulty_coefficients) @@ -172,6 +323,25 @@ def reduce_normalized_convection_scale_to_difficulty( num_points: int, maximum_absolute: float, ): + """ + Reduce the normalized convection scale to a difficulty based interface. + + δ₁ = β₁ * M * N * D + + **Arguments:** + + - `normalized_convection_scale`: normalized convection scale, see also + `exponax.stepper.generic.normalize_convection_scale` + - `num_spatial_dims`: number of spatial dimensions `d` + - `num_points`: number of points `N` used to discretize the domain per + dimension + - `maximum_absolute`: maximum absolute value of the input state the + resulting stepper is applied to + + **Returns:** + + - `difficulty_convection_scale`: difficulty convection scale + """ difficulty_convection_scale = ( normalized_convection_scale * maximum_absolute * num_points * num_spatial_dims ) @@ -185,6 +355,25 @@ def extract_normalized_convection_scale_from_difficulty( num_points: int, maximum_absolute: float, ): + """ + Extract the normalized convection scale from a difficulty based interface. + + β₁ = δ₁ / (M * N * D) + + **Arguments:** + + - `difficulty_convection_scale`: difficulty convection scale + - `num_spatial_dims`: number of spatial dimensions `d` + - `num_points`: number of points `N` used to discretize the domain per + dimension + - `maximum_absolute`: maximum absolute value of the input state the + resulting stepper is applied to + + **Returns:** + + - `normalized_convection_scale`: normalized convection scale, see also + `exponax.stepper.generic.normalize_convection_scale` + """ normalized_convection_scale = difficulty_convection_scale / ( maximum_absolute * num_points * num_spatial_dims ) @@ -198,6 +387,25 @@ def reduce_normalized_gradient_norm_scale_to_difficulty( num_points: int, maximum_absolute: float, ): + """ + Reduce the normalized gradient norm scale to a difficulty based interface. + + δ₂ = β₂ * M * N² * D + + **Arguments:** + + - `normalized_gradient_norm_scale`: normalized gradient norm scale, see also + `exponax.stepper.generic.normalize_gradient_norm_scale` + - `num_spatial_dims`: number of spatial dimensions `d` + - `num_points`: number of points `N` used to discretize the domain per + dimension + - `maximum_absolute`: maximum absolute value of the input state the + resulting stepper is applied to + + **Returns:** + + - `difficulty_gradient_norm_scale`: difficulty gradient norm scale + """ difficulty_gradient_norm_scale = ( normalized_gradient_norm_scale * maximum_absolute @@ -214,6 +422,25 @@ def extract_normalized_gradient_norm_scale_from_difficulty( num_points: int, maximum_absolute: float, ): + """ + Extract the normalized gradient norm scale from a difficulty based interface. + + β₂ = δ₂ / (M * N² * D) + + **Arguments:** + + - `difficulty_gradient_norm_scale`: difficulty gradient norm scale + - `num_spatial_dims`: number of spatial dimensions `d` + - `num_points`: number of points `N` used to discretize the domain per + dimension + - `maximum_absolute`: maximum absolute value of the input state the + resulting stepper is applied to + + **Returns:** + + - `normalized_gradient_norm_scale`: normalized gradient norm scale, see also + `exponax.stepper.generic.normalize_gradient_norm_scale` + """ normalized_gradient_norm_scale = difficulty_gradient_norm_scale / ( maximum_absolute * jnp.square(num_points) * num_spatial_dims ) @@ -227,6 +454,34 @@ def reduce_normalized_nonlinear_scales_to_difficulty( num_points: int, maximum_absolute: float, ): + """ + Reduce the normalized nonlinear scales associated with a quadratic, a + (single-channel) convection term, and a gradient norm term to a difficulty + based interface. + + δ₀ = β₀ + + δ₁ = β₁ * M * N * D + + δ₂ = β₂ * M * N² * D + + **Arguments:** + + - `normalized_nonlinear_scales`: normalized nonlinear scales associated with + a quadratic, a (single-channel) convection term, and a gradient norm + term (in this order) + - `num_spatial_dims`: number of spatial dimensions `d` + - `num_points`: number of points `N` used to discretize the domain per + dimension + - `maximum_absolute`: maximum absolute value of the input state the + resulting stepper is applied to + + **Returns:** + + - `nonlinear_difficulties`: difficulty nonlinear scales associated with a + quadratic, a (single-channel) convection term, and a gradient norm term + (in this order) + """ nonlinear_difficulties = ( normalized_nonlinear_scales[0], # Polynomial: normalized == difficulty reduce_normalized_convection_scale_to_difficulty( @@ -252,6 +507,34 @@ def extract_normalized_nonlinear_scales_from_difficulty( num_points: int, maximum_absolute: float, ): + """ + Extract the normalized nonlinear scales associated with a quadratic, a + (single-channel) convection term, and a gradient norm term from a difficulty + based interface. + + β₀ = δ₀ + + β₁ = δ₁ / (M * N * D) + + β₂ = δ₂ / (M * N² * D) + + **Arguments:** + + - `nonlinear_difficulties`: difficulty nonlinear scales associated with a + quadratic, a (single-channel) convection term, and a gradient norm term + (in this order) + - `num_spatial_dims`: number of spatial dimensions `d` + - `num_points`: number of points `N` used to discretize the domain per + dimension + - `maximum_absolute`: maximum absolute value of the input state the + resulting stepper is applied to + + **Returns:** + + - `normalized_nonlinear_scales`: normalized nonlinear scales associated with + a quadratic, a (single-channel) convection term, and a gradient norm term + (in this order) + """ normalized_nonlinear_scales = ( nonlinear_difficulties[0], # Polynomial: normalized == difficulty extract_normalized_convection_scale_from_difficulty( diff --git a/exponax/stepper/generic/_vorticity_convection.py b/exponax/stepper/generic/_vorticity_convection.py index 8f64e81..f75b2c4 100644 --- a/exponax/stepper/generic/_vorticity_convection.py +++ b/exponax/stepper/generic/_vorticity_convection.py @@ -28,6 +28,64 @@ def __init__( num_circle_points: int = 16, circle_radius: float = 1.0, ): + """ + Timestepper for 2D PDEs consisting of vorticity convection term and an + arbitrary combination of (isotropic) linear derivatives. + + ``` + uₜ + b ([1, -1]ᵀ ⊙ ∇(Δ⁻¹u)) ⋅ ∇u = sum_j a_j (1⋅∇ʲ)u + ``` + + where `b` is the vorticity convection scale, `a_j` are the coefficients + of the linear derivatives, and `∇ʲ` is the `j`-th derivative operator. + + In the default configuration, this corresponds to the 2D Navier-Stokes + simulation with a viscosity of `ν = 0.001` (the resulting Reynols number + depends on the `domain_extent`. In the case of a unit square domain, + i.e., `domain_extent = 1`, the Reynols number is `Re = 1/ν = 1000`). + Depending on the initial state, this corresponds to a decaying 2D + turbulence. + + Additionally, one can set an `injection_mode` and `injection_scale` to + inject energy into the system. For example, this allows for the + simulation of forced turbulence (=Kolmogorov flow). + + **Arguments:** + + - `num_spatial_dims`: number of spatial dimensions `D`. + - `domain_extent`: The size of the domain `L`; in higher dimensions + the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`. + - `num_points`: The number of points `N` used to discretize the + domain. This **includes** the left boundary point and **excludes** + the right boundary point. In higher dimensions; the number of points + in each dimension is the same. Hence, the total number of degrees of + freedom is `Nᵈ`. + - `dt`: The timestep size `Δt` between two consecutive states. + - `coefficients`: The list of coefficients `a_j` + corresponding to the derivatives. The length of this tuple + represents the highest occuring derivative. The default value `(0.0, + 0.0, 0.001)` corresponds to pure regular diffusion. + - `vorticity_convection_scale`: The scale `b` of the vorticity + convection term. + - `injection_mode`: The mode of the injection. + - `injection_scale`: The scale of the injection. Defaults to `0.0` which + means no injection. Hence, the flow will decay over time. + - `dealiasing_fraction`: The fraction of the modes that are kept after + dealiasing. The default value `2/3` corresponds to the 2/3 rule. + - `order`: The order of the ETDRK method to use. Must be one of {0, 1, + 2, 3, 4}. The option `0` only solves the linear part of the + equation. Hence, only use this for linear PDEs. For nonlinear PDEs, + a higher order method tends to be more stable and accurate. `2` is + often a good compromis in single-precision. Use `4` together with + double precision (`jax.config.update("jax_enable_x64", True)`) for + highest accuracy. + - `num_circle_points`: How many points to use in the complex contour + integral method to compute the coefficients of the exponential time + differencing Runge Kutta method. + - `circle_radius`: The radius of the contour used to compute the + coefficients of the exponential time differencing Runge Kutta + method. + """ if num_spatial_dims != 2: raise ValueError(f"Expected num_spatial_dims = 2, got {num_spatial_dims}.") self.vorticity_convection_scale = vorticity_convection_scale