diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index bd6befce7..920eaa799 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -88,6 +88,31 @@ async def empty(self) -> bool: ... @abstractmethod async def clear(self) -> None: ... + @abstractmethod + def with_mode(self, mode: AccessModeLiteral) -> Self: + """ + Return a new store of the same type pointing to the same location with a new mode. + + The returned Store is not automatically opened. Call :meth:`Store.open` before + using. + + Parameters + ---------- + mode: AccessModeLiteral + The new mode to use. + + Returns + ------- + store: + A new store of the same type with the new mode. + + Examples + -------- + >>> writer = zarr.store.MemoryStore(mode="w") + >>> reader = writer.with_mode("r") + """ + ... + @property def mode(self) -> AccessMode: """Access mode of the store.""" diff --git a/src/zarr/store/common.py b/src/zarr/store/common.py index ea0edbe5e..2d9b1e82c 100644 --- a/src/zarr/store/common.py +++ b/src/zarr/store/common.py @@ -92,8 +92,8 @@ async def make_store_path( assert AccessMode.from_literal(mode) == store_like.store.mode result = store_like elif isinstance(store_like, Store): - if mode is not None: - assert AccessMode.from_literal(mode) == store_like.mode + if mode is not None and mode != store_like.mode.str: + store_like = store_like.with_mode(mode) await store_like._ensure_open() result = StorePath(store_like) elif store_like is None: diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index 23a87ea49..0dc5c79e7 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -4,7 +4,7 @@ import os import shutil from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self from zarr.abc.store import ByteRangeRequest, Store from zarr.core.buffer import Buffer @@ -110,6 +110,9 @@ async def empty(self) -> bool: else: return True + def with_mode(self, mode: AccessModeLiteral) -> Self: + return type(self)(root=self.root, mode=mode) + def __str__(self) -> str: return f"file://{self.root}" diff --git a/src/zarr/store/logging.py b/src/zarr/store/logging.py index 3a4ae26c3..a9113aabe 100644 --- a/src/zarr/store/logging.py +++ b/src/zarr/store/logging.py @@ -5,7 +5,7 @@ import time from collections import defaultdict from contextlib import contextmanager -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self from zarr.abc.store import AccessMode, ByteRangeRequest, Store from zarr.core.buffer import Buffer @@ -14,6 +14,7 @@ from collections.abc import AsyncGenerator, Generator, Iterable from zarr.core.buffer import Buffer, BufferPrototype + from zarr.core.common import AccessModeLiteral class LoggingStore(Store): @@ -28,6 +29,8 @@ def __init__( ) -> None: self._store = store self.counter = defaultdict(int) + self.log_level = log_level + self.log_handler = log_handler self._configure_logger(log_level, log_handler) @@ -96,6 +99,14 @@ def _is_open(self) -> bool: # type: ignore[override] with self.log(): return self._store._is_open + async def _open(self) -> None: + with self.log(): + return await self._store._open() + + async def _ensure_open(self) -> None: + with self.log(): + return await self._store._ensure_open() + async def empty(self) -> bool: with self.log(): return await self._store.empty() @@ -167,3 +178,11 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: with self.log(): async for key in self._store.list_dir(prefix=prefix): yield key + + def with_mode(self, mode: AccessModeLiteral) -> Self: + with self.log(): + return type(self)( + self._store.with_mode(mode), + log_level=self.log_level, + log_handler=self.log_handler, + ) diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index d5294c9d2..6aec9a6d0 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self from zarr.abc.store import ByteRangeRequest, Store from zarr.core.buffer import Buffer, gpu @@ -41,6 +41,9 @@ async def empty(self) -> bool: async def clear(self) -> None: self._store_dict.clear() + def with_mode(self, mode: AccessModeLiteral) -> Self: + return type(self)(store_dict=self._store_dict, mode=mode) + def __str__(self) -> str: return f"memory://{id(self._store_dict)}" @@ -156,22 +159,30 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: class GpuMemoryStore(MemoryStore): """A GPU only memory store that stores every chunk in GPU memory irrespective - of the original location. This guarantees that chunks will always be in GPU - memory for downstream processing. For location agnostic use cases, it would - be better to use `MemoryStore` instead. + of the original location. + + The dictionary of buffers to initialize this memory store with *must* be + GPU Buffers. + + Writing data to this store through ``.set`` will move the buffer to the GPU + if necessary. + + Parameters + ---------- + store_dict: MutableMapping, optional + A mutable mapping with string keys and :class:`zarr.core.buffer.gpu.Buffer` + values. """ - _store_dict: MutableMapping[str, Buffer] + _store_dict: MutableMapping[str, gpu.Buffer] # type: ignore[assignment] def __init__( self, - store_dict: MutableMapping[str, Buffer] | None = None, + store_dict: MutableMapping[str, gpu.Buffer] | None = None, *, mode: AccessModeLiteral = "r", ) -> None: - super().__init__(mode=mode) - if store_dict: - self._store_dict = {k: gpu.Buffer.from_buffer(store_dict[k]) for k in iter(store_dict)} + super().__init__(store_dict=store_dict, mode=mode) # type: ignore[arg-type] def __str__(self) -> str: return f"gpumemory://{id(self._store_dict)}" @@ -179,6 +190,27 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"GpuMemoryStore({str(self)!r})" + @classmethod + def from_dict(cls, store_dict: MutableMapping[str, Buffer]) -> Self: + """ + Create a GpuMemoryStore from a dictionary of buffers at any location. + + The dictionary backing the newly created ``GpuMemoryStore`` will not be + the same as ``store_dict``. + + Parameters + ---------- + store_dict: mapping + A mapping of strings keys to arbitrary Buffers. The buffer data + will be moved into a :class:`gpu.Buffer`. + + Returns + ------- + GpuMemoryStore + """ + gpu_store_dict = {k: gpu.Buffer.from_buffer(v) for k, v in store_dict.items()} + return cls(gpu_store_dict) + async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: self._check_writable() assert isinstance(key, str) diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index 6cc631d3b..9ef40ae22 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Self import fsspec @@ -96,6 +96,14 @@ async def clear(self) -> None: async def empty(self) -> bool: return not await self.fs._find(self.path, withdirs=True) + def with_mode(self, mode: AccessModeLiteral) -> Self: + return type(self)( + fs=self.fs, + mode=mode, + path=self.path, + allowed_exceptions=self.allowed_exceptions, + ) + def __repr__(self) -> str: return f"" diff --git a/src/zarr/store/zip.py b/src/zarr/store/zip.py index 82ce7d024..116d6de83 100644 --- a/src/zarr/store/zip.py +++ b/src/zarr/store/zip.py @@ -5,7 +5,7 @@ import time import zipfile from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, Self from zarr.abc.store import ByteRangeRequest, Store from zarr.core.buffer import Buffer, BufferPrototype @@ -112,6 +112,9 @@ async def empty(self) -> bool: with self._lock: return not self._zf.namelist() + def with_mode(self, mode: ZipStoreAccessModeLiteral) -> Self: # type: ignore[override] + raise NotImplementedError("ZipStore cannot be reopened with a new mode.") + def __str__(self) -> str: return f"zip://{self.path}" diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index ed49936d1..5495e6fdf 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -1,10 +1,11 @@ import pickle -from typing import Any, Generic, TypeVar +from typing import Any, Generic, TypeVar, cast import pytest from zarr.abc.store import AccessMode, Store from zarr.core.buffer import Buffer, default_buffer_prototype +from zarr.core.common import AccessModeLiteral from zarr.core.sync import _collect_aiterator, collect_aiterator from zarr.store._utils import _normalize_interval_index from zarr.testing.utils import assert_bytes_equal @@ -274,6 +275,41 @@ async def test_list_dir(self, store: S) -> None: keys_observed = await _collect_aiterator(store.list_dir(root + "/")) assert sorted(keys_expected) == sorted(keys_observed) + async def test_with_mode(self, store: S) -> None: + data = b"0000" + self.set(store, "key", self.buffer_cls.from_bytes(data)) + assert self.get(store, "key").to_bytes() == data + + for mode in ["r", "a"]: + mode = cast(AccessModeLiteral, mode) + clone = store.with_mode(mode) + # await store.close() + await clone._ensure_open() + assert clone.mode == AccessMode.from_literal(mode) + assert isinstance(clone, type(store)) + + # earlier writes are visible + result = await clone.get("key", default_buffer_prototype()) + assert result is not None + assert result.to_bytes() == data + + # writes to original after with_mode is visible + self.set(store, "key-2", self.buffer_cls.from_bytes(data)) + result = await clone.get("key-2", default_buffer_prototype()) + assert result is not None + assert result.to_bytes() == data + + if mode == "a": + # writes to clone is visible in the original + await clone.set("key-3", self.buffer_cls.from_bytes(data)) + result = await clone.get("key-3", default_buffer_prototype()) + assert result is not None + assert result.to_bytes() == data + + else: + with pytest.raises(ValueError, match="store mode"): + await clone.set("key-3", self.buffer_cls.from_bytes(data)) + async def test_set_if_not_exists(self, store: S) -> None: key = "k" data_buf = self.buffer_cls.from_bytes(b"0000") diff --git a/tests/v3/test_store/test_logging.py b/tests/v3/test_store/test_logging.py index a263c2ae0..b03c9b94f 100644 --- a/tests/v3/test_store/test_logging.py +++ b/tests/v3/test_store/test_logging.py @@ -5,6 +5,7 @@ import pytest import zarr +import zarr.store from zarr.core.buffer import default_buffer_prototype from zarr.store.logging import LoggingStore @@ -48,3 +49,10 @@ async def test_logging_store_counter(store: Store) -> None: assert wrapped.counter["list"] == 0 assert wrapped.counter["list_dir"] == 0 assert wrapped.counter["list_prefix"] == 0 + + +async def test_with_mode(): + wrapped = LoggingStore(store=zarr.store.MemoryStore(mode="w"), log_level="INFO") + new = wrapped.with_mode(mode="r") + assert new.mode.str == "r" + assert new.log_level == "INFO" diff --git a/tests/v3/test_store/test_memory.py b/tests/v3/test_store/test_memory.py index 441304717..efb61b332 100644 --- a/tests/v3/test_store/test_memory.py +++ b/tests/v3/test_store/test_memory.py @@ -58,9 +58,14 @@ def set(self, store: GpuMemoryStore, key: str, value: Buffer) -> None: def get(self, store: MemoryStore, key: str) -> Buffer: return store._store_dict[key] - @pytest.fixture(params=[None, {}]) - def store_kwargs(self, request) -> dict[str, str | None | dict[str, Buffer]]: - return {"store_dict": request.param, "mode": "r+"} + @pytest.fixture(params=[None, True]) + def store_kwargs( + self, request: pytest.FixtureRequest + ) -> dict[str, str | None | dict[str, Buffer]]: + kwargs = {"store_dict": None, "mode": "r+"} + if request.param is True: + kwargs["store_dict"] = {} + return kwargs @pytest.fixture def store(self, store_kwargs: str | None | dict[str, gpu.Buffer]) -> GpuMemoryStore: @@ -80,3 +85,17 @@ def test_store_supports_partial_writes(self, store: GpuMemoryStore) -> None: def test_list_prefix(self, store: GpuMemoryStore) -> None: assert True + + def test_dict_reference(self, store: GpuMemoryStore) -> None: + store_dict = {} + result = GpuMemoryStore(store_dict=store_dict) + assert result._store_dict is store_dict + + def test_from_dict(self): + d = { + "a": gpu.Buffer.from_bytes(b"aaaa"), + "b": cpu.Buffer.from_bytes(b"bbbb"), + } + result = GpuMemoryStore.from_dict(d) + for v in result._store_dict.values(): + assert type(v) is gpu.Buffer diff --git a/tests/v3/test_store/test_zip.py b/tests/v3/test_store/test_zip.py index 595d1a3e5..e99b921be 100644 --- a/tests/v3/test_store/test_zip.py +++ b/tests/v3/test_store/test_zip.py @@ -96,3 +96,7 @@ def test_api_integration(self, store: ZipStore) -> None: del root["bar"] store.close() + + async def test_with_mode(self, store: ZipStore) -> None: + with pytest.raises(NotImplementedError, match="new mode"): + await super().test_with_mode(store)