Skip to content

Commit

Permalink
support max_concurrency in upload_blob and download_blob operat…
Browse files Browse the repository at this point in the history
…ions (#420)

* spec: expose max_concurrency param in blob_upload/blob_download

* prefer batch_size in batched async methods
  • Loading branch information
pmrowla authored Aug 5, 2023
1 parent cdd9513 commit 55ba981
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 9 deletions.
71 changes: 65 additions & 6 deletions adlfs/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from azure.storage.blob._shared.base_client import create_configuration
from azure.storage.blob.aio import BlobServiceClient as AIOBlobServiceClient
from azure.storage.blob.aio._list_blobs_helper import BlobPrefix
from fsspec.asyn import AsyncFileSystem, get_loop, sync, sync_wrapper
from fsspec.asyn import AsyncFileSystem, _get_batch_size, get_loop, sync, sync_wrapper
from fsspec.spec import AbstractBufferedFile
from fsspec.utils import infer_storage_options

Expand Down Expand Up @@ -165,6 +165,10 @@ class AzureBlobFileSystem(AsyncFileSystem):
False throws if retrieving container properties fails, which might happen if your
authentication is only valid at the storage container level, and not the
storage account level.
max_concurrency:
The number of concurrent connections to use when uploading or downloading a blob.
If None it will be inferred from fsspec.asyn._get_batch_size().
Pass on to fsspec:
skip_instance_cache: to control reuse of instances
Expand Down Expand Up @@ -227,6 +231,7 @@ def __init__(
default_cache_type: str = "bytes",
version_aware: bool = False,
assume_container_exists: Optional[bool] = None,
max_concurrency: Optional[int] = None,
**kwargs,
):
super_kwargs = {
Expand Down Expand Up @@ -292,6 +297,12 @@ def __init__(
if self.credential is not None:
weakref.finalize(self, sync, self.loop, close_credential, self)

if max_concurrency is None:
batch_size = _get_batch_size()
if batch_size > 0:
max_concurrency = batch_size
self.max_concurrency = max_concurrency

@classmethod
def _strip_protocol(cls, path: str):
"""
Expand Down Expand Up @@ -1426,7 +1437,9 @@ async def _dir_exists(self, container, path):
except ResourceNotFoundError:
return False

async def _pipe_file(self, path, value, overwrite=True, **kwargs):
async def _pipe_file(
self, path, value, overwrite=True, max_concurrency=None, **kwargs
):
"""Set the bytes of given file"""
container_name, path, _ = self.split_path(path)
async with self.service_client.get_blob_client(
Expand All @@ -1436,14 +1449,23 @@ async def _pipe_file(self, path, value, overwrite=True, **kwargs):
data=value,
overwrite=overwrite,
metadata={"is_directory": "false"},
max_concurrency=max_concurrency or self.max_concurrency,
**kwargs,
)
self.invalidate_cache(self._parent(path))
return result

pipe_file = sync_wrapper(_pipe_file)

async def _cat_file(self, path, start=None, end=None, **kwargs):
async def _pipe(self, *args, batch_size=None, max_concurrency=None, **kwargs):
max_concurrency = max_concurrency or 1
return await super()._pipe(
*args, batch_size=batch_size, max_concurrency=max_concurrency, **kwargs
)

async def _cat_file(
self, path, start=None, end=None, max_concurrency=None, **kwargs
):
path = self._strip_protocol(path)
if end is not None:
start = start or 0 # download_blob requires start if length is provided.
Expand All @@ -1456,7 +1478,10 @@ async def _cat_file(self, path, start=None, end=None, **kwargs):
) as bc:
try:
stream = await bc.download_blob(
offset=start, length=length, version_id=version_id
offset=start,
length=length,
version_id=version_id,
max_concurrency=max_concurrency or self.max_concurrency,
)
except ResourceNotFoundError as e:
raise FileNotFoundError from e
Expand Down Expand Up @@ -1497,6 +1522,12 @@ def cat(self, path, recursive=False, on_error="raise", **kwargs):
else:
return self.cat_file(paths[0])

async def _cat_ranges(self, *args, batch_size=None, max_concurrency=None, **kwargs):
max_concurrency = max_concurrency or 1
return await super()._cat_ranges(
*args, batch_size=batch_size, max_concurrency=max_concurrency, **kwargs
)

def url(self, path, expires=3600, **kwargs):
return sync(self.loop, self._url, path, expires, **kwargs)

Expand Down Expand Up @@ -1593,7 +1624,14 @@ async def _expand_path(
return list(sorted(out))

async def _put_file(
self, lpath, rpath, delimiter="/", overwrite=False, callback=None, **kwargws
self,
lpath,
rpath,
delimiter="/",
overwrite=False,
callback=None,
max_concurrency=None,
**kwargws,
):
"""
Copy single file to remote
Expand Down Expand Up @@ -1621,6 +1659,7 @@ async def _put_file(
raw_response_hook=make_callback(
"upload_stream_current", callback
),
max_concurrency=max_concurrency or self.max_concurrency,
)
self.invalidate_cache()
except ResourceExistsError:
Expand All @@ -1633,6 +1672,12 @@ async def _put_file(

put_file = sync_wrapper(_put_file)

async def _put(self, *args, batch_size=None, max_concurrency=None, **kwargs):
max_concurrency = max_concurrency or 1
return await super()._put(
*args, batch_size=batch_size, max_concurrency=max_concurrency, **kwargs
)

async def _cp_file(self, path1, path2, **kwargs):
"""Copy the file at path1 to path2"""
container1, path1, version_id = self.split_path(path1, delimiter="/")
Expand Down Expand Up @@ -1668,7 +1713,14 @@ def download(self, rpath, lpath, recursive=False, **kwargs):
return self.get(rpath, lpath, recursive=recursive, **kwargs)

async def _get_file(
self, rpath, lpath, recursive=False, delimiter="/", callback=None, **kwargs
self,
rpath,
lpath,
recursive=False,
delimiter="/",
callback=None,
max_concurrency=None,
**kwargs,
):
"""Copy single file remote to local"""
if os.path.isdir(lpath):
Expand All @@ -1683,6 +1735,7 @@ async def _get_file(
"download_stream_current", callback
),
version_id=version_id,
max_concurrency=max_concurrency or self.max_concurrency,
)
with open(lpath, "wb") as my_blob:
await stream.readinto(my_blob)
Expand All @@ -1691,6 +1744,12 @@ async def _get_file(

get_file = sync_wrapper(_get_file)

async def _get(self, *args, batch_size=None, max_concurrency=None, **kwargs):
max_concurrency = max_concurrency or 1
return await super()._get(
*args, batch_size=batch_size, max_concurrency=max_concurrency, **kwargs
)

def getxattr(self, path, attr):
meta = self.info(path).get("metadata", {})
return meta[attr]
Expand Down
13 changes: 10 additions & 3 deletions adlfs/tests/test_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,9 @@ def test_info_missing(storage, path):

def test_time_info(storage):
fs = AzureBlobFileSystem(
account_name=storage.account_name, connection_string=CONN_STR
account_name=storage.account_name,
connection_string=CONN_STR,
max_concurrency=1,
)

creation_time = fs.created("data/root/d/file_with_metadata.txt")
Expand Down Expand Up @@ -1486,7 +1488,10 @@ async def test_cat_file_versioned(storage, mocker):

await fs._cat_file(f"data/root/a/file.txt?versionid={DEFAULT_VERSION_ID}")
download_blob.assert_called_once_with(
offset=None, length=None, version_id=DEFAULT_VERSION_ID
offset=None,
length=None,
version_id=DEFAULT_VERSION_ID,
max_concurrency=fs.max_concurrency,
)

download_blob.reset_mock()
Expand Down Expand Up @@ -1744,7 +1749,9 @@ async def test_get_file_versioned(storage, mocker, tmp_path):
f"data/root/a/file.txt?versionid={DEFAULT_VERSION_ID}", tmp_path / "file.txt"
)
download_blob.assert_called_once_with(
raw_response_hook=mocker.ANY, version_id=DEFAULT_VERSION_ID
raw_response_hook=mocker.ANY,
version_id=DEFAULT_VERSION_ID,
max_concurrency=fs.max_concurrency,
)
download_blob.reset_mock()
download_blob.side_effect = ResourceNotFoundError
Expand Down

0 comments on commit 55ba981

Please sign in to comment.