Skip to content

Commit

Permalink
Support input URLs and more connection credential types (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
thedae authored Mar 19, 2024
1 parent fc77b3c commit f8a9be0
Show file tree
Hide file tree
Showing 9 changed files with 347 additions and 49 deletions.
57 changes: 43 additions & 14 deletions raster_loader/cli/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
import uuid
from urllib.parse import urlparse

import click
from functools import wraps, partial

from raster_loader.io.bigquery import BigQueryConnection
from raster_loader.utils import get_default_table_name
from raster_loader.io.bigquery import BigQueryConnection, AccessTokenCredentials


def catch_exception(func=None, *, handle=Exception):
Expand All @@ -30,8 +31,14 @@ def bigquery(args=None):


@bigquery.command(help="Upload a raster file to Google BigQuery.")
@click.option("--file_path", help="The path to the raster file.", required=True)
@click.option(
"--file_path", help="The path to the raster file.", required=False, default=None
)
@click.option(
"--file_url", help="The path to the raster file.", required=False, default=None
)
@click.option("--project", help="The name of the Google Cloud project.", required=True)
@click.option("--token", help="An access token to authenticate with.", default=None)
@click.option("--dataset", help="The name of the dataset.", required=True)
@click.option("--table", help="The name of the table.", default=None)
@click.option(
Expand Down Expand Up @@ -67,7 +74,9 @@ def bigquery(args=None):
@catch_exception()
def upload(
file_path,
file_url,
project,
token,
dataset,
table,
band,
Expand All @@ -82,6 +91,14 @@ def upload(
get_block_dims,
)

if file_path is None and file_url is None:
raise ValueError("Either --file_path or --file_url must be provided.")

if file_path and file_url:
raise ValueError("Only one of --file_path or --file_url must be provided.")

is_local_file = file_path is not None

# check that band and band_name are the same length
# if band_name provided
if band_name != (None,):
Expand All @@ -95,23 +112,30 @@ def upload(

# create default table name if not provided
if table is None:
table = os.path.basename(file_path).split(".")[0]
table = "_".join([table, "band", str(band), str(uuid.uuid4())])
table = get_default_table_name(file_path if is_local_file else urlparse(file_url).path, band)

credentials = None
if token is not None:
credentials = AccessTokenCredentials(token)

connector = BigQueryConnection(project)
connector = BigQueryConnection(project, credentials)

source = file_path if is_local_file else file_url

# introspect raster file
num_blocks = get_number_of_blocks(file_path)
file_size_mb = os.path.getsize(file_path) / 1024 / 1024
num_blocks = get_number_of_blocks(source)
file_size_mb = 0
if is_local_file:
file_size_mb = os.path.getsize(file_path) / 1024 / 1024

click.echo("Preparing to upload raster file to BigQuery...")
click.echo("File Path: {}".format(file_path))
click.echo("File Path: {}".format(source))
click.echo("File Size: {} MB".format(file_size_mb))
print_band_information(file_path)
print_band_information(source)
click.echo("Source Band: {}".format(band))
click.echo("Band Name: {}".format(band_name))
click.echo("Number of Blocks: {}".format(num_blocks))
click.echo("Block Dims: {}".format(get_block_dims(file_path)))
click.echo("Block Dims: {}".format(get_block_dims(source)))
click.echo("Project: {}".format(project))
click.echo("Dataset: {}".format(dataset))
click.echo("Table: {}".format(table))
Expand All @@ -121,7 +145,7 @@ def upload(

fqn = f"{project}.{dataset}.{table}"
connector.upload_raster(
file_path,
source,
fqn,
bands_info,
chunk_size,
Expand All @@ -138,8 +162,13 @@ def upload(
@click.option("--dataset", help="The name of the dataset.", required=True)
@click.option("--table", help="The name of the table.", required=True)
@click.option("--limit", help="Limit number of rows returned", default=10)
def describe(project, dataset, table, limit):
connector = BigQueryConnection(project)
@click.option("--token", help="An access token to authenticate with.", required=False, default=None)
def describe(project, dataset, table, limit, token):
credentials = None
if token is not None:
credentials = AccessTokenCredentials(token)

connector = BigQueryConnection(project, credentials)

fqn = f"{project}.{dataset}.{table}"
df = connector.get_records(fqn, limit)
Expand Down
63 changes: 47 additions & 16 deletions raster_loader/cli/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import uuid
from urllib.parse import urlparse

import click
from functools import wraps, partial

from raster_loader.utils import get_default_table_name
from raster_loader.io.snowflake import SnowflakeConnection


Expand Down Expand Up @@ -31,9 +32,12 @@ def snowflake(args=None):

@snowflake.command(help="Upload a raster file to Snowflake.")
@click.option("--account", help="The Swnoflake account.", required=True)
@click.option("--username", help="The username.", required=True)
@click.option("--password", help="The password.", required=True)
@click.option("--file_path", help="The path to the raster file.", required=True)
@click.option("--username", help="The username.", required=False, default=None)
@click.option("--password", help="The password.", required=False, default=None)
@click.option("--token", help="An access token to authenticate with.", required=False, default=None)
@click.option("--role", help="The role to use for the file upload.", default=None)
@click.option("--file_path", help="The path to the raster file.", required=False, default=None)
@click.option("--file_url", help="The path to the raster file.", required=False, default=None)
@click.option("--database", help="The name of the database.", required=True)
@click.option("--schema", help="The name of the schema.", required=True)
@click.option("--table", help="The name of the table.", default=None)
Expand Down Expand Up @@ -72,7 +76,10 @@ def upload(
account,
username,
password,
token,
role,
file_path,
file_url,
database,
schema,
table,
Expand All @@ -88,6 +95,17 @@ def upload(
get_block_dims,
)

if (token is None and (username is None or password is None)) or all(v is not None for v in [token, username, password]):
raise ValueError("Either --token or --username and --password must be provided.")

if file_path is None and file_url is None:
raise ValueError("Either --file_path or --file_url must be provided.")

if file_path and file_url:
raise ValueError("Only one of --file_path or --file_url must be provided.")

is_local_file = file_path is not None

# check that band and band_name are the same length
# if band_name provided
if band_name != (None,):
Expand All @@ -101,29 +119,34 @@ def upload(

# create default table name if not provided
if table is None:
table = os.path.basename(file_path).split(".")[0]
table = "_".join([table, "band", str(band), str(uuid.uuid4())])
table = get_default_table_name(file_path if is_local_file else urlparse(file_url).path, band)

connector = SnowflakeConnection(
username=username,
password=password,
token=token,
account=account,
database=database,
schema=schema,
role=role,
)

source = file_path if is_local_file else file_url

# introspect raster file
num_blocks = get_number_of_blocks(file_path)
file_size_mb = os.path.getsize(file_path) / 1024 / 1024
num_blocks = get_number_of_blocks(source)
file_size_mb = 0
if is_local_file:
file_size_mb = os.path.getsize(file_path) / 1024 / 1024

click.echo("Preparing to upload raster file to Snowflake...")
click.echo("File Path: {}".format(file_path))
click.echo("File Path: {}".format(source))
click.echo("File Size: {} MB".format(file_size_mb))
print_band_information(file_path)
print_band_information(source)
click.echo("Source Band: {}".format(band))
click.echo("Band Name: {}".format(band_name))
click.echo("Number of Blocks: {}".format(num_blocks))
click.echo("Block Dims: {}".format(get_block_dims(file_path)))
click.echo("Block Dims: {}".format(get_block_dims(source)))
click.echo("Database: {}".format(database))
click.echo("Schema: {}".format(schema))
click.echo("Table: {}".format(table))
Expand All @@ -133,7 +156,7 @@ def upload(

fqn = f"{database}.{schema}.{table}"
connector.upload_raster(
file_path,
source,
fqn,
bands_info,
chunk_size,
Expand All @@ -147,20 +170,28 @@ def upload(

@snowflake.command(help="Load and describe a table from Snowflake")
@click.option("--account", help="The Swnoflake account.", required=True)
@click.option("--username", help="The username.", required=True)
@click.option("--password", help="The password.", required=True)
@click.option("--username", help="The username.", required=False, default=None)
@click.option("--password", help="The password.", required=False, default=None)
@click.option("--token", help="An access token to authenticate with.", required=False, default=None)
@click.option("--role", help="The role to use for the file upload.", default=None)
@click.option("--database", help="The name of the database.", required=True)
@click.option("--schema", help="The name of the schema.", required=True)
@click.option("--table", help="The name of the table.", default=None)
@click.option("--table", help="The name of the table.", required=True)
@click.option("--limit", help="Limit number of rows returned", default=10)
def describe(account, username, password, database, schema, table, limit):
def describe(account, username, password, token, role, database, schema, table, limit):

if (token is None and (username is None or password is None)) or all(v is not None for v in [token, username, password]):
raise ValueError("Either --token or --username and --password must be provided.")

fqn = f"{database}.{schema}.{table}"
connector = SnowflakeConnection(
username=username,
password=password,
token=token,
account=account,
database=database,
schema=schema,
role=role,
)
df = connector.get_records(fqn, limit)
print(f"Table: {fqn}")
Expand Down
19 changes: 17 additions & 2 deletions raster_loader/io/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

try:
from google.cloud import bigquery
from google.auth.credentials import Credentials

except ImportError: # pragma: no cover
_has_bigquery = False
else:
Expand All @@ -28,11 +30,24 @@
from raster_loader.io.datawarehouse import DataWarehouseConnection


class AccessTokenCredentials(Credentials):
def __init__(self, access_token):
super(AccessTokenCredentials, self).__init__()
self._access_token = access_token

def refresh(self, request):
pass

def apply(self, headers, token=None):
headers["Authorization"] = f"Bearer {self._access_token}"


class BigQueryConnection(DataWarehouseConnection):
def __init__(self, project):
def __init__(self, project, credentials: Credentials = None):
if not _has_bigquery: # pragma: no cover
import_error_bigquery()
self.client = bigquery.Client(project=project)

self.client = bigquery.Client(project=project, credentials=credentials)

def execute(self, sql):
return self.client.query(sql).result()
Expand Down
29 changes: 20 additions & 9 deletions raster_loader/io/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,30 @@
else:
_has_snowflake = True


class SnowflakeConnection(DataWarehouseConnection):
def __init__(self, username, password, account, database, schema):
def __init__(self, username, password, account, database, schema, token, role):
if not _has_snowflake:
import_error_snowflake()

self.client = snowflake.connector.connect(
user=username,
password=password,
account=account,
database=database.upper(),
schema=schema.upper(),
)
# TODO: Write a proper static factory for this
if token is None:
self.client = snowflake.connector.connect(
user=username,
password=password,
account=account,
database=database.upper(),
schema=schema.upper(),
role=role.upper() if role is not None else None,
)
else:
self.client = snowflake.connector.connect(
authenticator="oauth",
token=token,
account=account,
database=database.upper(),
schema=schema.upper(),
role=role.upper() if role is not None else None,
)

def band_rename_function(self, band_name: str):
return band_name
Expand Down
50 changes: 50 additions & 0 deletions raster_loader/tests/bigquery/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,55 @@ def test_bigquery_upload(*args, **kwargs):
assert result.exit_code == 0


@patch("raster_loader.cli.bigquery.BigQueryConnection.upload_raster", return_value=None)
@patch("raster_loader.cli.bigquery.BigQueryConnection.__init__", return_value=None)
def test_bigquery_file_path_or_url_check(*args, **kwargs):
runner = CliRunner()
result = runner.invoke(
main,
[
"bigquery",
"upload",
"--project",
"project",
"--dataset",
"dataset",
"--table",
"table",
"--chunk_size",
1,
"--band",
1,
],
)
assert result.exit_code == 1
assert "Either --file_path or --file_url must be provided" in result.output

result = runner.invoke(
main,
[
"bigquery",
"upload",
"--file_path",
f"{tiff}",
"--file_url",
"http://example.com/raster.tif",
"--project",
"project",
"--dataset",
"dataset",
"--table",
"table",
"--chunk_size",
1,
"--band",
1,
],
)
assert result.exit_code == 1
assert "Only one of --file_path or --file_url must be provided" in result.output


@patch("raster_loader.cli.bigquery.BigQueryConnection.upload_raster", return_value=None)
@patch("raster_loader.cli.bigquery.BigQueryConnection.__init__", return_value=None)
def test_bigquery_upload_multiple_bands(*args, **kwargs):
Expand Down Expand Up @@ -153,6 +202,7 @@ def test_bigquery_upload_no_table_name(*args, **kwargs):
],
)
assert result.exit_code == 0
assert "Table: mosaic_cog_band__1___" in result.output


@patch(
Expand Down
Loading

0 comments on commit f8a9be0

Please sign in to comment.