Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.zeros{,_like} no longer produces a Zero. #3

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 3 additions & 15 deletions quax/zero/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import jax.core
import jax.lax as lax
import jax.numpy as jnp
import numpy as np

from .._core import ArrayValue, DenseArrayValue, quaxify_keepwrap, register
from .._core import ArrayValue, quaxify_keepwrap, register


class Zero(ArrayValue):
Expand Down Expand Up @@ -40,19 +39,8 @@ def materialise(self):


@register(lax.broadcast_in_dim_p)
def _(value: DenseArrayValue, *, broadcast_dimensions, shape) -> ArrayValue:
arraylike = value.array
aval = jax.core.get_aval(arraylike)
if isinstance(aval, jax.core.ConcreteArray) and aval.shape == () and aval.val == 0:
return Zero(shape, np.result_type(arraylike))
else:
# Avoid an infinite loop, by pushing a new interpreter to the dynamic
# interpreter stack.
with jax.ensure_compile_time_eval():
out = lax.broadcast_in_dim_p.bind(
arraylike, broadcast_dimensions=broadcast_dimensions, shape=shape
)
return DenseArrayValue(out) # pyright: ignore
def _(value: Zero, *, broadcast_dimensions, shape) -> Zero:
return Zero(shape, value.dtype)


@register(lax.broadcast_in_dim_p)
Expand Down
7 changes: 4 additions & 3 deletions tests/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def run(a):

out = run(1)
out2 = quax.quaxify(jnp.zeros)((3, 4))
zero = quax.zero.Zero((3, 4), jnp.float32)
assert eqx.tree_equal(out, zero)
assert eqx.tree_equal(out2, zero)
assert isinstance(out, jax.Array)
assert isinstance(out2, jax.Array)
assert not isinstance(out, quax.zero.Zero)
assert not isinstance(out2, quax.zero.Zero)
Loading