Skip to content

Commit

Permalink
typeguard warnings, on not having any annotations, now name the corre…
Browse files Browse the repository at this point in the history
…ct function.
  • Loading branch information
patrick-kidger committed Jun 25, 2024
1 parent 40f494b commit 4716fc1
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 15 deletions.
23 changes: 14 additions & 9 deletions jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,16 +393,18 @@ def wrapped_fn(*args, **kwargs): # pyright: ignore
param_signature = full_signature.replace(
return_annotation=inspect.Signature.empty
)
name = getattr(fn, "__name__", "<no name found>")
qualname = getattr(fn, "__qualname__", "<no qualname found>")
module = getattr(fn, "__module__", "generated")

# Use the same name so that typeguard warnings look correct.
full_fn, output_name = _make_fn_with_signature(
"check_return", full_signature, module, output=True
name, qualname, module, full_signature, output=True
)
full_fn = typechecker(full_fn)

param_fn = _make_fn_with_signature(
"check_params", param_signature, module, output=False
name, qualname, module, param_signature, output=False
)
full_fn = typechecker(full_fn)
param_fn = typechecker(param_fn)

@ft.wraps(fn)
Expand Down Expand Up @@ -565,17 +567,19 @@ def _check_dataclass_annotations(self, typechecker):
values[field.name] = value

signature = inspect.Signature(parameters)
module = self.__class__.__module__
f = _make_fn_with_signature(
self.__class__.__name__, signature, module, output=False
self.__class__.__name__,
self.__class__.__qualname__,
self.__class__.__module__,
signature,
output=False,
)
f.__qualname__ = self.__class__.__qualname__
f = jaxtyped(f, typechecker=typechecker)
f(self, **values)


def _make_fn_with_signature(
name: str, signature: inspect.Signature, module: str, output: bool
name: str, qualname: str, module: str, signature: inspect.Signature, output: bool
):
"""Dynamically creates a function `fn` with name `name` and signature `signature`.
Expand Down Expand Up @@ -697,6 +701,7 @@ def _make_fn_with_signature(
exec(fnstr, scope)
fn = scope[name]
fn.__module__ = module
fn.__qualname__ = qualname
assert fn is not None
if output:
return fn, output_name
Expand Down Expand Up @@ -745,7 +750,7 @@ def _get_problem_arg(
assert keep_annotation is not sentinel
new_signature = inspect.Signature(new_parameters)
fn = _make_fn_with_signature(
"check_single_arg", new_signature, module, output=False
"check_single_arg", "check_single_arg", module, new_signature, output=False
)
fn = typechecker(fn) # but no `jaxtyped`; keep the same environment.
try:
Expand Down
2 changes: 1 addition & 1 deletion test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def g(x: Shaped[Array, "a b"]) -> Shaped[Array, "a b"]:
g(jnp.array([[1, 2], [3, 4]], dtype=jnp.int8))
g(jnp.array([[1, 2], [3, 4]], dtype=jnp.uint4))
g(jnp.array([[1, 2], [3, 4]], dtype=jnp.uint16))
g(jr.normal(getkey(), (3, 4), dtype=jnp.complex128))
g(jr.normal(getkey(), (3, 4), dtype=jnp.complex64))
g(jr.normal(getkey(), (3, 4), dtype=jnp.bfloat16))

with pytest.raises(ParamError):
Expand Down
6 changes: 3 additions & 3 deletions test/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,23 @@ def test_context(getkey):

def test_varargs(jaxtyp, typecheck):
@jaxtyp(typecheck)
def f(*args):
def f(*args) -> None:
pass

f(1, 2)


def test_varkwargs(jaxtyp, typecheck):
@jaxtyp(typecheck)
def f(**kwargs):
def f(**kwargs) -> None:
pass

f(a=1, b=2)


def test_defaults(jaxtyp, typecheck):
@jaxtyp(typecheck)
def f(x: int, y=1):
def f(x: int, y=1) -> None:
pass

f(1)
Expand Down
4 changes: 2 additions & 2 deletions test/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def gen(x: Float[Array, "*"]) -> Iterator[Float[Array, "*"]]:
yield x

@jaxtyp(typecheck)
def foo():
def foo() -> None:
next(gen(jnp.zeros(2)))
next(gen(jnp.zeros((3, 4))))

Expand Down Expand Up @@ -81,7 +81,7 @@ def g(x: Shaped[torch.Tensor, "*"]) -> Iterator[Shaped[torch.Tensor, "*"]]:
yield x

@jaxtyp(typecheck)
def f():
def f() -> None:
next(g(torch.zeros(1)))
next(g(torch.zeros(2)))

Expand Down

0 comments on commit 4716fc1

Please sign in to comment.