Skip to content

Commit

Permalink
SNOW-136583 adding multipart PUT threshold support (#627)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-mkeller authored Feb 23, 2021
1 parent 9699c75 commit 2b41eaa
Show file tree
Hide file tree
Showing 19 changed files with 1,116 additions and 1,034 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def _get_arrow_lib_as_linker_input(self):
'chardet>=3.0.2,<4',
'idna>=2.5,<3',
'certifi>=2017.4.17',
'dataclasses<1.0;python_version=="3.6"',
],

namespace_packages=['snowflake'],
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/connector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
NotSupportedError,
OperationalError,
ProgrammingError,
Warning,
_Warning,
)
from .version import VERSION

Expand All @@ -57,7 +57,7 @@ def Connect(**kwargs):

__all__ = [
# Error handling
'Error', 'Warning',
'Error', '_Warning',
'InterfaceError', 'DatabaseError',
'NotSupportedError', 'DataError', 'IntegrityError', 'ProgrammingError',
'OperationalError', 'InternalError',
Expand Down
150 changes: 77 additions & 73 deletions src/snowflake/connector/azure_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
import os
from collections import namedtuple
from logging import getLogger
from typing import TYPE_CHECKING, Any, Dict

from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
from azure.storage.blob import BlobServiceClient, ContentSettings, ExponentialRetry

from .constants import HTTP_HEADER_VALUE_OCTET_STREAM, SHA256_DIGEST, FileHeader, ResultStatus
from .constants import HTTP_HEADER_VALUE_OCTET_STREAM, FileHeader, ResultStatus
from .encryption_util import EncryptionMetadata

if TYPE_CHECKING: # pragma: no cover
from .file_transfer_agent import SnowflakeFileMeta

logger = getLogger(__name__)

"""
Expand All @@ -31,11 +35,9 @@
class SnowflakeAzureUtil(object):
"""Azure Utility class."""

# max_connections works over this size
DATA_SIZE_THRESHOLD = 67108864

@staticmethod
def create_client(stage_info, use_accelerate_endpoint: bool = False):
def create_client(stage_info: Dict[str, Any],
use_accelerate_endpoint: bool = False) -> BlobServiceClient:
"""Creates a client object with a stage credential.
Args:
Expand All @@ -53,10 +55,7 @@ def create_client(stage_info, use_accelerate_endpoint: bool = False):
if end_point.startswith('blob.'):
end_point = end_point[len('blob.'):]
client = BlobServiceClient(
account_url="https://{}.blob.{}".format(
stage_info['storageAccount'],
end_point
),
account_url=f"https://{stage_info['storageAccount']}.blob.{end_point}",
credential=sas_token)
client._config.retry_policy = ExponentialRetry(
initial_backoff=1,
Expand Down Expand Up @@ -85,39 +84,34 @@ def extract_container_name_and_path(stage_location):
path=path)

@staticmethod
def get_file_header(meta, filename):
def get_file_header(meta: 'SnowflakeFileMeta', filename):
"""Gets Azure file properties."""
client = meta['client']
azure_location = SnowflakeAzureUtil.extract_container_name_and_path(
meta['stage_info']['location'])
client: BlobServiceClient = meta.client
azure_location = SnowflakeAzureUtil.extract_container_name_and_path(meta.stage_info['location'])
try:
# HTTP HEAD request
blob = client.get_blob_client(azure_location.container_name,
azure_location.path + filename)
blob_details = blob.get_blob_properties()
except ResourceNotFoundError:
meta['result_status'] = ResultStatus.NOT_FOUND_FILE
meta.result_status = ResultStatus.NOT_FOUND_FILE
return FileHeader(
digest=None,
content_length=None,
encryption_metadata=None
)
except HttpResponseError as err:
logger.debug("Caught exception's status code: {status_code} and message: {ex_representation}".format(
status_code=err.status_code,
ex_representation=str(err)
))
logger.debug(f"Caught exception's status code: {err.status_code} and message: {str(err)}")
if err.status_code == 403 and SnowflakeAzureUtil._detect_azure_token_expire_error(err):
logger.debug("AZURE Token expired. Renew and retry")
meta['result_status'] = ResultStatus.RENEW_TOKEN
meta.result_status = ResultStatus.RENEW_TOKEN
else:
logger.debug('Unexpected Azure error: %s'
'container: %s, path: %s',
err, azure_location.container_name,
azure_location.path)
meta['result_status'] = ResultStatus.ERROR
logger.debug(f'Unexpected Azure error: {err} '
f'container: {azure_location.container_name}, path: {azure_location.path}')
meta.result_status = ResultStatus.ERROR

return
meta['result_status'] = ResultStatus.UPLOADED
meta.result_status = ResultStatus.UPLOADED
encryptiondata = json.loads(blob_details.metadata.get('encryptiondata', 'null'))
encryption_metadata = EncryptionMetadata(
key=encryptiondata['WrappedContentKey']['EncryptedKey'],
Expand All @@ -140,9 +134,30 @@ def _detect_azure_token_expire_error(err):
"Server failed to authenticate the request." in errstr

@staticmethod
def upload_file(data_file, meta, encryption_metadata, max_concurrency):
def upload_file(data_file: str,
meta: 'SnowflakeFileMeta',
encryption_metadata: 'EncryptionMetadata',
max_concurrency: int,
multipart_threshold: int,
):
"""Uploads the local file to Azure's Blob Storage.
Args:
data_file: File path on local system.
meta: The File meta object (contains credentials and remote location).
encryption_metadata: Encryption metadata to be set on object.
max_concurrency: Not applicable to Azure.
multipart_threshold: The number of bytes after which size a file should be uploaded concurrently in chunks.
Not applicable to Azure.
Raises:
HTTPError if some http errors occurred.
Returns:
None.
"""
azure_metadata = {
'sfcdigest': meta[SHA256_DIGEST],
'sfcdigest': meta.sha256_digest,
}
if encryption_metadata:
azure_metadata.update({
Expand All @@ -165,37 +180,36 @@ def upload_file(data_file, meta, encryption_metadata, max_concurrency):
'matdesc': encryption_metadata.matdesc
})
azure_location = SnowflakeAzureUtil.extract_container_name_and_path(
meta['stage_info']['location'])
path = azure_location.path + meta['dst_file_name'].lstrip('/')
meta.stage_info['location'])
path = azure_location.path + meta.dst_file_name.lstrip('/')

client = meta['client']
client: BlobServiceClient = meta.client
callback = None
upload_src = None
upload_size = None

if 'src_stream' not in meta:
if meta.src_stream is None:
upload_size = os.path.getsize(data_file)
upload_src = open(data_file, 'rb')
else:
upload_src = meta.get('real_src_stream', meta['src_stream'])
upload_src = meta.real_src_stream or meta.src_stream
upload_size = upload_src.seek(0, os.SEEK_END)
upload_src.seek(0)

if meta['put_azure_callback']:
callback = meta['put_azure_callback'](
if meta.put_azure_callback:
callback = meta.put_azure_callback(
data_file,
upload_size,
output_stream=meta['put_callback_output_stream'],
show_progress_bar=meta['show_progress_bar'])
output_stream=meta.put_callback_output_stream,
show_progress_bar=meta.show_progress_bar)

def azure_callback(response):
current = response.context['upload_stream_current']
total = response.context['data_stream_total']
if current is not None:
callback(current)
logger.debug("data transfer progress from sdk callback. "
"current: %s, total: %s",
current, total)
f"current: {current}, total: {total}")

