Skip to content

Commit

Permalink
Fix fill_value handling for complex dtypes (#2200)
Browse files Browse the repository at this point in the history
* Fix fill_value handling for complex & datetime dtypes

* cleanup

* more cleanup

* more cleanup

* Fix default fill_value

* Fixes

* Add booleans

* Add v2, v3 specific dtypes

* Add version.py to gitignore

* cleanpu

* style: pre-commit fixes

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dcherian and pre-commit-ci[bot] authored Sep 25, 2024
1 parent fafd0bf commit 692593b
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 65 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,5 @@ fixture/
.DS_Store
tests/.hypothesis
.hypothesis/

zarr/version.py
7 changes: 0 additions & 7 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,6 @@ async def _create_v3(
shape = parse_shapelike(shape)
codecs = list(codecs) if codecs is not None else [BytesCodec()]

if fill_value is None:
if dtype == np.dtype("bool"):
fill_value = False
else:
fill_value = 0

if chunk_key_encoding is None:
chunk_key_encoding = ("default", "/")
assert chunk_key_encoding is not None
Expand All @@ -281,7 +275,6 @@ async def _create_v3(
)

array = cls(metadata=metadata, store_path=store_path)

await array._save_metadata(metadata)
return array

Expand Down
7 changes: 6 additions & 1 deletion src/zarr/core/buffer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,14 @@ def __repr__(self) -> str:

def all_equal(self, other: Any, equal_nan: bool = True) -> bool:
"""Compare to `other` using np.array_equal."""
if other is None:
# Handle None fill_value for Zarr V2
return False
# use array_equal to obtain equal_nan=True functionality
data, other = np.broadcast_arrays(self._data, other)
result = np.array_equal(self._data, other, equal_nan=equal_nan)
result = np.array_equal(
self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "US" else False
)
return result

def fill(self, value: Any) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def parse_fill_value(
if fill_value is None:
return dtype.type(0)
if isinstance(fill_value, Sequence) and not isinstance(fill_value, str):
if dtype in (np.complex64, np.complex128):
if dtype.type in (np.complex64, np.complex128):
dtype = cast(COMPLEX_DTYPE, dtype)
if len(fill_value) == 2:
# complex datatypes serialize to JSON arrays with two elements
Expand Down Expand Up @@ -391,7 +391,7 @@ def parse_fill_value(
pass
elif fill_value in ["Infinity", "-Infinity"] and not np.isfinite(casted_value):
pass
elif dtype.kind == "f":
elif dtype.kind in "cf":
# float comparison is not exact, especially when dtype <float64
# so we us np.isclose for this comparison.
# this also allows us to compare nan fill_values
Expand Down
112 changes: 62 additions & 50 deletions src/zarr/testing/strategies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
from typing import Any
from typing import Any, Literal

import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
Expand All @@ -19,6 +18,35 @@
max_leaves=3,
)


def v3_dtypes() -> st.SearchStrategy[np.dtype]:
return (
npst.boolean_dtypes()
| npst.integer_dtypes(endianness="=")
| npst.unsigned_integer_dtypes(endianness="=")
| npst.floating_dtypes(endianness="=")
| npst.complex_number_dtypes(endianness="=")
# | npst.byte_string_dtypes(endianness="=")
# | npst.unicode_string_dtypes()
# | npst.datetime64_dtypes()
# | npst.timedelta64_dtypes()
)


def v2_dtypes() -> st.SearchStrategy[np.dtype]:
return (
npst.boolean_dtypes()
| npst.integer_dtypes(endianness="=")
| npst.unsigned_integer_dtypes(endianness="=")
| npst.floating_dtypes(endianness="=")
| npst.complex_number_dtypes(endianness="=")
| npst.byte_string_dtypes(endianness="=")
| npst.unicode_string_dtypes(endianness="=")
| npst.datetime64_dtypes()
# | npst.timedelta64_dtypes()
)


# From https://zarr-specs.readthedocs.io/en/latest/v3/core/v3.0.html#node-names
# 1. must not be the empty string ("")
# 2. must not include the character "/"
Expand All @@ -33,21 +61,29 @@
array_names = node_names
attrs = st.none() | st.dictionaries(_attr_keys, _attr_values)
paths = st.lists(node_names, min_size=1).map(lambda x: "/".join(x)) | st.just("/")
np_arrays = npst.arrays(
# TODO: re-enable timedeltas once they are supported
dtype=npst.scalar_dtypes().filter(
lambda x: (x.kind not in ["m", "M"]) and (x.byteorder not in [">"])
),
shape=npst.array_shapes(max_dims=4),
)
stores = st.builds(MemoryStore, st.just({}), mode=st.just("w"))
compressors = st.sampled_from([None, "default"])
format = st.sampled_from([2, 3])
zarr_formats: st.SearchStrategy[Literal[2, 3]] = st.sampled_from([2, 3])
array_shapes = npst.array_shapes(max_dims=4)


@st.composite # type: ignore[misc]
def numpy_arrays(
draw: st.DrawFn,
*,
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
zarr_formats: st.SearchStrategy[Literal[2, 3]] = zarr_formats,
) -> Any:
"""
Generate numpy arrays that can be saved in the provided Zarr format.
"""
zarr_format = draw(zarr_formats)
return draw(npst.arrays(dtype=v3_dtypes() if zarr_format == 3 else v2_dtypes(), shape=shapes))


@st.composite # type: ignore[misc]
def np_array_and_chunks(
draw: st.DrawFn, *, arrays: st.SearchStrategy[np.ndarray] = np_arrays
draw: st.DrawFn, *, arrays: st.SearchStrategy[np.ndarray] = numpy_arrays
) -> tuple[np.ndarray, tuple[int]]: # type: ignore[type-arg]
"""A hypothesis strategy to generate small sized random arrays.
Expand All @@ -66,73 +102,49 @@ def np_array_and_chunks(
def arrays(
draw: st.DrawFn,
*,
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
compressors: st.SearchStrategy = compressors,
stores: st.SearchStrategy[StoreLike] = stores,
arrays: st.SearchStrategy[np.ndarray] = np_arrays,
paths: st.SearchStrategy[None | str] = paths,
array_names: st.SearchStrategy = array_names,
arrays: st.SearchStrategy | None = None,
attrs: st.SearchStrategy = attrs,
format: st.SearchStrategy = format,
zarr_formats: st.SearchStrategy = zarr_formats,
) -> Array:
store = draw(stores)
nparray, chunks = draw(np_array_and_chunks(arrays=arrays))
path = draw(paths)
name = draw(array_names)
attributes = draw(attrs)
zarr_format = draw(format)
zarr_format = draw(zarr_formats)
if arrays is None:
arrays = numpy_arrays(shapes=shapes, zarr_formats=st.just(zarr_format))
nparray, chunks = draw(np_array_and_chunks(arrays=arrays))
# test that None works too.
fill_value = draw(st.one_of([st.none(), npst.from_dtype(nparray.dtype)]))
# compressor = draw(compressors)

# TODO: clean this up
# if path is None and name is None:
# array_path = None
# array_name = None
# elif path is None and name is not None:
# array_path = f"{name}"
# array_name = f"/{name}"
# elif path is not None and name is None:
# array_path = path
# array_name = None
# elif path == "/":
# assert name is not None
# array_path = name
# array_name = "/" + name
# else:
# assert name is not None
# array_path = f"{path}/{name}"
# array_name = "/" + array_path

expected_attrs = {} if attributes is None else attributes

array_path = path + ("/" if not path.endswith("/") else "") + name
root = Group.from_store(store, zarr_format=zarr_format)
fill_value_args: tuple[Any, ...] = tuple()
if nparray.dtype.kind == "M":
m = re.search(r"\[(.+)\]", nparray.dtype.str)
if not m:
raise ValueError(f"Couldn't find precision for dtype '{nparray.dtype}.")

fill_value_args = (
# e.g. ns, D
m.groups()[0],
)

a = root.create_array(
array_path,
shape=nparray.shape,
chunks=chunks,
dtype=nparray.dtype.str,
dtype=nparray.dtype,
attributes=attributes,
# compressor=compressor, # TODO: FIXME
fill_value=nparray.dtype.type(0, *fill_value_args),
# compressor=compressor, # FIXME
fill_value=fill_value,
)

assert isinstance(a, Array)
assert a.fill_value is not None
assert isinstance(root[array_path], Array)
assert nparray.shape == a.shape
assert chunks == a.chunks
assert array_path == a.path, (path, name, array_path, a.name, a.path)
# assert array_path == a.name, (path, name, array_path, a.name, a.path)
# assert a.basename is None # TODO
# assert a.store == normalize_store_arg(store)
assert a.basename == name, (a.basename, name)
assert dict(a.attrs) == expected_attrs

a[:] = nparray
Expand Down
10 changes: 5 additions & 5 deletions tests/v3/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import hypothesis.strategies as st # noqa: E402
from hypothesis import given # noqa: E402

from zarr.testing.strategies import arrays, basic_indices, np_arrays # noqa: E402
from zarr.testing.strategies import arrays, basic_indices, numpy_arrays, zarr_formats # noqa: E402


@given(st.data())
def test_roundtrip(data: st.DataObject) -> None:
nparray = data.draw(np_arrays)
zarray = data.draw(arrays(arrays=st.just(nparray)))
@given(data=st.data(), zarr_format=zarr_formats)
def test_roundtrip(data: st.DataObject, zarr_format: int) -> None:
nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format)))
zarray = data.draw(arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format)))
assert_array_equal(nparray, zarray[:])


Expand Down

0 comments on commit 692593b

Please sign in to comment.