Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow mode casting for Stores #2249

Merged
merged 21 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,31 @@ async def empty(self) -> bool: ...
@abstractmethod
async def clear(self) -> None: ...

@abstractmethod
def with_mode(self, mode: AccessModeLiteral) -> Self:
"""
Return a new store of the same type pointing to the same location with a new mode.

The returned Store is not automatically opened. Call :meth:`Store.open` before
using.

Parameters
----------
mode: AccessModeLiteral
The new mode to use.

Returns
-------
store:
A new store of the same type with the new mode.

Examples
--------
>>> writer = zarr.store.MemoryStore(mode="w")
>>> reader = writer.with_mode("r")
"""
...

@property
def mode(self) -> AccessMode:
"""Access mode of the store."""
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/store/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def make_store_path(
result = store_like
elif isinstance(store_like, Store):
if mode is not None:
assert AccessMode.from_literal(mode) == store_like.mode
store_like = store_like.with_mode(mode)
await store_like._ensure_open()
result = StorePath(store_like)
elif store_like is None:
Expand Down
5 changes: 4 additions & 1 deletion src/zarr/store/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import shutil
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Self

from zarr.abc.store import ByteRangeRequest, Store
from zarr.core.buffer import Buffer
Expand Down Expand Up @@ -103,6 +103,9 @@ async def empty(self) -> bool:
else:
return True

def with_mode(self, mode: AccessModeLiteral) -> Self:
return type(self)(root=self.root, mode=mode)

def __str__(self) -> str:
return f"file://{self.root}"

Expand Down
12 changes: 11 additions & 1 deletion src/zarr/store/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import time
from collections import defaultdict
from contextlib import contextmanager
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Self

from zarr.abc.store import AccessMode, ByteRangeRequest, Store

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Generator, Iterable

from zarr.core.buffer import Buffer, BufferPrototype
from zarr.core.common import AccessModeLiteral


class LoggingStore(Store):
Expand All @@ -27,6 +28,8 @@ def __init__(
):
self._store = store
self.counter = defaultdict(int)
self.log_level = log_level
self.log_handler = log_handler

self._configure_logger(log_level, log_handler)

Expand Down Expand Up @@ -162,3 +165,10 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
with self.log():
async for key in self._store.list_dir(prefix=prefix):
yield key

def with_mode(self, mode: AccessModeLiteral) -> Self:
return type(self)(
self._store.with_mode(mode),
log_level=self.log_level,
log_handler=self.log_handler,
)
jhamman marked this conversation as resolved.
Show resolved Hide resolved
50 changes: 41 additions & 9 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Self

from zarr.abc.store import ByteRangeRequest, Store
from zarr.core.buffer import Buffer, gpu
Expand Down Expand Up @@ -41,6 +41,9 @@ async def empty(self) -> bool:
async def clear(self) -> None:
self._store_dict.clear()

def with_mode(self, mode: AccessModeLiteral) -> Self:
return type(self)(store_dict=self._store_dict, mode=mode)

def __str__(self) -> str:
return f"memory://{id(self._store_dict)}"

Expand Down Expand Up @@ -152,29 +155,58 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:

class GpuMemoryStore(MemoryStore):
"""A GPU only memory store that stores every chunk in GPU memory irrespective
of the original location. This guarantees that chunks will always be in GPU
memory for downstream processing. For location agnostic use cases, it would
be better to use `MemoryStore` instead.
of the original location.

The dictionary of buffers to initialize this memory store with *must* be
GPU Buffers.

Writing data to this store through ``.set`` will move the buffer to the GPU
if necessary.

Parameters
----------
store_dict: MutableMapping, optional
A mutable mapping with string keys and :class:`zarr.core.buffer.gpu.Buffer`
values.
"""

_store_dict: MutableMapping[str, Buffer]
_store_dict: MutableMapping[str, gpu.Buffer] # type: ignore[assignment]

def __init__(
self,
store_dict: MutableMapping[str, Buffer] | None = None,
store_dict: MutableMapping[str, gpu.Buffer] | None = None,
*,
mode: AccessModeLiteral = "r",
) -> None:
super().__init__(mode=mode)
if store_dict:
self._store_dict = {k: gpu.Buffer.from_buffer(store_dict[k]) for k in iter(store_dict)}
super().__init__(store_dict=store_dict, mode=mode) # type: ignore[arg-type]

def __str__(self) -> str:
return f"gpumemory://{id(self._store_dict)}"

def __repr__(self) -> str:
return f"GpuMemoryStore({str(self)!r})"

@classmethod
def from_dict(cls, store_dict: MutableMapping[str, Buffer]) -> Self:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this API!

"""
Create a GpuMemoryStore from a dictionary of buffers at any location.

The dictionary backing the newly created ``GpuMemoryStore`` will not be
the same as ``store_dict``.

Parameters
----------
store_dict: mapping
A mapping of strings keys to arbitrary Buffers. The buffer data
will be moved into a :class:`gpu.Buffer`.

Returns
-------
GpuMemoryStore
"""
gpu_store_dict = {k: gpu.Buffer.from_buffer(v) for k, v in store_dict.items()}
return cls(gpu_store_dict)

async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
self._check_writable()
assert isinstance(key, str)
Expand Down
10 changes: 9 additions & 1 deletion src/zarr/store/remote.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Self

import fsspec

Expand Down Expand Up @@ -95,6 +95,14 @@ async def clear(self) -> None:
async def empty(self) -> bool:
return not await self.fs._find(self.path, withdirs=True)

def with_mode(self, mode: AccessModeLiteral) -> Self:
return type(self)(
fs=self.fs,
mode=mode,
path=self.path,
allowed_exceptions=self.allowed_exceptions,
)

def __repr__(self) -> str:
return f"<RemoteStore({type(self.fs).__name__}, {self.path})>"

Expand Down
5 changes: 4 additions & 1 deletion src/zarr/store/zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import zipfile
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, Self

from zarr.abc.store import ByteRangeRequest, Store
from zarr.core.buffer import Buffer, BufferPrototype
Expand Down Expand Up @@ -112,6 +112,9 @@ async def empty(self) -> bool:
with self._lock:
return not self._zf.namelist()

def with_mode(self, mode: ZipStoreAccessModeLiteral) -> Self: # type: ignore[override]
raise NotImplementedError("ZipStore cannot be reopened with a new mode.")

def __str__(self) -> str:
return f"zip://{self.path}"

Expand Down
38 changes: 37 additions & 1 deletion src/zarr/testing/store.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pickle
from typing import Any, Generic, TypeVar
from typing import Any, Generic, TypeVar, cast

import pytest

from zarr.abc.store import AccessMode, Store
from zarr.core.buffer import Buffer, default_buffer_prototype
from zarr.core.common import AccessModeLiteral
from zarr.core.sync import _collect_aiterator, collect_aiterator
from zarr.store._utils import _normalize_interval_index
from zarr.testing.utils import assert_bytes_equal
Expand Down Expand Up @@ -273,3 +274,38 @@ async def test_list_dir(self, store: S) -> None:

keys_observed = await _collect_aiterator(store.list_dir(root + "/"))
assert sorted(keys_expected) == sorted(keys_observed)

async def test_with_mode(self, store: S) -> None:
data = b"0000"
self.set(store, "key", self.buffer_cls.from_bytes(data))
assert self.get(store, "key").to_bytes() == data

for mode in ["r", "a"]:
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved
mode = cast(AccessModeLiteral, mode)
clone = store.with_mode(mode)
# await store.close()
await clone._ensure_open()
assert clone.mode == AccessMode.from_literal(mode)
assert isinstance(clone, type(store))

# earlier writes are visible
result = await clone.get("key", default_buffer_prototype())
assert result is not None
assert result.to_bytes() == data

# writes to original after with_mode is visible
self.set(store, "key-2", self.buffer_cls.from_bytes(data))
result = await clone.get("key-2", default_buffer_prototype())
assert result is not None
assert result.to_bytes() == data

if mode == "a":
# writes to clone is visible in the original
await clone.set("key-3", self.buffer_cls.from_bytes(data))
result = await clone.get("key-3", default_buffer_prototype())
assert result is not None
assert result.to_bytes() == data

else:
with pytest.raises(ValueError, match="store mode"):
await clone.set("key-3", self.buffer_cls.from_bytes(data))
8 changes: 8 additions & 0 deletions tests/v3/test_store/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import zarr
import zarr.store
from zarr.core.buffer import default_buffer_prototype
from zarr.store.logging import LoggingStore

Expand Down Expand Up @@ -48,3 +49,10 @@ async def test_logging_store_counter(store: Store) -> None:
assert wrapped.counter["list"] == 0
assert wrapped.counter["list_dir"] == 0
assert wrapped.counter["list_prefix"] == 0


async def test_with_mode():
wrapped = LoggingStore(store=zarr.store.MemoryStore(mode="w"), log_level="INFO")
new = await wrapped.with_mode(mode="r")
assert new.mode.str == "r"
assert new.log_level == "INFO"
25 changes: 22 additions & 3 deletions tests/v3/test_store/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,14 @@ def set(self, store: GpuMemoryStore, key: str, value: Buffer) -> None:
def get(self, store: MemoryStore, key: str) -> Buffer:
return store._store_dict[key]

@pytest.fixture(params=[None, {}])
def store_kwargs(self, request) -> dict[str, str | None | dict[str, Buffer]]:
return {"store_dict": request.param, "mode": "r+"}
@pytest.fixture(params=[None, True])
def store_kwargs(
self, request: pytest.FixtureRequest
) -> dict[str, str | None | dict[str, Buffer]]:
kwargs = {"store_dict": None, "mode": "r+"}
if request.param is True:
kwargs["store_dict"] = {}
return kwargs

@pytest.fixture
def store(self, store_kwargs: str | None | dict[str, gpu.Buffer]) -> GpuMemoryStore:
Expand All @@ -80,3 +85,17 @@ def test_store_supports_partial_writes(self, store: GpuMemoryStore) -> None:

def test_list_prefix(self, store: GpuMemoryStore) -> None:
assert True

def test_dict_reference(self, store: GpuMemoryStore) -> None:
store_dict = {}
result = GpuMemoryStore(store_dict=store_dict)
assert result._store_dict is store_dict

def test_from_dict(self):
d = {
"a": gpu.Buffer.from_bytes(b"aaaa"),
"b": cpu.Buffer.from_bytes(b"bbbb"),
}
result = GpuMemoryStore.from_dict(d)
for v in result._store_dict.values():
assert type(v) is gpu.Buffer
4 changes: 4 additions & 0 deletions tests/v3/test_store/test_zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,7 @@ def test_api_integration(self, store: ZipStore) -> None:
del root["bar"]

store.close()

async def test_with_mode(self, store: ZipStore) -> None:
with pytest.raises(NotImplementedError, match="new mode"):
await super().test_with_mode(store)
Loading