From 7e05c6e53864e077659a9dc59263aaa10ed3c1c3 Mon Sep 17 00:00:00 2001 From: Alex Fan Date: Thu, 13 Jun 2024 20:04:38 +0800 Subject: [PATCH] Add support for numpy struct dtype --- jaxtyping/__init__.py | 1 + jaxtyping/_array_types.py | 38 ++++++++++++++++++++++++++++++++++++++ test/test_array.py | 18 ++++++++++++++++++ 3 files changed, 57 insertions(+) diff --git a/jaxtyping/__init__.py b/jaxtyping/__init__.py index a4694eb..f50dd92 100644 --- a/jaxtyping/__init__.py +++ b/jaxtyping/__init__.py @@ -28,6 +28,7 @@ AbstractArray as AbstractArray, AbstractDtype as AbstractDtype, get_array_name_format as get_array_name_format, + make_numpy_struct_dtype as make_numpy_struct_dtype, set_array_name_format as set_array_name_format, ) from ._config import config as config diff --git a/jaxtyping/_array_types.py b/jaxtyping/_array_types.py index fb539ae..91838e8 100644 --- a/jaxtyping/_array_types.py +++ b/jaxtyping/_array_types.py @@ -157,6 +157,10 @@ def _check_dims( return "" +def _dtype_is_numpy_struct_array(dtype): + return dtype.type.__name__ == "void" and dtype is not np.dtype(np.void) + + class _MetaAbstractArray(type): _skip_instancecheck: bool = False @@ -177,6 +181,9 @@ def __instancecheck_str__(cls, obj: Any) -> str: if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"): # JAX, numpy dtype = obj.dtype.type.__name__ + # numpy structured array is strictly a subtype of np.void + if _dtype_is_numpy_struct_array(obj.dtype): + dtype = str(obj.dtype) elif hasattr(obj.dtype, "as_numpy_dtype"): # TensorFlow dtype = obj.dtype.as_numpy_dtype.__name__ @@ -755,3 +762,34 @@ class _Cls(AbstractDtype): Shaped = _make_dtype(_any_dtype, "Shaped") Key = _make_dtype(_prng_key, "Key") + + +def make_numpy_struct_dtype(dtype: np.dtype, name: str): + """Creates a type annotation for [numpy structured array](https://numpy.org/doc/stable/user/basics.rec.html#structured-arrays) + It does exact match on the name, order, and dtype of all its fields. + + !!! Example + + ```python + label_t = np.dtype([('first', np.uint8), ('second', np.int8)]) + Label = make_numpy_struct_dtype(label_t, 'Label') + ``` + after that, you can use it just like any AbstractDtype + ```python + a: Label[np.ndarray, 'a b'] = np.array([[(1, 0), (0, 1)]], dtype=label_t) + ``` + + **Arguments:** + + - `dtype`: The numpy dtype that the returned annotation matches + + - `name`: The python class name for the returned dtype annotation + + **Returns:** + + A type annotation with classname `name` and matching exactly `dtype`. + It can be used like any usual subclasses of AbstractDtypes. + """ + if not (isinstance(dtype, np.dtype) and _dtype_is_numpy_struct_array(dtype)): + raise ValueError(f"Expecting a numpy structured array dtype, not {dtype}") + return _make_dtype(str(dtype), name) diff --git a/test/test_array.py b/test/test_array.py index 33c4ab7..4f24cdb 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -92,6 +92,24 @@ def test_dtypes(): assert key == val.__name__ +def test_numpy_struct_dtype(): + from jaxtyping import make_numpy_struct_dtype + + dtype1 = np.dtype([("first", np.uint8), ("second", bool)]) + Dtype1 = make_numpy_struct_dtype(dtype1, "Dtype1") + arr = np.array([0, False], dtype=dtype1) + + assert isinstance(arr, Dtype1[np.ndarray, "_"]) + + dtype2 = np.dtype([("third", np.uint8), ("second", bool)]) + Dtype2 = make_numpy_struct_dtype(dtype2, "Dtype2") + assert not isinstance(arr, Dtype2[np.ndarray, "_"]) + + dtype3 = np.dtype([("second", bool), ("first", np.uint8)]) + Dtype3 = make_numpy_struct_dtype(dtype3, "Dtype3") + assert not isinstance(arr, Dtype3[np.ndarray, "_"]) + + def test_return(jaxtyp, typecheck, getkey): @jaxtyp(typecheck) def g(x: Float[Array, "b c"]) -> Float[Array, "c b"]: