Skip to content

Commit

Permalink
Better documentation for jnp.nan_to_num
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 9, 2024
1 parent 2f67710 commit f19ee75
Showing 1 changed file with 43 additions and 1 deletion.
44 changes: 43 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3400,11 +3400,53 @@ def fix(x: ArrayLike, out: None = None) -> Array:
return where(lax.ge(x, zero), ufuncs.floor(x), ufuncs.ceil(x))


@util.implements(np.nan_to_num)
@jit
def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0,
posinf: ArrayLike | None = None,
neginf: ArrayLike | None = None) -> Array:
"""Replace NaN and infinite entries in an array.
JAX implementation of :func:`numpy.nan_to_num`.
Args:
x: array of values to be replaced. If it does not have an inexact
dtype it will be returned unmodified.
copy: unused by JAX
nan: value to substitute for NaN entries. Defaults to 0.0.
posinf: value to substitute for positive infinite entries.
Defaults to the maximum representable value.
neginf: value to substitute for positive infinite entries.
Defaults to the minimum representable value.
Returns:
A copy of ``x`` with the requested substitutions.
See also:
- :func:`jax.numpy.isnan`: return True where the array contains NaN
- :func:`jax.numpy.isposinf`: return True where the array contains +inf
- :func:`jax.numpy.isneginf`: return True where the array contains -inf
Examples:
>>> x = jnp.array([0, jnp.nan, 1, jnp.inf, 2, -jnp.inf])
Default substitution values:
>>> jnp.nan_to_num(x)
Array([ 0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 3.4028235e+38,
2.0000000e+00, -3.4028235e+38], dtype=float32)
Overriding substitutions for ``-inf`` and ``+inf``:
>>> jnp.nan_to_num(x, posinf=999, neginf=-999)
Array([ 0., 0., 1., 999., 2., -999.], dtype=float32)
If you only wish to substitute for NaN values while leaving ``inf`` values
untouched, using :func:`~jax.numpy.where` with :func:`jax.numpy.isnan` is
a better option:
>>> jnp.where(jnp.isnan(x), 0, x)
Array([ 0., 0., 1., inf, 2., -inf], dtype=float32)
"""
del copy
util.check_arraylike("nan_to_num", x)
dtype = _dtype(x)
Expand Down

0 comments on commit f19ee75

Please sign in to comment.