Skip to content

Commit

Permalink
Improve docs for jnp.partition & argpartition
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 29, 2024
1 parent 26b4848 commit 1c5319d
Showing 1 changed file with 99 additions and 18 deletions.
117 changes: 99 additions & 18 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5517,17 +5517,57 @@ def argsort(
return lax.rev(indices, dimensions=[dimension]) if descending else indices


@util.implements(np.partition, lax_description="""
The JAX version requires the ``kth`` argument to be a static integer rather than
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
you're only accessing the top or bottom k values of the output, it may be more
efficient to call :func:`jax.lax.top_k` directly.
The JAX version differs from the NumPy version in the treatment of NaN entries;
NaNs which have the negative bit set are sorted to the beginning of the array.
""")
@partial(jit, static_argnames=['kth', 'axis'])
def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
"""Returns a partially-sorted copy of an array.
JAX implementation of :func:`numpy.partition`. The JAX version differs from
NumPy in the treatment of NaN entries: NaNs which have the negative bit set
are sorted to the beginning of the array.
Args:
a: array to be partitioned.
kth: static integer index about which to partition the array.
axis: static integer axis along which to partition the array; default is -1.
Returns:
A copy of ``a`` partitioned at the ``kth`` value along ``axis``. The entries
before ``kth`` are values smaller than ``take(a, kth, axis)``, and entries
after ``kth`` are indices of values larger than ``take(a, kth, axis)``
Note:
The JAX version requires the ``kth`` argument to be a static integer rather than
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
you're only accessing the top or bottom k values of the output, it may be more
efficient to call :func:`jax.lax.top_k` directly.
See Also:
- :func:`jax.numpy.sort`: full sort
- :func:`jax.numpy.argpartition`: indirect partial sort
- :func:`jax.lax.top_k`: directly find the top k entries
- :func:`jax.lax.approx_max_k`: compute the approximate top k entries
- :func:`jax.lax.approx_min_k`: compute the approximate bottom k entries
Examples:
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
>>> kth = 4
>>> x_partitioned = jnp.partition(x, kth)
>>> x_partitioned
Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)
The result is a partially-sorted copy of the input. All values before ``kth``
are of smaller than the pivot value, and all values after ``kth`` are larger
than the pivot value:
>>> smallest_values = x_partitioned[:kth]
>>> pivot_value = x_partitioned[kth]
>>> largest_values = x_partitioned[kth + 1:]
>>> print(smallest_values, pivot_value, largest_values)
[1 2 3 3] 4 [9 8 7 6 5]
Notice that among ``smallest_values`` and ``largest_values``, the returned
order is arbitrary and implementation-dependent.
"""
# TODO(jakevdp): handle NaN values like numpy.
util.check_arraylike("partition", a)
arr = asarray(a)
Expand All @@ -5543,17 +5583,58 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
return swapaxes(out, -1, axis)


@util.implements(np.argpartition, lax_description="""
The JAX version requires the ``kth`` argument to be a static integer rather than
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
you're only accessing the top or bottom k values of the output, it may be more
efficient to call :func:`jax.lax.top_k` directly.
The JAX version differs from the NumPy version in the treatment of NaN entries;
NaNs which have the negative bit set are sorted to the beginning of the array.
""")
@partial(jit, static_argnames=['kth', 'axis'])
def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
"""Returns indices that partially sort an array.
JAX implementation of :func:`numpy.argpartition`. The JAX version differs from
NumPy in the treatment of NaN entries: NaNs which have the negative bit set are
sorted to the beginning of the array.
Args:
a: array to be partitioned.
kth: static integer index about which to partition the array.
axis: static integer axis along which to partition the array; default is -1.
Returns:
Indices which partition ``a`` at the ``kth`` value along ``axis``. The entries
before ``kth`` are indices of values smaller than ``take(a, kth, axis)``, and
entries after ``kth`` are indices of values larger than ``take(a, kth, axis)``
Note:
The JAX version requires the ``kth`` argument to be a static integer rather than
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
you're only accessing the top or bottom k values of the output, it may be more
efficient to call :func:`jax.lax.top_k` directly.
See Also:
- :func:`jax.numpy.partition`: direct partial sort
- :func:`jax.numpy.argsort`: full indirect sort
- :func:`jax.lax.top_k`: directly find the top k entries
- :func:`jax.lax.approx_max_k`: compute the approximate top k entries
- :func:`jax.lax.approx_min_k`: compute the approximate bottom k entries
Examples:
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
>>> kth = 4
>>> idx = jnp.argpartition(x, kth)
>>> idx
Array([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32)
The result is a sequence of indices that partially sort the input. All indices
before ``kth`` are of values smaller than the pivot value, and all indices
after ``kth`` are of values larger than the pivot value:
>>> x_partitioned = x[idx]
>>> smallest_values = x_partitioned[:kth]
>>> pivot_value = x_partitioned[kth]
>>> largest_values = x_partitioned[kth + 1:]
>>> print(smallest_values, pivot_value, largest_values)
[1 2 3 3] 4 [6 8 9 7 5]
Notice that among ``smallest_values`` and ``largest_values``, the returned
order is arbitrary and implementation-dependent.
"""
# TODO(jakevdp): handle NaN values like numpy.
util.check_arraylike("partition", a)
arr = asarray(a)
Expand Down

0 comments on commit 1c5319d

Please sign in to comment.