Skip to content

Commit

Permalink
[v3] clean up create_array signatures in group/asyncgroup classes (#…
Browse files Browse the repository at this point in the history
…2132)

* clean up create_array signatures in group/asyncgroup classes

* fix members test
  • Loading branch information
jhamman authored Aug 30, 2024
1 parent 2f9cf22 commit 0b5483a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 24 deletions.
16 changes: 9 additions & 7 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,21 +311,23 @@ def info(self) -> None:

async def create_group(
self,
path: str,
name: str,
*,
exists_ok: bool = False,
attributes: dict[str, Any] | None = None,
) -> AsyncGroup:
attributes = attributes or {}
return await type(self).create(
self.store_path / path,
self.store_path / name,
attributes=attributes,
exists_ok=exists_ok,
zarr_format=self.metadata.zarr_format,
)

async def create_array(
self,
path: str,
name: str,
*,
shape: ChunkCoords,
dtype: npt.DTypeLike = "float64",
fill_value: Any | None = None,
Expand Down Expand Up @@ -356,7 +358,7 @@ async def create_array(
Parameters
----------
path: str
name: str
The name of the array.
shape: tuple[int, ...]
The shape of the array.
Expand Down Expand Up @@ -392,7 +394,7 @@ async def create_array(
"""
return await AsyncArray.create(
self.store_path / path,
self.store_path / name,
shape=shape,
dtype=dtype,
chunk_shape=chunk_shape,
Expand Down Expand Up @@ -789,7 +791,7 @@ def create_array(
return Array(
self._sync(
self._async_group.create_array(
path=name,
name=name,
shape=shape,
dtype=dtype,
fill_value=fill_value,
Expand Down Expand Up @@ -912,7 +914,7 @@ def array(
return Array(
self._sync(
self._async_group.create_array(
path=name,
name=name,
shape=shape,
dtype=dtype,
fill_value=fill_value,
Expand Down
34 changes: 17 additions & 17 deletions tests/v3/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,11 +564,11 @@ async def test_asyncgroup_getitem(store: LocalStore | MemoryStore, zarr_format:
"""
agroup = await AsyncGroup.create(store=store, zarr_format=zarr_format)

sub_array_path = "sub_array"
array_name = "sub_array"
sub_array = await agroup.create_array(
path=sub_array_path, shape=(10,), dtype="uint8", chunk_shape=(2,)
name=array_name, shape=(10,), dtype="uint8", chunk_shape=(2,)
)
assert await agroup.getitem(sub_array_path) == sub_array
assert await agroup.getitem(array_name) == sub_array

sub_group_path = "sub_group"
sub_group = await agroup.create_group(sub_group_path, attributes={"foo": 100})
Expand All @@ -581,29 +581,29 @@ async def test_asyncgroup_getitem(store: LocalStore | MemoryStore, zarr_format:

async def test_asyncgroup_delitem(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
agroup = await AsyncGroup.create(store=store, zarr_format=zarr_format)
sub_array_path = "sub_array"
array_name = "sub_array"
_ = await agroup.create_array(
path=sub_array_path, shape=(10,), dtype="uint8", chunk_shape=(2,), attributes={"foo": 100}
name=array_name, shape=(10,), dtype="uint8", chunk_shape=(2,), attributes={"foo": 100}
)
await agroup.delitem(sub_array_path)
await agroup.delitem(array_name)

# todo: clean up the code duplication here
if zarr_format == 2:
assert not await agroup.store_path.store.exists(sub_array_path + "/" + ".zarray")
assert not await agroup.store_path.store.exists(sub_array_path + "/" + ".zattrs")
assert not await agroup.store_path.store.exists(array_name + "/" + ".zarray")
assert not await agroup.store_path.store.exists(array_name + "/" + ".zattrs")
elif zarr_format == 3:
assert not await agroup.store_path.store.exists(sub_array_path + "/" + "zarr.json")
assert not await agroup.store_path.store.exists(array_name + "/" + "zarr.json")
else:
raise AssertionError

sub_group_path = "sub_group"
_ = await agroup.create_group(sub_group_path, attributes={"foo": 100})
await agroup.delitem(sub_group_path)
if zarr_format == 2:
assert not await agroup.store_path.store.exists(sub_array_path + "/" + ".zgroup")
assert not await agroup.store_path.store.exists(sub_array_path + "/" + ".zattrs")
assert not await agroup.store_path.store.exists(array_name + "/" + ".zgroup")
assert not await agroup.store_path.store.exists(array_name + "/" + ".zattrs")
elif zarr_format == 3:
assert not await agroup.store_path.store.exists(sub_array_path + "/" + "zarr.json")
assert not await agroup.store_path.store.exists(array_name + "/" + "zarr.json")
else:
raise AssertionError

Expand All @@ -615,7 +615,7 @@ async def test_asyncgroup_create_group(
agroup = await AsyncGroup.create(store=store, zarr_format=zarr_format)
sub_node_path = "sub_group"
attributes = {"foo": 999}
subnode = await agroup.create_group(path=sub_node_path, attributes=attributes)
subnode = await agroup.create_group(name=sub_node_path, attributes=attributes)

assert isinstance(subnode, AsyncGroup)
assert subnode.attrs == attributes
Expand Down Expand Up @@ -645,7 +645,7 @@ async def test_asyncgroup_create_array(

sub_node_path = "sub_array"
subnode = await agroup.create_array(
path=sub_node_path,
name=sub_node_path,
shape=shape,
dtype=dtype,
chunk_shape=chunk_shape,
Expand Down Expand Up @@ -684,11 +684,11 @@ async def test_group_members_async(store: LocalStore | MemoryStore):
GroupMetadata(),
store_path=StorePath(store=store, path="root"),
)
a0 = await group.create_array("a0", (1,))
a0 = await group.create_array("a0", shape=(1,))
g0 = await group.create_group("g0")
a1 = await g0.create_array("a1", (1,))
a1 = await g0.create_array("a1", shape=(1,))
g1 = await g0.create_group("g1")
a2 = await g1.create_array("a2", (1,))
a2 = await g1.create_array("a2", shape=(1,))
g2 = await g1.create_group("g2")

# immediate children
Expand Down

0 comments on commit 0b5483a

Please sign in to comment.