From 99928904380eaabbcfa999608fc5e79b9794065f Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 21 Dec 2023 08:33:14 -0800 Subject: [PATCH] jnp.zeros{,_like} no longer produces a Zero. This is from some discussion in #2. The existing approach is a little bit magic. This also means that we can probably remove dynamic tracing. --- quax/zero/_core.py | 18 +++--------------- tests/test_zero.py | 7 ++++--- 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/quax/zero/_core.py b/quax/zero/_core.py index c16e8ab..a94332c 100644 --- a/quax/zero/_core.py +++ b/quax/zero/_core.py @@ -5,9 +5,8 @@ import jax.core import jax.lax as lax import jax.numpy as jnp -import numpy as np -from .._core import ArrayValue, DenseArrayValue, quaxify_keepwrap, register +from .._core import ArrayValue, quaxify_keepwrap, register class Zero(ArrayValue): @@ -40,19 +39,8 @@ def materialise(self): @register(lax.broadcast_in_dim_p) -def _(value: DenseArrayValue, *, broadcast_dimensions, shape) -> ArrayValue: - arraylike = value.array - aval = jax.core.get_aval(arraylike) - if isinstance(aval, jax.core.ConcreteArray) and aval.shape == () and aval.val == 0: - return Zero(shape, np.result_type(arraylike)) - else: - # Avoid an infinite loop, by pushing a new interpreter to the dynamic - # interpreter stack. - with jax.ensure_compile_time_eval(): - out = lax.broadcast_in_dim_p.bind( - arraylike, broadcast_dimensions=broadcast_dimensions, shape=shape - ) - return DenseArrayValue(out) # pyright: ignore +def _(value: Zero, *, broadcast_dimensions, shape) -> Zero: + return Zero(shape, value.dtype) @register(lax.broadcast_in_dim_p) diff --git a/tests/test_zero.py b/tests/test_zero.py index de2093c..b732e4e 100644 --- a/tests/test_zero.py +++ b/tests/test_zero.py @@ -126,6 +126,7 @@ def run(a): out = run(1) out2 = quax.quaxify(jnp.zeros)((3, 4)) - zero = quax.zero.Zero((3, 4), jnp.float32) - assert eqx.tree_equal(out, zero) - assert eqx.tree_equal(out2, zero) + assert isinstance(out, jax.Array) + assert isinstance(out2, jax.Array) + assert not isinstance(out, quax.zero.Zero) + assert not isinstance(out2, quax.zero.Zero)