try:
blob = client.get_blob_client(
Expand All @@ -207,58 +221,51 @@ def azure_callback(response):
metadata=azure_metadata,
overwrite=True,
max_concurrency=max_concurrency,
raw_response_hook=azure_callback if meta['put_azure_callback'] else None,
raw_response_hook=azure_callback if meta.put_azure_callback else None,
content_settings=ContentSettings(
content_type=HTTP_HEADER_VALUE_OCTET_STREAM,
content_encoding='utf-8',
)
)
except HttpResponseError as err:
logger.debug("Caught exception's status code: {status_code} and message: {ex_representation}".format(
status_code=err.status_code,
ex_representation=str(err)
))
logger.debug(f"Caught exception's status code: {err.status_code} and message: {err}")
if err.status_code == 403 and SnowflakeAzureUtil._detect_azure_token_expire_error(err):
logger.debug("AZURE Token expired. Renew and retry")
meta['result_status'] = ResultStatus.RENEW_TOKEN
meta.result_status = ResultStatus.RENEW_TOKEN
else:
meta['last_error'] = err
meta['result_status'] = ResultStatus.NEED_RETRY
meta.last_error = err
meta.result_status = ResultStatus.NEED_RETRY
return
finally:
if 'src_stream' not in meta:
if meta.src_stream is None:
upload_src.close()

logger.debug('DONE putting a file')
meta['dst_file_size'] = meta['upload_size']
meta['result_status'] = ResultStatus.UPLOADED
meta.dst_file_size = meta.upload_size
meta.result_status = ResultStatus.UPLOADED
# Comparing with s3, azure haven't experienced OpenSSL.SSL.SysCallError,
# so we will add logic to catch it only when it happens

