From fe3986bd46af4d49abf858cef65fd11d6c0fdfc6 Mon Sep 17 00:00:00 2001 From: Drew Banin Date: Tue, 6 Aug 2019 13:45:09 -0400 Subject: [PATCH] fix for unpickleable datetime tzs set by snowflake --- core/dbt/adapters/factory.py | 4 +- core/dbt/adapters/sql/connections.py | 9 ++-- .../dbt/adapters/snowflake/connections.py | 47 +++++++++++++++++ .../integration/048_rpc_test/sql/bigquery.sql | 26 ++++++++++ .../integration/048_rpc_test/sql/redshift.sql | 15 ++++++ .../048_rpc_test/sql/snowflake.sql | 17 +++++++ .../test_execute_fetch_and_serialize.py | 50 +++++++++++++++++++ 7 files changed, 164 insertions(+), 4 deletions(-) create mode 100644 test/integration/048_rpc_test/sql/bigquery.sql create mode 100644 test/integration/048_rpc_test/sql/redshift.sql create mode 100644 test/integration/048_rpc_test/sql/snowflake.sql create mode 100644 test/integration/048_rpc_test/test_execute_fetch_and_serialize.py diff --git a/core/dbt/adapters/factory.py b/core/dbt/adapters/factory.py index 39ba9d070c8..a2c584df90e 100644 --- a/core/dbt/adapters/factory.py +++ b/core/dbt/adapters/factory.py @@ -1,6 +1,7 @@ import dbt.exceptions from importlib import import_module from dbt.include.global_project import PACKAGES +from dbt.logger import GLOBAL_LOGGER as logger import threading @@ -29,7 +30,8 @@ def get_relation_class_by_name(adapter_name): def load_plugin(adapter_name): try: mod = import_module('.' + adapter_name, 'dbt.adapters') - except ImportError: + except ImportError as e: + logger.info("Error importing adapter: {}".format(e)) raise dbt.exceptions.RuntimeException( "Could not find adapter type {}!".format(adapter_name) ) diff --git a/core/dbt/adapters/sql/connections.py b/core/dbt/adapters/sql/connections.py index a6db10d1215..f90c958bebf 100644 --- a/core/dbt/adapters/sql/connections.py +++ b/core/dbt/adapters/sql/connections.py @@ -76,6 +76,10 @@ def get_status(cls, cursor): '`get_status` is not implemented for this adapter!' ) + @classmethod + def process_results(cls, column_names, rows): + return [dict(zip(column_names, row)) for row in rows] + @classmethod def get_result_from_cursor(cls, cursor): data = [] @@ -83,9 +87,8 @@ def get_result_from_cursor(cls, cursor): if cursor.description is not None: column_names = [col[0] for col in cursor.description] - raw_results = cursor.fetchall() - data = [dict(zip(column_names, row)) - for row in raw_results] + rows = cursor.fetchall() + data = cls.process_results(column_names, rows) return dbt.clients.agate_helper.table_from_data(data, column_names) diff --git a/plugins/snowflake/dbt/adapters/snowflake/connections.py b/plugins/snowflake/dbt/adapters/snowflake/connections.py index f78690797a7..417f2ce1636 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/connections.py +++ b/plugins/snowflake/dbt/adapters/snowflake/connections.py @@ -1,6 +1,7 @@ import re from io import StringIO from contextlib import contextmanager +import datetime import snowflake.connector import snowflake.connector.errors @@ -14,6 +15,29 @@ from dbt.logger import GLOBAL_LOGGER as logger +# Provide a sane Timezone that can be pickled +# The datetime objects returned by snowflake-connector-python +# are not currently pickleable, so they can't be shared across +# process boundaries (ie. in the dbt rpc server) +# See: https://github.com/snowflakedb/snowflake-connector-python/pull/188 +class OffsetTimezone(datetime.tzinfo): + def __init__(self, name, tzoffset_seconds): + self.name = name + self.tzoffset = datetime.timedelta(seconds=tzoffset_seconds) + + def utcoffset(self, dt, is_dst=False): + return self.tzoffset + + def tzname(self, dt): + return self.name + + def dst(self, dt): + return datetime.timedelta(0) + + def __repr__(self): + return self.name + + SNOWFLAKE_CREDENTIALS_CONTRACT = { 'type': 'object', 'additionalProperties': False, @@ -197,6 +221,29 @@ def _get_private_key(cls, private_key_path, private_key_passphrase): format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption()) + @classmethod + def process_results(cls, column_names, rows): + # Override for Snowflake. The datetime objects returned by + # snowflake-connector-python are not pickleable, so we need + # to replace them with sane timezones + fixed = [] + for row in rows: + fixed_row = [] + for col in row: + if isinstance(col, datetime.datetime) and col.tzinfo: + offset = col.utcoffset() + offset_seconds = offset.total_seconds() + new_timezone = OffsetTimezone( + dbt.compat.to_native_string(col.tzinfo.tzname(col)), + offset_seconds) + col = col.astimezone(tz=new_timezone) + fixed_row.append(col) + + fixed.append(fixed_row) + + return super(SnowflakeConnectionManager, cls).process_results( + column_names, fixed) + def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False): diff --git a/test/integration/048_rpc_test/sql/bigquery.sql b/test/integration/048_rpc_test/sql/bigquery.sql new file mode 100644 index 00000000000..c7de4deb4d3 --- /dev/null +++ b/test/integration/048_rpc_test/sql/bigquery.sql @@ -0,0 +1,26 @@ + +select + cast(1 as int64) as test_int64, + cast(1 as numeric) as test_numeric, + cast(1 as float64) as test_float64, + cast('inf' as float64) as test_float64_inf, + cast('+inf' as float64) as test_float64_pos_inf, + cast('-inf' as float64) as test_float64_neg_inf, + cast('NaN' as float64) as test_float64_nan, + cast(true as boolean) as test_boolean, + cast('abc123' as string) as test_string, + cast('abc123' as bytes) as test_bytes, + cast('2019-01-01' as date) as test_date, + cast('12:00:00' as time) as test_time, + cast('2019-01-01 12:00:00' as timestamp) as test_timestamp, + timestamp('2019-01-01T12:00:00+04:00') as test_timestamp_tz, + st_geogfromgeojson('{ "type": "LineString", "coordinates": [ [1, 1], [3, 2] ] }') as test_geo, + [ + struct(1 as val_1, 2 as val_2), + struct(3 as val_1, 4 as val_2) + ] as test_array, + struct( + cast('Fname' as string) as fname, + cast('Lname' as string) as lname + ) as test_struct, + cast(null as int64) as test_null diff --git a/test/integration/048_rpc_test/sql/redshift.sql b/test/integration/048_rpc_test/sql/redshift.sql new file mode 100644 index 00000000000..2bd3b2be5ce --- /dev/null +++ b/test/integration/048_rpc_test/sql/redshift.sql @@ -0,0 +1,15 @@ + +select + cast(1 as smallint) as test_smallint, + cast(1 as int) as test_int, + cast(1 as bigint) as test_bigint, + cast(1 as decimal) as test_decimal, + cast(1 as numeric(12,2)) as test_numeric, + cast(true as boolean) as test_boolean, + cast('abc123' as char) as test_char, + cast('abc123' as varchar) as test_varchar, + cast('abc123' as text) as test_text, + cast('2019-01-01' as date) as test_date, + cast('2019-01-01 12:00:00' as timestamp) as test_timestamp, + cast('2019-01-01 12:00:00+04:00' as timestamptz) as test_timestamptz, + cast(null as int) as test_null diff --git a/test/integration/048_rpc_test/sql/snowflake.sql b/test/integration/048_rpc_test/sql/snowflake.sql new file mode 100644 index 00000000000..319e0b1985c --- /dev/null +++ b/test/integration/048_rpc_test/sql/snowflake.sql @@ -0,0 +1,17 @@ + +select + cast(1 as number(12, 2)) as test_number, + cast(1 as int) as test_int, + cast(1 as float) as test_float, + cast('abc123' as varchar) as test_varchar, + cast('abc123' as char(6)) as test_char, + cast('C0FF33' as binary) as test_binary, + cast('2019-01-01' as date) as test_date, + cast('2019-01-01 12:00:00' as datetime) as test_datetime, + cast('12:00:00' as time) as test_time, + cast('2019-01-01 12:00:00' as timestamp_ltz) as test_timestamp_ltz, + cast('2019-01-01 12:00:00' as timestamp_ntz) as test_timestamp_ntz, + cast('2019-01-01 12:00:00' as timestamp_tz) as test_timestamp_tz, + cast(parse_json('{"a": 1, "b": 2}') as variant) as test_variant, + cast(parse_json('{"a": 1, "b": 2}') as object) as test_object, + cast(parse_json('[{"a": 1, "b": 2}]') as array) as test_array diff --git a/test/integration/048_rpc_test/test_execute_fetch_and_serialize.py b/test/integration/048_rpc_test/test_execute_fetch_and_serialize.py new file mode 100644 index 00000000000..73d140a84d0 --- /dev/null +++ b/test/integration/048_rpc_test/test_execute_fetch_and_serialize.py @@ -0,0 +1,50 @@ +from test.integration.base import DBTIntegrationTest, use_profile +import pickle +import os + +class TestRpcExecuteReturnsResults(DBTIntegrationTest): + + @property + def schema(self): + return "rpc_test_048" + + @property + def models(self): + return "models" + + @property + def project_config(self): + return { + 'macro-paths': ['macros'], + } + + def test_pickle(self, agate_table): + table = { + 'column_names': list(agate_table.column_names), + 'rows': [list(row) for row in agate_table] + } + + pickle.dumps(table) + + def test_file(self, filename): + file_path = os.path.join("sql", filename) + with open(file_path) as fh: + query = fh.read() + + status, table = self.adapter.execute(query, auto_begin=False, fetch=True) + self.assertTrue(len(table.columns) > 0, "agate table had no columns") + self.assertTrue(len(table.rows) > 0, "agate table had no rows") + + self.test_pickle(table) + + @use_profile('bigquery') + def test__bigquery_fetch_and_serialize(self): + self.test_file('bigquery.sql') + + @use_profile('snowflake') + def test__snowflake_fetch_and_serialize(self): + self.test_file('snowflake.sql') + + @use_profile('redshift') + def test__redshift_fetch_and_serialize(self): + self.test_file('redshift.sql')