diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index db779d8..9fa014c 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -342,6 +342,7 @@ def modify_annotation(ann): @ft.wraps(fn) def wrapped_fn(*args, **kwargs): # pyright: ignore bound = signature.bind(*args, **kwargs) + bound.apply_defaults() memos = push_shape_memo(bound.arguments) try: return fn(*args, **kwargs) @@ -410,6 +411,7 @@ def wrapped_fn(*args, **kwargs): # Raise bind-time errors before we do any shape analysis. (I.e. skip # the pointless jaxtyping information for a non-typechecking failure.) bound = param_signature.bind(*args, **kwargs) + bound.apply_defaults() memos = push_shape_memo(bound.arguments) try: diff --git a/test/test_decorator.py b/test/test_decorator.py index c6d3c1f..02fc5d2 100644 --- a/test/test_decorator.py +++ b/test/test_decorator.py @@ -117,6 +117,17 @@ def f(x: int, y=1): f(1) +def test_default_bindings(getkey, jaxtyp, typecheck): + @jaxtyp(typecheck) + def f(x: int, y: int = 1) -> Float[Array, "x {y}"]: + return jr.normal(getkey(), (x, y)) + + f(1) + f(1, 1) + f(1, 0) + f(1, 5) + + class _GlobalFoo: pass