diff --git a/raster_loader/cli/bigquery.py b/raster_loader/cli/bigquery.py index 3ff8b0e..ac9abde 100644 --- a/raster_loader/cli/bigquery.py +++ b/raster_loader/cli/bigquery.py @@ -1,10 +1,11 @@ import os -import uuid +from urllib.parse import urlparse import click from functools import wraps, partial -from raster_loader.io.bigquery import BigQueryConnection +from raster_loader.utils import get_default_table_name +from raster_loader.io.bigquery import BigQueryConnection, AccessTokenCredentials def catch_exception(func=None, *, handle=Exception): @@ -30,8 +31,14 @@ def bigquery(args=None): @bigquery.command(help="Upload a raster file to Google BigQuery.") -@click.option("--file_path", help="The path to the raster file.", required=True) +@click.option( + "--file_path", help="The path to the raster file.", required=False, default=None +) +@click.option( + "--file_url", help="The path to the raster file.", required=False, default=None +) @click.option("--project", help="The name of the Google Cloud project.", required=True) +@click.option("--token", help="An access token to authenticate with.", default=None) @click.option("--dataset", help="The name of the dataset.", required=True) @click.option("--table", help="The name of the table.", default=None) @click.option( @@ -67,7 +74,9 @@ def bigquery(args=None): @catch_exception() def upload( file_path, + file_url, project, + token, dataset, table, band, @@ -82,6 +91,14 @@ def upload( get_block_dims, ) + if file_path is None and file_url is None: + raise ValueError("Either --file_path or --file_url must be provided.") + + if file_path and file_url: + raise ValueError("Only one of --file_path or --file_url must be provided.") + + is_local_file = file_path is not None + # check that band and band_name are the same length # if band_name provided if band_name != (None,): @@ -95,23 +112,30 @@ def upload( # create default table name if not provided if table is None: - table = os.path.basename(file_path).split(".")[0] - table = "_".join([table, "band", str(band), str(uuid.uuid4())]) + table = get_default_table_name(file_path if is_local_file else urlparse(file_url).path, band) + + credentials = None + if token is not None: + credentials = AccessTokenCredentials(token) - connector = BigQueryConnection(project) + connector = BigQueryConnection(project, credentials) + + source = file_path if is_local_file else file_url # introspect raster file - num_blocks = get_number_of_blocks(file_path) - file_size_mb = os.path.getsize(file_path) / 1024 / 1024 + num_blocks = get_number_of_blocks(source) + file_size_mb = 0 + if is_local_file: + file_size_mb = os.path.getsize(file_path) / 1024 / 1024 click.echo("Preparing to upload raster file to BigQuery...") - click.echo("File Path: {}".format(file_path)) + click.echo("File Path: {}".format(source)) click.echo("File Size: {} MB".format(file_size_mb)) - print_band_information(file_path) + print_band_information(source) click.echo("Source Band: {}".format(band)) click.echo("Band Name: {}".format(band_name)) click.echo("Number of Blocks: {}".format(num_blocks)) - click.echo("Block Dims: {}".format(get_block_dims(file_path))) + click.echo("Block Dims: {}".format(get_block_dims(source))) click.echo("Project: {}".format(project)) click.echo("Dataset: {}".format(dataset)) click.echo("Table: {}".format(table)) @@ -121,7 +145,7 @@ def upload( fqn = f"{project}.{dataset}.{table}" connector.upload_raster( - file_path, + source, fqn, bands_info, chunk_size, @@ -138,8 +162,13 @@ def upload( @click.option("--dataset", help="The name of the dataset.", required=True) @click.option("--table", help="The name of the table.", required=True) @click.option("--limit", help="Limit number of rows returned", default=10) -def describe(project, dataset, table, limit): - connector = BigQueryConnection(project) +@click.option("--token", help="An access token to authenticate with.", required=False, default=None) +def describe(project, dataset, table, limit, token): + credentials = None + if token is not None: + credentials = AccessTokenCredentials(token) + + connector = BigQueryConnection(project, credentials) fqn = f"{project}.{dataset}.{table}" df = connector.get_records(fqn, limit) diff --git a/raster_loader/cli/snowflake.py b/raster_loader/cli/snowflake.py index 1e60b79..be73e18 100644 --- a/raster_loader/cli/snowflake.py +++ b/raster_loader/cli/snowflake.py @@ -1,9 +1,10 @@ import os -import uuid +from urllib.parse import urlparse import click from functools import wraps, partial +from raster_loader.utils import get_default_table_name from raster_loader.io.snowflake import SnowflakeConnection @@ -31,9 +32,12 @@ def snowflake(args=None): @snowflake.command(help="Upload a raster file to Snowflake.") @click.option("--account", help="The Swnoflake account.", required=True) -@click.option("--username", help="The username.", required=True) -@click.option("--password", help="The password.", required=True) -@click.option("--file_path", help="The path to the raster file.", required=True) +@click.option("--username", help="The username.", required=False, default=None) +@click.option("--password", help="The password.", required=False, default=None) +@click.option("--token", help="An access token to authenticate with.", required=False, default=None) +@click.option("--role", help="The role to use for the file upload.", default=None) +@click.option("--file_path", help="The path to the raster file.", required=False, default=None) +@click.option("--file_url", help="The path to the raster file.", required=False, default=None) @click.option("--database", help="The name of the database.", required=True) @click.option("--schema", help="The name of the schema.", required=True) @click.option("--table", help="The name of the table.", default=None) @@ -72,7 +76,10 @@ def upload( account, username, password, + token, + role, file_path, + file_url, database, schema, table, @@ -88,6 +95,17 @@ def upload( get_block_dims, ) + if (token is None and (username is None or password is None)) or all(v is not None for v in [token, username, password]): + raise ValueError("Either --token or --username and --password must be provided.") + + if file_path is None and file_url is None: + raise ValueError("Either --file_path or --file_url must be provided.") + + if file_path and file_url: + raise ValueError("Only one of --file_path or --file_url must be provided.") + + is_local_file = file_path is not None + # check that band and band_name are the same length # if band_name provided if band_name != (None,): @@ -101,29 +119,34 @@ def upload( # create default table name if not provided if table is None: - table = os.path.basename(file_path).split(".")[0] - table = "_".join([table, "band", str(band), str(uuid.uuid4())]) + table = get_default_table_name(file_path if is_local_file else urlparse(file_url).path, band) connector = SnowflakeConnection( username=username, password=password, + token=token, account=account, database=database, schema=schema, + role=role, ) + source = file_path if is_local_file else file_url + # introspect raster file - num_blocks = get_number_of_blocks(file_path) - file_size_mb = os.path.getsize(file_path) / 1024 / 1024 + num_blocks = get_number_of_blocks(source) + file_size_mb = 0 + if is_local_file: + file_size_mb = os.path.getsize(file_path) / 1024 / 1024 click.echo("Preparing to upload raster file to Snowflake...") - click.echo("File Path: {}".format(file_path)) + click.echo("File Path: {}".format(source)) click.echo("File Size: {} MB".format(file_size_mb)) - print_band_information(file_path) + print_band_information(source) click.echo("Source Band: {}".format(band)) click.echo("Band Name: {}".format(band_name)) click.echo("Number of Blocks: {}".format(num_blocks)) - click.echo("Block Dims: {}".format(get_block_dims(file_path))) + click.echo("Block Dims: {}".format(get_block_dims(source))) click.echo("Database: {}".format(database)) click.echo("Schema: {}".format(schema)) click.echo("Table: {}".format(table)) @@ -133,7 +156,7 @@ def upload( fqn = f"{database}.{schema}.{table}" connector.upload_raster( - file_path, + source, fqn, bands_info, chunk_size, @@ -147,20 +170,28 @@ def upload( @snowflake.command(help="Load and describe a table from Snowflake") @click.option("--account", help="The Swnoflake account.", required=True) -@click.option("--username", help="The username.", required=True) -@click.option("--password", help="The password.", required=True) +@click.option("--username", help="The username.", required=False, default=None) +@click.option("--password", help="The password.", required=False, default=None) +@click.option("--token", help="An access token to authenticate with.", required=False, default=None) +@click.option("--role", help="The role to use for the file upload.", default=None) @click.option("--database", help="The name of the database.", required=True) @click.option("--schema", help="The name of the schema.", required=True) -@click.option("--table", help="The name of the table.", default=None) +@click.option("--table", help="The name of the table.", required=True) @click.option("--limit", help="Limit number of rows returned", default=10) -def describe(account, username, password, database, schema, table, limit): +def describe(account, username, password, token, role, database, schema, table, limit): + + if (token is None and (username is None or password is None)) or all(v is not None for v in [token, username, password]): + raise ValueError("Either --token or --username and --password must be provided.") + fqn = f"{database}.{schema}.{table}" connector = SnowflakeConnection( username=username, password=password, + token=token, account=account, database=database, schema=schema, + role=role, ) df = connector.get_records(fqn, limit) print(f"Table: {fqn}") diff --git a/raster_loader/io/bigquery.py b/raster_loader/io/bigquery.py index 37bf8f8..4a40de2 100644 --- a/raster_loader/io/bigquery.py +++ b/raster_loader/io/bigquery.py @@ -20,6 +20,8 @@ try: from google.cloud import bigquery + from google.auth.credentials import Credentials + except ImportError: # pragma: no cover _has_bigquery = False else: @@ -28,11 +30,24 @@ from raster_loader.io.datawarehouse import DataWarehouseConnection +class AccessTokenCredentials(Credentials): + def __init__(self, access_token): + super(AccessTokenCredentials, self).__init__() + self._access_token = access_token + + def refresh(self, request): + pass + + def apply(self, headers, token=None): + headers["Authorization"] = f"Bearer {self._access_token}" + + class BigQueryConnection(DataWarehouseConnection): - def __init__(self, project): + def __init__(self, project, credentials: Credentials = None): if not _has_bigquery: # pragma: no cover import_error_bigquery() - self.client = bigquery.Client(project=project) + + self.client = bigquery.Client(project=project, credentials=credentials) def execute(self, sql): return self.client.query(sql).result() diff --git a/raster_loader/io/snowflake.py b/raster_loader/io/snowflake.py index 3ee46e1..c3ca5f8 100644 --- a/raster_loader/io/snowflake.py +++ b/raster_loader/io/snowflake.py @@ -28,19 +28,30 @@ else: _has_snowflake = True - class SnowflakeConnection(DataWarehouseConnection): - def __init__(self, username, password, account, database, schema): + def __init__(self, username, password, account, database, schema, token, role): if not _has_snowflake: import_error_snowflake() - self.client = snowflake.connector.connect( - user=username, - password=password, - account=account, - database=database.upper(), - schema=schema.upper(), - ) + # TODO: Write a proper static factory for this + if token is None: + self.client = snowflake.connector.connect( + user=username, + password=password, + account=account, + database=database.upper(), + schema=schema.upper(), + role=role.upper() if role is not None else None, + ) + else: + self.client = snowflake.connector.connect( + authenticator="oauth", + token=token, + account=account, + database=database.upper(), + schema=schema.upper(), + role=role.upper() if role is not None else None, + ) def band_rename_function(self, band_name: str): return band_name diff --git a/raster_loader/tests/bigquery/test_cli.py b/raster_loader/tests/bigquery/test_cli.py index b0fb2c7..ed809ab 100644 --- a/raster_loader/tests/bigquery/test_cli.py +++ b/raster_loader/tests/bigquery/test_cli.py @@ -38,6 +38,55 @@ def test_bigquery_upload(*args, **kwargs): assert result.exit_code == 0 +@patch("raster_loader.cli.bigquery.BigQueryConnection.upload_raster", return_value=None) +@patch("raster_loader.cli.bigquery.BigQueryConnection.__init__", return_value=None) +def test_bigquery_file_path_or_url_check(*args, **kwargs): + runner = CliRunner() + result = runner.invoke( + main, + [ + "bigquery", + "upload", + "--project", + "project", + "--dataset", + "dataset", + "--table", + "table", + "--chunk_size", + 1, + "--band", + 1, + ], + ) + assert result.exit_code == 1 + assert "Either --file_path or --file_url must be provided" in result.output + + result = runner.invoke( + main, + [ + "bigquery", + "upload", + "--file_path", + f"{tiff}", + "--file_url", + "http://example.com/raster.tif", + "--project", + "project", + "--dataset", + "dataset", + "--table", + "table", + "--chunk_size", + 1, + "--band", + 1, + ], + ) + assert result.exit_code == 1 + assert "Only one of --file_path or --file_url must be provided" in result.output + + @patch("raster_loader.cli.bigquery.BigQueryConnection.upload_raster", return_value=None) @patch("raster_loader.cli.bigquery.BigQueryConnection.__init__", return_value=None) def test_bigquery_upload_multiple_bands(*args, **kwargs): @@ -153,6 +202,7 @@ def test_bigquery_upload_no_table_name(*args, **kwargs): ], ) assert result.exit_code == 0 + assert "Table: mosaic_cog_band__1___" in result.output @patch( diff --git a/raster_loader/tests/snowflake/test_cli.py b/raster_loader/tests/snowflake/test_cli.py index 935382e..86fc0b7 100644 --- a/raster_loader/tests/snowflake/test_cli.py +++ b/raster_loader/tests/snowflake/test_cli.py @@ -45,6 +45,129 @@ def test_snowflake_upload(*args, **kwargs): ) assert result.exit_code == 0 +@patch( + "raster_loader.io.snowflake.SnowflakeConnection.upload_raster", return_value=None +) +@patch("raster_loader.io.snowflake.SnowflakeConnection.__init__", return_value=None) +def test_snowflake_credentials_validation(*args, **kwargs): + runner = CliRunner() + result = runner.invoke( + main, + [ + "snowflake", + "upload", + "--file_path", + f"{tiff}", + "--database", + "database", + "--schema", + "schema", + "--table", + "table", + "--account", + "account", + "--username", + "username", + "--chunk_size", + 1, + "--band", + 1, + ], + ) + assert result.exit_code == 1 + assert "Either --token or --username and --password must be provided." in result.output + + result = runner.invoke( + main, + [ + "snowflake", + "upload", + "--file_path", + f"{tiff}", + "--database", + "database", + "--schema", + "schema", + "--table", + "table", + "--account", + "account", + "--username", + "username", + "--password", + "password", + "--token", + "token", + "--chunk_size", + 1, + "--band", + 1, + ], + ) + assert result.exit_code == 1 + assert "Either --token or --username and --password must be provided." in result.output + +@patch( + "raster_loader.io.snowflake.SnowflakeConnection.upload_raster", return_value=None +) +@patch("raster_loader.io.snowflake.SnowflakeConnection.__init__", return_value=None) +def test_snowflake_file_path_or_url_check(*args, **kwargs): + runner = CliRunner() + result = runner.invoke( + main, + [ + "snowflake", + "upload", + "--database", + "database", + "--schema", + "schema", + "--table", + "table", + "--account", + "account", + "--username", + "username", + "--password", + "password", + "--chunk_size", + 1, + "--band", + 1, + ], + ) + assert result.exit_code == 1 + assert "Either --file_path or --file_url must be provided" in result.output + + result = runner.invoke( + main, + [ + "snowflake", + "upload", + "--file_path", + f"{tiff}", + "--file_url", + "http://example.com/raster.tif", + "--database", + "database", + "--schema", + "schema", + "--table", + "table", + "--account", + "account", + "--username", + "username", + "--password", + "password", + "--chunk_size", + 1, + "--band", + 1, + ], + ) + assert result.exit_code == 1 + assert "Only one of --file_path or --file_url must be provided" in result.output @patch( "raster_loader.io.snowflake.SnowflakeConnection.upload_raster", return_value=None @@ -177,8 +300,6 @@ def test_snowflake_upload_no_table_name(*args, **kwargs): "database", "--schema", "schema", - "--table", - "table", "--account", "account", "--username", @@ -192,6 +313,7 @@ def test_snowflake_upload_no_table_name(*args, **kwargs): ], ) assert result.exit_code == 0 + assert "Table: mosaic_cog_band__1___" in result.output @patch( diff --git a/raster_loader/tests/snowflake/test_io.py b/raster_loader/tests/snowflake/test_io.py index 1197dca..61b0708 100644 --- a/raster_loader/tests/snowflake/test_io.py +++ b/raster_loader/tests/snowflake/test_io.py @@ -32,10 +32,11 @@ SF_PASSWORD = os.environ.get("SF_PASSWORD") SF_DATABASE = os.environ.get("SF_DATABASE") SF_SCHEMA = os.environ.get("SF_SCHEMA") +SF_ROLE = os.environ.get("SF_ROLE") def check_integration_config(): - if not all([SF_ACCOUNT, SF_USERNAME, SF_PASSWORD, SF_DATABASE, SF_SCHEMA]): + if not all([SF_ACCOUNT, SF_USERNAME, SF_PASSWORD, SF_DATABASE, SF_SCHEMA, SF_ROLE]): raise Exception( "You need to copy tests/.env.sample to test/.env and set your configuration" "before running the tests" @@ -50,7 +51,13 @@ def test_rasterio_to_snowflake_with_raster_default_band_name(): fqn = f"{SF_DATABASE}.{SF_SCHEMA}.{table_name}" connector = SnowflakeConnection( - SF_USERNAME, SF_PASSWORD, SF_ACCOUNT, SF_DATABASE, SF_SCHEMA + username=SF_USERNAME, + password=SF_PASSWORD, + account=SF_ACCOUNT, + database=SF_DATABASE, + schema=SF_SCHEMA, + role=SF_ROLE, + token=None, ) connector.upload_raster( @@ -94,7 +101,13 @@ def test_rasterio_to_snowflake_appending_rows(): fqn = f"{SF_DATABASE}.{SF_SCHEMA}.{table_name}" connector = SnowflakeConnection( - SF_USERNAME, SF_PASSWORD, SF_ACCOUNT, SF_DATABASE, SF_SCHEMA + username=SF_USERNAME, + password=SF_PASSWORD, + account=SF_ACCOUNT, + database=SF_DATABASE, + schema=SF_SCHEMA, + role=SF_ROLE, + token=None, ) connector.upload_raster( @@ -172,7 +185,13 @@ def test_rasterio_to_snowflake_with_raster_custom_band_column(): fqn = f"{SF_DATABASE}.{SF_SCHEMA}.{table_name}" connector = SnowflakeConnection( - SF_USERNAME, SF_PASSWORD, SF_ACCOUNT, SF_DATABASE, SF_SCHEMA + username=SF_USERNAME, + password=SF_PASSWORD, + account=SF_ACCOUNT, + database=SF_DATABASE, + schema=SF_SCHEMA, + role=SF_ROLE, + token=None, ) connector.upload_raster( @@ -220,7 +239,13 @@ def test_rasterio_to_snowflake_with_raster_multiple_default(): fqn = f"{SF_DATABASE}.{SF_SCHEMA}.{table_name}" connector = SnowflakeConnection( - SF_USERNAME, SF_PASSWORD, SF_ACCOUNT, SF_DATABASE, SF_SCHEMA + username=SF_USERNAME, + password=SF_PASSWORD, + account=SF_ACCOUNT, + database=SF_DATABASE, + schema=SF_SCHEMA, + role=SF_ROLE, + token=None, ) connector.upload_raster( @@ -273,7 +298,13 @@ def test_rasterio_to_snowflake_with_raster_multiple_custom(): fqn = f"{SF_DATABASE}.{SF_SCHEMA}.{table_name}" connector = SnowflakeConnection( - SF_USERNAME, SF_PASSWORD, SF_ACCOUNT, SF_DATABASE, SF_SCHEMA + username=SF_USERNAME, + password=SF_PASSWORD, + account=SF_ACCOUNT, + database=SF_DATABASE, + schema=SF_SCHEMA, + role=SF_ROLE, + token=None, ) connector.upload_raster( diff --git a/raster_loader/utils.py b/raster_loader/utils.py index 257410e..e3adf85 100644 --- a/raster_loader/utils.py +++ b/raster_loader/utils.py @@ -1,4 +1,7 @@ from itertools import islice +import os +import re +import uuid def ask_yes_no_question(question: str) -> bool: @@ -25,3 +28,8 @@ def batched(iterable, n): it = iter(iterable) while batch := tuple(islice(it, n)): # noqa yield batch + +def get_default_table_name(base_path: str, band): + table = os.path.basename(base_path).split(".")[0] + table = "_".join([table, "band", str(band), str(uuid.uuid4())]) + return re.sub(r"[^a-zA-Z0-9_-]", "_", table) diff --git a/setup.cfg b/setup.cfg index 66888a8..f75233b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -57,6 +57,7 @@ test = pytest-cov>=3.0.0 bigquery = google-cloud-bigquery>=3.13.0 + google-auth>=2.28.0 snowflake = snowflake-connector-python>=2.6.0 all =