diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index 8f8bba2..6f96637 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -393,16 +393,18 @@ def wrapped_fn(*args, **kwargs): # pyright: ignore param_signature = full_signature.replace( return_annotation=inspect.Signature.empty ) + name = getattr(fn, "__name__", "") + qualname = getattr(fn, "__qualname__", "") module = getattr(fn, "__module__", "generated") + # Use the same name so that typeguard warnings look correct. full_fn, output_name = _make_fn_with_signature( - "check_return", full_signature, module, output=True + name, qualname, module, full_signature, output=True ) - full_fn = typechecker(full_fn) - param_fn = _make_fn_with_signature( - "check_params", param_signature, module, output=False + name, qualname, module, param_signature, output=False ) + full_fn = typechecker(full_fn) param_fn = typechecker(param_fn) @ft.wraps(fn) @@ -565,17 +567,19 @@ def _check_dataclass_annotations(self, typechecker): values[field.name] = value signature = inspect.Signature(parameters) - module = self.__class__.__module__ f = _make_fn_with_signature( - self.__class__.__name__, signature, module, output=False + self.__class__.__name__, + self.__class__.__qualname__, + self.__class__.__module__, + signature, + output=False, ) - f.__qualname__ = self.__class__.__qualname__ f = jaxtyped(f, typechecker=typechecker) f(self, **values) def _make_fn_with_signature( - name: str, signature: inspect.Signature, module: str, output: bool + name: str, qualname: str, module: str, signature: inspect.Signature, output: bool ): """Dynamically creates a function `fn` with name `name` and signature `signature`. @@ -697,6 +701,7 @@ def _make_fn_with_signature( exec(fnstr, scope) fn = scope[name] fn.__module__ = module + fn.__qualname__ = qualname assert fn is not None if output: return fn, output_name @@ -745,7 +750,7 @@ def _get_problem_arg( assert keep_annotation is not sentinel new_signature = inspect.Signature(new_parameters) fn = _make_fn_with_signature( - "check_single_arg", new_signature, module, output=False + "check_single_arg", "check_single_arg", module, new_signature, output=False ) fn = typechecker(fn) # but no `jaxtyped`; keep the same environment. try: diff --git a/test/test_array.py b/test/test_array.py index 4f24cdb..1185a51 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -154,7 +154,7 @@ def g(x: Shaped[Array, "a b"]) -> Shaped[Array, "a b"]: g(jnp.array([[1, 2], [3, 4]], dtype=jnp.int8)) g(jnp.array([[1, 2], [3, 4]], dtype=jnp.uint4)) g(jnp.array([[1, 2], [3, 4]], dtype=jnp.uint16)) - g(jr.normal(getkey(), (3, 4), dtype=jnp.complex128)) + g(jr.normal(getkey(), (3, 4), dtype=jnp.complex64)) g(jr.normal(getkey(), (3, 4), dtype=jnp.bfloat16)) with pytest.raises(ParamError): diff --git a/test/test_decorator.py b/test/test_decorator.py index 0b34a8f..93aab75 100644 --- a/test/test_decorator.py +++ b/test/test_decorator.py @@ -96,7 +96,7 @@ def test_context(getkey): def test_varargs(jaxtyp, typecheck): @jaxtyp(typecheck) - def f(*args): + def f(*args) -> None: pass f(1, 2) @@ -104,7 +104,7 @@ def f(*args): def test_varkwargs(jaxtyp, typecheck): @jaxtyp(typecheck) - def f(**kwargs): + def f(**kwargs) -> None: pass f(a=1, b=2) @@ -112,7 +112,7 @@ def f(**kwargs): def test_defaults(jaxtyp, typecheck): @jaxtyp(typecheck) - def f(x: int, y=1): + def f(x: int, y=1) -> None: pass f(1) diff --git a/test/test_generators.py b/test/test_generators.py index 1d04862..29a2dd2 100644 --- a/test/test_generators.py +++ b/test/test_generators.py @@ -20,7 +20,7 @@ def gen(x: Float[Array, "*"]) -> Iterator[Float[Array, "*"]]: yield x @jaxtyp(typecheck) - def foo(): + def foo() -> None: next(gen(jnp.zeros(2))) next(gen(jnp.zeros((3, 4)))) @@ -81,7 +81,7 @@ def g(x: Shaped[torch.Tensor, "*"]) -> Iterator[Shaped[torch.Tensor, "*"]]: yield x @jaxtyp(typecheck) - def f(): + def f() -> None: next(g(torch.zeros(1))) next(g(torch.zeros(2)))