Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for numpy struct dtype #211

Merged
merged 1 commit into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jaxtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
AbstractArray as AbstractArray,
AbstractDtype as AbstractDtype,
get_array_name_format as get_array_name_format,
make_numpy_struct_dtype as make_numpy_struct_dtype,
set_array_name_format as set_array_name_format,
)
from ._config import config as config
Expand Down
38 changes: 38 additions & 0 deletions jaxtyping/_array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ def _check_dims(
return ""


def _dtype_is_numpy_struct_array(dtype):
return dtype.type.__name__ == "void" and dtype is not np.dtype(np.void)


class _MetaAbstractArray(type):
_skip_instancecheck: bool = False

Expand All @@ -177,6 +181,9 @@ def __instancecheck_str__(cls, obj: Any) -> str:
if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"):
# JAX, numpy
dtype = obj.dtype.type.__name__
# numpy structured array is strictly a subtype of np.void
if _dtype_is_numpy_struct_array(obj.dtype):
dtype = str(obj.dtype)
elif hasattr(obj.dtype, "as_numpy_dtype"):
# TensorFlow
dtype = obj.dtype.as_numpy_dtype.__name__
Expand Down Expand Up @@ -755,3 +762,34 @@ class _Cls(AbstractDtype):
Shaped = _make_dtype(_any_dtype, "Shaped")

Key = _make_dtype(_prng_key, "Key")


def make_numpy_struct_dtype(dtype: np.dtype, name: str):
"""Creates a type annotation for [numpy structured array](https://numpy.org/doc/stable/user/basics.rec.html#structured-arrays)
It does exact match on the name, order, and dtype of all its fields.

!!! Example

```python
label_t = np.dtype([('first', np.uint8), ('second', np.int8)])
Label = make_numpy_struct_dtype(label_t, 'Label')
```
after that, you can use it just like any AbstractDtype
```python
a: Label[np.ndarray, 'a b'] = np.array([[(1, 0), (0, 1)]], dtype=label_t)
```

**Arguments:**

- `dtype`: The numpy dtype that the returned annotation matches

- `name`: The python class name for the returned dtype annotation

**Returns:**

A type annotation with classname `name` and matching exactly `dtype`.
It can be used like any usual subclasses of AbstractDtypes.
"""
if not (isinstance(dtype, np.dtype) and _dtype_is_numpy_struct_array(dtype)):
raise ValueError(f"Expecting a numpy structured array dtype, not {dtype}")
return _make_dtype(str(dtype), name)
18 changes: 18 additions & 0 deletions test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,24 @@ def test_dtypes():
assert key == val.__name__


def test_numpy_struct_dtype():
from jaxtyping import make_numpy_struct_dtype

dtype1 = np.dtype([("first", np.uint8), ("second", bool)])
Dtype1 = make_numpy_struct_dtype(dtype1, "Dtype1")
arr = np.array([0, False], dtype=dtype1)

assert isinstance(arr, Dtype1[np.ndarray, "_"])

dtype2 = np.dtype([("third", np.uint8), ("second", bool)])
Dtype2 = make_numpy_struct_dtype(dtype2, "Dtype2")
assert not isinstance(arr, Dtype2[np.ndarray, "_"])

dtype3 = np.dtype([("second", bool), ("first", np.uint8)])
Dtype3 = make_numpy_struct_dtype(dtype3, "Dtype3")
assert not isinstance(arr, Dtype3[np.ndarray, "_"])


def test_return(jaxtyp, typecheck, getkey):
@jaxtyp(typecheck)
def g(x: Float[Array, "b c"]) -> Float[Array, "c b"]:
Expand Down
Loading