Skip to content

Commit

Permalink
Consistent docs (#26)
Browse files Browse the repository at this point in the history
* Add documentation to base

* Add documentation to etdrk core

* Adapt to Equinox style

* Change to Equinox style

* Change to Equinox style

* Change to Equinox style

* Fix typo

* Consistent documentation

* Improve nonlin fun base documentation

* Change to Equinox style

* Improve docs for convection

* Improve docs for gradient norm

* Improve docs for combined general nonlinearity

* Fix docstring

* Improve and fix documentation

* Fix broken link

* Adapt to Equinox style

* Improve documentation

* Add clarification

* better wording

* Fix convection difficulty computation

* Improve doc of generic gradient norm stepper

* Add docstring

* Add documentation

* Add documentation

* Add documentation
  • Loading branch information
Ceyron authored Sep 3, 2024
1 parent 9ea7788 commit bea4079
Show file tree
Hide file tree
Showing 40 changed files with 1,999 additions and 429 deletions.
52 changes: 38 additions & 14 deletions exponax/_base_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,14 @@ def _build_linear_operator(
Assemble the L operator in Fourier space.
**Arguments:**
- `derivative_operator`: The derivative operator, shape `( D, ...,
N//2+1 )`. The ellipsis are (D-1) axis of size N (**not** of size
N//2+1).
- `derivative_operator`: The derivative operator, shape `( D, ...,
N//2+1 )`. The ellipsis are (D-1) axis of size N (**not** of size
N//2+1).
**Returns:**
- `L`: The linear operator, shape `( C, ..., N//2+1 )`.
- `L`: The linear operator, shape `( C, ..., N//2+1 )`.
"""
pass

Expand All @@ -183,12 +185,15 @@ def _build_nonlinear_fun(
transforms to Fourier space, and evaluates derivatives there.
**Arguments:**
- `derivative_operator`: The derivative operator, shape `( D, ..., N//2+1 )`.
- `derivative_operator`: The derivative operator, shape `( D, ...,
N//2+1 )`.
**Returns:**
- `nonlinear_fun`: A function that evaluates the nonlinearities in
time space, transforms to Fourier space, and evaluates the
derivatives there. Should be a subclass of `BaseNonlinearFun`.
- `nonlinear_fun`: A function that evaluates the nonlinearities in
time space, transforms to Fourier space, and evaluates the
derivatives there. Should be a subclass of `BaseNonlinearFun`.
"""
pass

Expand All @@ -197,10 +202,12 @@ def step(self, u: Float[Array, "C ... N"]) -> Float[Array, "C ... N"]:
Perform one step of the time integration.
**Arguments:**
- `u`: The state vector, shape `(C, ..., N,)`.
- `u`: The state vector, shape `(C, ..., N,)`.
**Returns:**
- `u_next`: The state vector after one step, shape `(C, ..., N,)`.
- `u_next`: The state vector after one step, shape `(C, ..., N,)`.
"""
u_hat = fft(u, num_spatial_dims=self.num_spatial_dims)
u_next_hat = self.step_fourier(u_hat)
Expand All @@ -220,11 +227,13 @@ def step_fourier(
transforms.
**Arguments:**
- `u_hat`: The (real) Fourier transform of the state vector
- `u_hat`: The (real) Fourier transform of the state vector
**Returns:**
- `u_next_hat`: The (real) Fourier transform of the state vector
after one step
- `u_next_hat`: The (real) Fourier transform of the state vector
after one step
"""
return self._integrator.step_fourier(u_hat)

Expand All @@ -233,7 +242,22 @@ def __call__(
u: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
"""
Performs a check
Perform one step of the time integration for a single state.
**Arguments:**
- `u`: The state vector, shape `(C, ..., N,)`.
**Returns:**
- `u_next`: The state vector after one step, shape `(C, ..., N,)`.
!!! tip
Use this call method together with `exponax.rollout` to efficiently
produce temporal trajectories.
!!! info
For batched operation, use `jax.vmap` on this function.
"""
expected_shape = (self.num_channels,) + spatial_shape(
self.num_spatial_dims, self.num_points
Expand Down
32 changes: 19 additions & 13 deletions exponax/_forced_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(
transient integrators to forced problems.
**Arguments**:
- `stepper`: The stepper to be transformed.
- `stepper`: The stepper to be transformed.
"""
self.stepper = stepper

Expand All @@ -49,11 +50,13 @@ def step(
The forcing term `f` is assumed to be evaluated on the same grid as `u`.
**Arguments**:
- `u`: The current state.
- `f`: The forcing term.
- `u`: The current state.
- `f`: The forcing term.
**Returns**:
- `u_next`: The state after one time step.
- `u_next`: The state after one time step.
"""
u_with_force = u + self.stepper.dt * f
return self.stepper.step(u_with_force)
Expand All @@ -71,11 +74,13 @@ def step_fourier(
`u_hat`.
**Arguments**:
- `u_hat`: The current state in Fourier space.
- `f_hat`: The forcing term in Fourier space.
- `u_hat`: The current state in Fourier space.
- `f_hat`: The forcing term in Fourier space.
**Returns**:
- `u_next_hat`: The state after one time step in Fourier space.
- `u_next_hat`: The state after one time step in Fourier space.
"""
u_hat_with_force = u_hat + self.stepper.dt * f_hat
return self.stepper.step_fourier(u_hat_with_force)
Expand All @@ -91,12 +96,13 @@ def __call__(
The forcing term `f` is assumed to be evaluated on the same grid as `u`.
**Arguments**:
- `u`: The current state.
- `f`: The forcing term.
**Arguments:**
**Returns**:
- `u_next`: The state after one time step.
"""
- `u`: The current state.
- `f`: The forcing term.
**Returns:**
- `u_next`: The state after one time step.
"""
return self.step(u, f)
36 changes: 26 additions & 10 deletions exponax/_poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ def __init__(
It is included for completion.
**Arguments:**
- `num_spatial_dims`: The number of spatial dimensions.
- `domain_extent`: The extent of the domain.
- `num_points`: The number of points in each spatial dimension.
- `order`: The order of the Poisson equation. Defaults to 2. You can
also set `order=4` for the biharmonic equation.
- `num_spatial_dims`: The number of spatial dimensions.
- `domain_extent`: The extent of the domain.
- `num_points`: The number of points in each spatial dimension.
- `order`: The order of the Poisson equation. Defaults to 2. You can
also set `order=4` for the biharmonic equation.
"""
self.num_spatial_dims = num_spatial_dims
self.domain_extent = domain_extent
Expand All @@ -71,10 +72,12 @@ def step_fourier(
Solve the Poisson equation in Fourier space.
**Arguments:**
- `f_hat`: The Fourier transform of the right hand side.
- `f_hat`: The Fourier transform of the right hand side.
**Returns:**
- `u_hat`: The Fourier transform of the solution.
- `u_hat`: The Fourier transform of the solution.
"""
return -self._inv_operator * f_hat

Expand All @@ -83,13 +86,15 @@ def step(
f: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
"""
Solve the Poisson equation in real space.
Solve the Poisson equation in state space.
**Arguments:**
- `f`: The right hand side.
- `f`: The right hand side.
**Returns:**
- `u`: The solution.
- `u`: The solution.
"""
f_hat = fft(f, num_spatial_dims=self.num_spatial_dims)
u_hat = self.step_fourier(f_hat)
Expand All @@ -104,6 +109,17 @@ def __call__(
self,
f: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
"""
Solve the Poisson equation in state space.
**Arguments:**
- `f`: The right hand side.
**Returns:**
- `u`: The solution.
"""
if f.shape[1:] != spatial_shape(self.num_spatial_dims, self.num_points):
raise ValueError(
f"Shape of f[1:] is {f.shape[1:]} but should be {spatial_shape(self.num_spatial_dims, self.num_points)}"
Expand Down
32 changes: 29 additions & 3 deletions exponax/_repeated_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ def __init__(
time step of X/Y and then wrap it in a RepeatedStepper with num_sub_steps=Y.
**Arguments:**
- `stepper`: The stepper to repeat.
- `num_sub_steps`: The number of substeps to perform.
- `stepper`: The stepper to repeat.
- `num_sub_steps`: The number of substeps to perform.
"""
self.stepper = stepper
self.num_sub_steps = num_sub_steps
Expand All @@ -52,8 +53,16 @@ def step(
u: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
"""
Step the PDE forward in time by self.num_sub_steps time steps given the
Step the PDE forward in time by `self.num_sub_steps` time steps given the
current state `u`.
**Arguments:**
- `u`: The current state.
**Returns:**
- `u_next`: The state after `self.num_sub_steps` time steps.
"""
return repeat(self.stepper.step, self.num_sub_steps)(u)

Expand All @@ -64,6 +73,15 @@ def step_fourier(
"""
Step the PDE forward in time by self.num_sub_steps time steps given the
current state `u_hat` in real-valued Fourier space.
**Arguments:**
- `u_hat`: The current state in Fourier space.
**Returns:**
- `u_next_hat`: The state after `self.num_sub_steps` time steps in Fourier
space.
"""
return repeat(self.stepper.step_fourier, self.num_sub_steps)(u_hat)

Expand All @@ -74,5 +92,13 @@ def __call__(
"""
Step the PDE forward in time by self.num_sub_steps time steps given the
current state `u`.
**Arguments:**
- `u`: The current state.
**Returns:**
- `u_next`: The state after `self.num_sub_steps` time steps.
"""
return repeat(self.stepper, self.num_sub_steps)(u)
Loading

0 comments on commit bea4079

Please sign in to comment.