Skip to content

Commit

Permalink
Add num_gdrive_retries to maybe_download
Browse files Browse the repository at this point in the history
  • Loading branch information
Suqi Sun committed Mar 24, 2022
1 parent 636a7f8 commit aef70d6
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion forte/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def maybe_download(
path: Union[str, PathLike],
filenames: Optional[List[str]] = None,
extract: bool = False,
num_gdrive_retries: int = 1,
) -> List[str]:
...

Expand All @@ -51,6 +52,7 @@ def maybe_download(
path: Union[str, PathLike],
filenames: Optional[str] = None,
extract: bool = False,
num_gdrive_retries: int = 1,
) -> str:
...

Expand All @@ -60,6 +62,7 @@ def maybe_download(
path: Union[str, PathLike],
filenames: Union[List[str], str, None] = None,
extract: bool = False,
num_gdrive_retries: int = 1,
):
r"""Downloads a set of files.
Expand All @@ -70,6 +73,8 @@ def maybe_download(
must have the same length with ``urls``. If `None`,
filenames are extracted from ``urls``.
extract: Whether to extract compressed files.
num_gdrive_retries: An integer specifying the number of attempts
to download file from Google Drive. Default value is 1.
Returns:
A list of paths to the downloaded files.
Expand Down Expand Up @@ -108,7 +113,9 @@ def maybe_download(
# if not tf.gfile.Exists(filepath):
if not os.path.exists(filepath):
if "drive.google.com" in url:
filepath = _download_from_google_drive(url, filename, path)
filepath = _download_from_google_drive(
url, filename, path, num_gdrive_retries
)
else:
filepath = _download(url, filename, path)

Expand Down

0 comments on commit aef70d6

Please sign in to comment.