Skip to content

Commit

Permalink
Compatibility with Python 3.10 + with generic array types
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jul 12, 2024
1 parent f7c57c3 commit c2f19db
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
36 changes: 28 additions & 8 deletions jaxtyping/_array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,16 @@
import types
import typing
from dataclasses import dataclass
from typing import Any, Literal, NoReturn, Optional, TypeVar, Union
from typing import (
Any,
get_args,
get_origin,
Literal,
NoReturn,
Optional,
TypeVar,
Union,
)


# Bit of a hack, but jaxtyping provides nicer error messages than typeguard. This means
Expand Down Expand Up @@ -358,7 +367,7 @@ class for `Float32[Array, "foo"]`.

_not_made = object()

_union_types = [typing.Union]
_union_types = [Union]
if sys.version_info >= (3, 10):
_union_types.append(types.UnionType)

Expand Down Expand Up @@ -517,6 +526,9 @@ def _make_array_cached(array_type, dim_str, dtypes, name):
# Allow Python built-in numeric types.
# TODO: do something more generic than this? Should we _make all types
# that have `shape` and `dtype` attributes or something?
array_origin = get_origin(array_type)
if array_origin is not None:
array_type = array_origin
if array_type is bool:
if _check_scalar("bool", dtypes, dims):
return array_type
Expand Down Expand Up @@ -547,7 +559,7 @@ def _make_array_cached(array_type, dim_str, dtypes, name):
return array_type
else:
return _not_made
if issubclass(array_type, AbstractArray):
if array_type is not Any and issubclass(array_type, AbstractArray):
if dtypes is _any_dtype:
dtypes = array_type.dtypes
elif array_type.dtypes is not _any_dtype:
Expand Down Expand Up @@ -588,11 +600,15 @@ def _make_array(*args, **kwargs):

if type(out) is tuple:
array_type, name, dtypes, dims, index_variadic, dim_str = out
metaclass = _make_metaclass(type(array_type))
metaclass = (
_make_metaclass(type)
if array_type is Any
else _make_metaclass(type(array_type))
)

out = metaclass(
name,
(array_type, AbstractArray),
(AbstractArray,) if array_type is Any else (array_type, AbstractArray),
dict(
array_type=array_type,
dtypes=dtypes,
Expand Down Expand Up @@ -629,14 +645,18 @@ def __getitem__(cls, item: tuple[Any, str]):
if isinstance(array_type, TypeVar):
bound = array_type.__bound__
if bound is None:
array_type = Any
constraints = array_type.__constraints__
if constraints == ():
array_type = Any
else:
array_type = Union[constraints]
else:
array_type = bound
del item
if typing.get_origin(array_type) in _union_types:
if get_origin(array_type) in _union_types:
out = [
_make_array(x, dim_str, cls.dtypes, cls.__name__)
for x in typing.get_args(array_type)
for x in get_args(array_type)
]
out = tuple(x for x in out if x is not _not_made)
if len(out) == 0:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "jaxtyping"
version = "0.2.32"
version = "0.2.33"
description = "Type annotations and runtime checking for shape and dtype of JAX arrays, and PyTrees."
readme = "README.md"
requires-python ="~=3.9"
Expand Down

0 comments on commit c2f19db

Please sign in to comment.