@staticmethod
def _native_download_file(meta, full_dst_file_name, max_concurrency):
azure_location = SnowflakeAzureUtil.extract_container_name_and_path(
meta['stage_info']['location'])
path = azure_location.path + meta['src_file_name'].lstrip('/')
client = meta['client']
def _native_download_file(meta: 'SnowflakeFileMeta', full_dst_file_name, max_concurrency):
azure_location = SnowflakeAzureUtil.extract_container_name_and_path(meta.stage_info['location'])
path = azure_location.path + meta.src_file_name.lstrip('/')
client: BlobServiceClient = meta.client

callback = None
if meta['get_azure_callback']:
callback = meta['get_azure_callback'](
meta['src_file_name'],
meta['src_file_size'],
output_stream=meta['get_callback_output_stream'],
show_progress_bar=meta['show_progress_bar'])
if meta.get_azure_callback:
callback = meta.get_azure_callback(
meta.src_file_name,
meta.src_file_size,
output_stream=meta.get_callback_output_stream,
show_progress_bar=meta.show_progress_bar)

def azure_callback(response):
current = response.context['download_stream_current']
total = response.context['data_stream_total']
if current is not None:
callback(current)
logger.debug("data transfer progress from sdk callback. "
"current: %s, total: %s",
current, total)

logger.debug(f"data transfer progress from sdk callback. current: {current}, total: {total}")
try:
blob = client.get_blob_client(
azure_location.container_name,
Expand All @@ -267,20 +274,17 @@ def azure_callback(response):
with open(full_dst_file_name, 'wb') as download_f:
download = blob.download_blob(
max_concurrency=max_concurrency,
raw_response_hook=azure_callback if meta['put_azure_callback'] else None,
raw_response_hook=azure_callback if meta.put_azure_callback else None,
)
download.readinto(download_f)

except HttpResponseError as err:
logger.debug("Caught exception's status code: {status_code} and message: {ex_representation}".format(
status_code=err.status_code,
ex_representation=str(err)
))
logger.debug(f"Caught exception's status code: {err.status_code} and message: {str(err)}")
if err.status_code == 403 and SnowflakeAzureUtil._detect_azure_token_expire_error(err):
logger.debug("AZURE Token expired. Renew and retry")
meta['result_status'] = ResultStatus.RENEW_TOKEN
meta.result_status = ResultStatus.RENEW_TOKEN
else:
meta['last_error'] = err
meta['result_status'] = ResultStatus.NEED_RETRY
meta.last_error = err
meta.result_status = ResultStatus.NEED_RETRY
return
meta['result_status'] = ResultStatus.DOWNLOADED
meta.result_status = ResultStatus.DOWNLOADED
10 changes: 10 additions & 0 deletions src/snowflake/connector/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,13 @@ def PRINT(msg):

def INPUT(prompt):
return input(prompt)


try:
# builtin dataclass
from dataclass import dataclass # NOQA
from dataclass import field # NOQA
except ImportError:
# backported dataclass for Python 3.6
from dataclasses import dataclass # NOQA
from dataclasses import field # NOQA
6 changes: 4 additions & 2 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __init__(self, **kwargs):
self.heartbeat_thread = None

self.converter = None
self.__set_error_attributes()
self.connect(**kwargs)
self._telemetry = TelemetryClient(self._rest)
self.incident = IncidentAPI(self._rest)
Expand Down Expand Up @@ -419,7 +420,6 @@ def connect(self, **kwargs):
self.__config(**kwargs)
TelemetryService.get_instance().update_context(kwargs)

self.__set_error_attributes()
self.__open_connection()

def close(self, retry=True):
Expand Down Expand Up @@ -533,7 +533,9 @@ def execute_stream(self, stream: StringIO,
def __set_error_attributes(self):
for m in [method for method in dir(errors) if
callable(getattr(errors, method))]:
setattr(self, m, getattr(errors, m))
# If name starts with _ then ignore that
name = m if not m.startswith('_') else m[1:]
setattr(self, name, getattr(errors, m))

@staticmethod
def setup_ocsp_privatelink(app, hostname):
Expand Down
5 changes: 3 additions & 2 deletions src/snowflake/connector/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,6 @@ def execute(self,
logger.debug('SUCCESS')
data = ret['data']

# logger.debug(ret)
logger.debug("PUT OR GET: %s", self.is_file_transfer)
if self.is_file_transfer:
sf_file_transfer_agent = SnowflakeFileTransferAgent(
Expand All @@ -580,7 +579,9 @@ def execute(self,
show_progress_bar=_show_progress_bar,
raise_put_get_error=_raise_put_get_error,
force_put_overwrite=_force_put_overwrite or data.get('overwrite', False),
source_from_stream=file_stream)
source_from_stream=file_stream,
multipart_threshold=data.get('threshold'),
)
sf_file_transfer_agent.execute()
data = sf_file_transfer_agent.result()
self._total_rowcount = len(data['rowset']) if \
Expand Down
Loading

0 comments on commit 2b41eaa

Please sign in to comment.