Skip to content

Commit

Permalink
Improve docs for jnp.average
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 8, 2024
1 parent 023f2a7 commit 5620dfd
Showing 1 changed file with 54 additions and 5 deletions.
59 changes: 54 additions & 5 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from jax._src import dtypes
from jax._src.numpy.util import (
_broadcast_to, check_arraylike, _complex_elem_type,
promote_dtypes_inexact, promote_dtypes_numeric, _where, implements)
promote_dtypes_inexact, promote_dtypes_numeric, _where)
from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg
from jax._src.util import (
Expand Down Expand Up @@ -700,9 +700,8 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
An array of the mean along the given axis.
See also:
- :func:`jax.numpy.sum`: Compute the sum of array elements over a given axis.
- :func:`jax.numpy.max`: Compute the maximum of array elements over given axis.
- :func:`jax.numpy.min`: Compute the minimum of array elements over given axis.
- :func:`jax.numpy.average`: Compute the weighted average of array elements
- :func:`jax.numpy.sum`: Compute the sum of array elements.
Examples:
By default, the mean is computed along all the axes.
Expand Down Expand Up @@ -782,9 +781,59 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, *
@overload
def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: ...
@implements(np.average)
def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]:
"""Compute the weighed average.
JAX Implementation of :func:`numpy.average`.
Args:
a: array to be averaged
axis: optional, int or sequence of ints, default=None. Axis along which
the mean to be computed. If None, mean is computed along all the axes.
weights: an optional array of weights for a weighted average. Must be
broadcast-compatible with ``a``.
returned: bool, default=False. If true then return the normalization factor
(i.e. the sum of weights).
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
Returns:
An array ``average`` or tuple of arrays ``(average, normalization)`` if
``returned`` is True.
See also:
- :func:`jax.numpy.mean`: unweighted mean.
Examples:
Simple average:
>>> x = jnp.array([1, 2, 3, 2, 4])
>>> jnp.average(x)
Array(2.4, dtype=float32)
Weighted average:
>>> weights = jnp.array([2, 1, 3, 2, 2])
>>> jnp.average(x, weights=weights)
Array(2.5, dtype=float32)
Use ``returned=True`` to optionally return the normalization, i.e. the
sum of weights:
>>> jnp.average(x, returned=True)
(Array(2.4, dtype=float32), Array(5., dtype=float32))
>>> jnp.average(x, weights=weights, returned=True)
(Array(2.5, dtype=float32), Array(10., dtype=float32))
Weighted average along a specified axis:
>>> x = jnp.array([[8, 2, 7],
... [3, 6, 4]])
>>> weights = jnp.array([1, 2, 3])
>>> jnp.average(x, weights=weights, axis=1)
Array([5.5, 4.5], dtype=float32)
"""
return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims)

@partial(api.jit, static_argnames=('axis', 'returned', 'keepdims'), inline=True)
Expand Down

0 comments on commit 5620dfd

Please sign in to comment.