Skip to content

Commit

Permalink
Merge pull request #24113 from rajasekharporeddy:testbranch1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 682344324
  • Loading branch information
Google-ML-Automation committed Oct 4, 2024
2 parents 8f423e0 + 321b9bc commit 46b7bfa
Showing 1 changed file with 40 additions and 1 deletion.
41 changes: 40 additions & 1 deletion jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3122,9 +3122,48 @@ def isnan(x: ArrayLike, /) -> Array:
return lax.ne(x, x)


@implements(np.heaviside, module='numpy')
@jit
def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
r"""Compute the heaviside step function.
JAX implementation of :obj:`numpy.heaviside`.
The heaviside step function is defined by:
.. math::
\mathrm{heaviside}(x1, x2) = \begin{cases}
0., & x < 0\\
x2, & x = 0\\
1., & x > 0.
\end{cases}
Args:
x1: input array or scalar. ``complex`` dtype are not supported.
x2: scalar or array. Specifies the return values when ``x1`` is ``0``. ``complex``
dtype are not supported. ``x1`` and ``x2`` must either have same shape or
broadcast compatible.
Returns:
An array containing the heaviside step function of ``x1``, promoting to
inexact dtype.
Examples:
>>> x1 = jnp.array([[-2, 0, 3],
... [5, -1, 0],
... [0, 7, -3]])
>>> x2 = jnp.array([2, 0.5, 1])
>>> jnp.heaviside(x1, x2)
Array([[0. , 0.5, 1. ],
[1. , 0. , 1. ],
[2. , 1. , 0. ]], dtype=float32)
>>> jnp.heaviside(x1, 0.5)
Array([[0. , 0.5, 1. ],
[1. , 0. , 0.5],
[0.5, 1. , 0. ]], dtype=float32)
>>> jnp.heaviside(-3, x2)
Array([0., 0., 0.], dtype=float32)
"""
check_arraylike("heaviside", x1, x2)
x1, x2 = promote_dtypes_inexact(x1, x2)
zero = _lax_const(x1, 0)
Expand Down

0 comments on commit 46b7bfa

Please sign in to comment.