Skip to content

Commit

Permalink
Now reporting the correct source code line numbers when using the imp…
Browse files Browse the repository at this point in the history
…ort hook
  • Loading branch information
patrick-kidger committed Jun 13, 2024
1 parent 342a567 commit 59aeeec
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions jaxtyping/_import_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,17 @@ def _optimized_cache_from_source(typechecker_hash, /, path, debug_override=None)
# for importlib and decorator lookup.
# Version 8: Now using new-style `jaxtyped(typechecker=...)` rather than old-style
# double-decorators.
# Version 9: Now reporting the correct source code lines. (Important when used with
# a debugger.)
return cache_from_source(
path, debug_override, optimization=f"jaxtyping8{typechecker_hash}"
path, debug_override, optimization=f"jaxtyping9{typechecker_hash}"
)


class Typechecker:
lookup = {}

def __init__(self, typechecker):
self.ast = None

if isinstance(typechecker, str):
# If the typechecker is a string, then we parse it
string_to_eval = (
Expand Down Expand Up @@ -121,18 +121,17 @@ def get_hash(self):
return self.hash

def get_ast(self):
# we compile AST only if we missed importlib cache
if self.ast is None:
self.ast = (
ast.parse(
f"@jaxtyping.jaxtyped(typechecker=jaxtyping._import_hook.Typechecker.lookup['{self.hash}'])\n"
"def _():\n ..."
)
.body[0]
.decorator_list[0]
# Note that we compile AST only if we missed importlib cache.
# No caching on this function! We modify the return type every time, with
# its appropriate source code location.
return (
ast.parse(
f"@jaxtyping.jaxtyped(typechecker=jaxtyping._import_hook.Typechecker.lookup['{self.hash}'])\n"
"def _():\n ..."
)

return self.ast
.body[0]
.decorator_list[0]
)


class JaxtypingTransformer(ast.NodeVisitor):
Expand All @@ -159,7 +158,9 @@ def visit_Module(self, node: ast.Module):
def visit_ClassDef(self, node: ast.ClassDef):
# Place at the start of the decorator list, so that `@dataclass` decorators get
# called first.
node.decorator_list.insert(0, self._typechecker.get_ast())
decorator = self._typechecker.get_ast()
ast.copy_location(decorator, node)
node.decorator_list.insert(0, decorator)
self._parents.append(node)
self.generic_visit(node)
self._parents.pop()
Expand All @@ -173,6 +174,8 @@ def visit_FunctionDef(self, node: ast.FunctionDef):
# had type annotations in the body of the function (or
# `assert isinstance(..., SomeType)`).

decorator = self._typechecker.get_ast()
ast.copy_location(decorator, node)
# Place at the end of the decorator list, because:
# - as otherwise we wrap e.g. `jax.custom_{jvp,vjp}` and lose the ability
# to `defjvp` etc.
Expand All @@ -187,7 +190,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef):
# case we're just going to have to need to ask the user to remove their
# typechecking annotation (and let this decorator do it instead).
# It's more important we be compatible with normal JAX code.
node.decorator_list.append(self._typechecker.get_ast())
node.decorator_list.append(decorator)

self._parents.append(node)
self.generic_visit(node)
Expand Down

0 comments on commit 59aeeec

Please sign in to comment.