From b9f69f4f1b318817bd5f3696eb60b3e574a81b4a Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Wed, 25 Sep 2024 22:33:31 +0530 Subject: [PATCH] Update Returns --- jax/_src/numpy/lax_numpy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b7c8e1da5411..5bff629dc8a8 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6630,8 +6630,8 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int out: Not used by JAX. Returns: - A 0-D array for 2-D input, and in general a N-2 dimensional array for - N-dimensional input. + An array of dimension x.ndim-2 containing the sum of the diagonal elements + along axes (axis1, axis2) See also: - :func:`jax.numpy.diag`: Returns the specified diagonal or constructs a diagonal