diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 2e222bf6c612..14eb6edebc20 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -729,14 +729,54 @@ def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array: return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis]) -@util.implements(np.fliplr, lax_description=_ARRAY_VIEW_DOC) def fliplr(m: ArrayLike) -> Array: + """Reverse the order of elements of an array along axis 1. + + JAX implementation of :func:`numpy.fliplr`. + + Args: + m: Array with atleast two dimenssions. + + Returns: + An array with the elements in reverse order along axis 1. + + See Also: + - :func:`jax.numpy.flip`: reverse the order along the given axis + - :func:`jax.numpy.flipud`: reverse the order along axis 0 + + Example: + >>> x = jnp.array([[1, 2], + ... [3, 4]]) + >>> jnp.fliplr(x1) + Array([[2, 1], + [4, 3]], dtype=int32) + """ util.check_arraylike("fliplr", m) return _flip(asarray(m), 1) -@util.implements(np.flipud, lax_description=_ARRAY_VIEW_DOC) def flipud(m: ArrayLike) -> Array: + """Reverse the order of elements of an array along axis 0. + + JAX implementation of :func:`numpy.flipud`. + + Args: + m: Array with atleast one dimension. + + Returns: + An array with the elements in reverse order along axis 0. + + See Also: + - :func:`jax.numpy.flip`: reverse the order along the given axis + - :func:`jax.numpy.fliplr`: reverse the order along axis 1 + + Example: + >>> x = jnp.array([[1, 2], + ... [3, 4]]) + >>> jnp.flipud(x) + Array([[3, 4], + [1, 2]], dtype=int32) + """ util.check_arraylike("flipud", m) return _flip(asarray(m), 0)