Skip to content

Commit

Permalink
Improve documentation for jnp.fliplr and jnp.flipud
Browse files Browse the repository at this point in the history
  • Loading branch information
rajasekharporeddy committed Jun 7, 2024
1 parent 55d0f5e commit d3b9461
Showing 1 changed file with 42 additions and 2 deletions.
44 changes: 42 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,14 +729,54 @@ def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array:
return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis])


@util.implements(np.fliplr, lax_description=_ARRAY_VIEW_DOC)
def fliplr(m: ArrayLike) -> Array:
"""Reverse the order of elements of an array along axis 1.
JAX implementation of :func:`numpy.fliplr`.
Args:
m: Array with atleast two dimenssions.
Returns:
An array with the elements in reverse order along axis 1.
See Also:
- :func:`jax.numpy.flip`: reverse the order along the given axis
- :func:`jax.numpy.flipud`: reverse the order along axis 0
Example:
>>> x = jnp.array([[1, 2],
... [3, 4]])
>>> jnp.fliplr(x1)
Array([[2, 1],
[4, 3]], dtype=int32)
"""
util.check_arraylike("fliplr", m)
return _flip(asarray(m), 1)


@util.implements(np.flipud, lax_description=_ARRAY_VIEW_DOC)
def flipud(m: ArrayLike) -> Array:
"""Reverse the order of elements of an array along axis 0.
JAX implementation of :func:`numpy.flipud`.
Args:
m: Array with atleast one dimension.
Returns:
An array with the elements in reverse order along axis 0.
See Also:
- :func:`jax.numpy.flip`: reverse the order along the given axis
- :func:`jax.numpy.fliplr`: reverse the order along axis 1
Example:
>>> x = jnp.array([[1, 2],
... [3, 4]])
>>> jnp.flipud(x)
Array([[3, 4],
[1, 2]], dtype=int32)
"""
util.check_arraylike("flipud", m)
return _flip(asarray(m), 0)

Expand Down

0 comments on commit d3b9461

Please sign in to comment.