diff --git a/bg_atlasapi/utils.py b/bg_atlasapi/utils.py index a1eecbb9..591ab491 100644 --- a/bg_atlasapi/utils.py +++ b/bg_atlasapi/utils.py @@ -1,6 +1,7 @@ import configparser import json import logging +import re from typing import Callable, Optional import requests @@ -167,6 +168,13 @@ def retrieve_over_http( try: with progress: tot = int(response.headers.get("content-length", 0)) + + if tot == 0: + try: + tot = get_download_size(url) + except Exception: + tot = 0 + task_id = progress.add_task( "download", filename=output_file_path.name, @@ -175,16 +183,18 @@ def retrieve_over_http( ) with open(output_file_path, "wb") as fout: - advanced = 0 + completed = 0 for chunk in response.iter_content(chunk_size=CHUNK_SIZE): fout.write(chunk) adv = len(chunk) - progress.update(task_id, advance=adv, refresh=True) + completed += adv + progress.update( + task_id, completed=min(completed, tot), refresh=True + ) if fn_update: # update handler with completed and total bytes - advanced += adv - fn_update(advanced, tot) + fn_update(completed, tot) except requests.exceptions.ConnectionError: output_file_path.unlink() @@ -193,6 +203,66 @@ def retrieve_over_http( ) +def get_download_size(url: str) -> int: + """Get file size based on the MB value on the "src" page of each atlas + + Parameters + ---------- + url : str + atlas file url (in a repo, make sure the "raw" url is passed) + + Returns + ------- + int + size of the file to download + + Raises + ------ + requests.exceptions.HTTPError: If there's an issue with HTTP request. + ValueError: If the file size cannot be extracted from the response. + IndexError: If the url is not formatted as expected + + """ + try: + # Replace the 'raw' in the url with 'src' + url_split = url.split("/") + url_split[5] = "src" + url = "/".join(url_split) + + response = requests.get(url) + response.raise_for_status() + + response_string = response.content.decode("utf-8") + search_result = re.search( + r"([0-9]+\.[0-9] [MGK]B)|([0-9]+ [MGK]B)", response_string + ) + + assert search_result is not None + + size_string = search_result.group() + + assert size_string is not None + + size = float(size_string[:-3]) + prefix = size_string[-2] + + if prefix == "G": + size *= 1e9 + elif prefix == "M": + size *= 1e6 + elif prefix == "K": + size *= 1e3 + + return int(size) + + except requests.exceptions.HTTPError as e: + raise e + except AssertionError: + raise ValueError("File size information not found in the response.") + except IndexError: + raise IndexError("Improperly formatted URL") + + def conf_from_url(url): """Read conf file from an URL. Parameters diff --git a/tests/test_utils.py b/tests/test_utils.py index 9928ebea..1d194872 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,13 @@ +from unittest import mock + import pytest +import requests +from requests import HTTPError from bg_atlasapi import utils +test_url = "https://gin.g-node.org/BrainGlobe/atlases/raw/master/example_mouse_100um_v1.2.tar.gz" + def test_http_check(): assert utils.check_internet_connection() @@ -14,3 +20,64 @@ def test_http_check(): assert not utils.check_internet_connection( url="http://asd", raise_error=False ) + + +def test_get_download_size_bad_url(): + with pytest.raises(IndexError): + utils.get_download_size(url="http://asd") + + +def test_get_download_size_no_size_url(): + with pytest.raises(ValueError): + utils.get_download_size( + "https://gin.g-node.org/BrainGlobe/atlases/src/master/last_versions.conf" + ) + + +@pytest.mark.parametrize( + "url, real_size", + [ + ( + "https://gin.g-node.org/BrainGlobe/atlases/raw/master/example_mouse_100um_v1.2.tar.gz", + 7.3, + ), + ( + "https://gin.g-node.org/BrainGlobe/atlases/raw/master/allen_mouse_100um_v1.2.tar.gz", + 61, + ), + ( + "https://gin.g-node.org/BrainGlobe/atlases/raw/master/admba_3d_p56_mouse_25um_v1.0.tar.gz", + 335, + ), + ( + "https://gin.g-node.org/BrainGlobe/atlases/raw/master/osten_mouse_10um_v1.1.tar.gz", + 3600, + ), + ], +) +def test_get_download_size(url, real_size): + size = utils.get_download_size(url) + + real_size = real_size * 1e6 + + assert size == real_size + + +def test_get_download_size_kb(): + with mock.patch("requests.get", autospec=True) as mock_request: + mock_response = mock.Mock(spec=requests.Response) + mock_response.status_code = 200 + mock_response.content = b"asd 24.7 KB 123sd" + mock_request.return_value = mock_response + + size = utils.get_download_size(test_url) + + assert size == 24700 + + +def test_get_download_size_HTTPError(): + with mock.patch("requests.get", autospec=True) as mock_request: + mock_request.side_effect = HTTPError() + + with pytest.raises(HTTPError): + utils.get_download_size(test_url)