Skip to content

Commit

Permalink
fix for unpickleable datetime tzs set by snowflake
Browse files Browse the repository at this point in the history
  • Loading branch information
drewbanin committed Aug 9, 2019
1 parent 3e3c69e commit fe3986b
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 4 deletions.
4 changes: 3 additions & 1 deletion core/dbt/adapters/factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dbt.exceptions
from importlib import import_module
from dbt.include.global_project import PACKAGES
from dbt.logger import GLOBAL_LOGGER as logger

import threading

Expand Down Expand Up @@ -29,7 +30,8 @@ def get_relation_class_by_name(adapter_name):
def load_plugin(adapter_name):
try:
mod = import_module('.' + adapter_name, 'dbt.adapters')
except ImportError:
except ImportError as e:
logger.info("Error importing adapter: {}".format(e))
raise dbt.exceptions.RuntimeException(
"Could not find adapter type {}!".format(adapter_name)
)
Expand Down
9 changes: 6 additions & 3 deletions core/dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,19 @@ def get_status(cls, cursor):
'`get_status` is not implemented for this adapter!'
)

@classmethod
def process_results(cls, column_names, rows):
return [dict(zip(column_names, row)) for row in rows]

@classmethod
def get_result_from_cursor(cls, cursor):
data = []
column_names = []

if cursor.description is not None:
column_names = [col[0] for col in cursor.description]
raw_results = cursor.fetchall()
data = [dict(zip(column_names, row))
for row in raw_results]
rows = cursor.fetchall()
data = cls.process_results(column_names, rows)

return dbt.clients.agate_helper.table_from_data(data, column_names)

Expand Down
47 changes: 47 additions & 0 deletions plugins/snowflake/dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
from io import StringIO
from contextlib import contextmanager
import datetime

import snowflake.connector
import snowflake.connector.errors
Expand All @@ -14,6 +15,29 @@
from dbt.logger import GLOBAL_LOGGER as logger


# Provide a sane Timezone that can be pickled
# The datetime objects returned by snowflake-connector-python
# are not currently pickleable, so they can't be shared across
# process boundaries (ie. in the dbt rpc server)
# See: https:/snowflakedb/snowflake-connector-python/pull/188
class OffsetTimezone(datetime.tzinfo):
def __init__(self, name, tzoffset_seconds):
self.name = name
self.tzoffset = datetime.timedelta(seconds=tzoffset_seconds)

def utcoffset(self, dt, is_dst=False):
return self.tzoffset

def tzname(self, dt):
return self.name

def dst(self, dt):
return datetime.timedelta(0)

def __repr__(self):
return self.name


SNOWFLAKE_CREDENTIALS_CONTRACT = {
'type': 'object',
'additionalProperties': False,
Expand Down Expand Up @@ -197,6 +221,29 @@ def _get_private_key(cls, private_key_path, private_key_passphrase):
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption())

@classmethod
def process_results(cls, column_names, rows):
# Override for Snowflake. The datetime objects returned by
# snowflake-connector-python are not pickleable, so we need
# to replace them with sane timezones
fixed = []
for row in rows:
fixed_row = []
for col in row:
if isinstance(col, datetime.datetime) and col.tzinfo:
offset = col.utcoffset()
offset_seconds = offset.total_seconds()
new_timezone = OffsetTimezone(
dbt.compat.to_native_string(col.tzinfo.tzname(col)),
offset_seconds)
col = col.astimezone(tz=new_timezone)
fixed_row.append(col)

fixed.append(fixed_row)

return super(SnowflakeConnectionManager, cls).process_results(
column_names, fixed)

def add_query(self, sql, auto_begin=True,
bindings=None, abridge_sql_log=False):

Expand Down
26 changes: 26 additions & 0 deletions test/integration/048_rpc_test/sql/bigquery.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

select
cast(1 as int64) as test_int64,
cast(1 as numeric) as test_numeric,
cast(1 as float64) as test_float64,
cast('inf' as float64) as test_float64_inf,
cast('+inf' as float64) as test_float64_pos_inf,
cast('-inf' as float64) as test_float64_neg_inf,
cast('NaN' as float64) as test_float64_nan,
cast(true as boolean) as test_boolean,
cast('abc123' as string) as test_string,
cast('abc123' as bytes) as test_bytes,
cast('2019-01-01' as date) as test_date,
cast('12:00:00' as time) as test_time,
cast('2019-01-01 12:00:00' as timestamp) as test_timestamp,
timestamp('2019-01-01T12:00:00+04:00') as test_timestamp_tz,
st_geogfromgeojson('{ "type": "LineString", "coordinates": [ [1, 1], [3, 2] ] }') as test_geo,
[
struct(1 as val_1, 2 as val_2),
struct(3 as val_1, 4 as val_2)
] as test_array,
struct(
cast('Fname' as string) as fname,
cast('Lname' as string) as lname
) as test_struct,
cast(null as int64) as test_null
15 changes: 15 additions & 0 deletions test/integration/048_rpc_test/sql/redshift.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

select
cast(1 as smallint) as test_smallint,
cast(1 as int) as test_int,
cast(1 as bigint) as test_bigint,
cast(1 as decimal) as test_decimal,
cast(1 as numeric(12,2)) as test_numeric,
cast(true as boolean) as test_boolean,
cast('abc123' as char) as test_char,
cast('abc123' as varchar) as test_varchar,
cast('abc123' as text) as test_text,
cast('2019-01-01' as date) as test_date,
cast('2019-01-01 12:00:00' as timestamp) as test_timestamp,
cast('2019-01-01 12:00:00+04:00' as timestamptz) as test_timestamptz,
cast(null as int) as test_null
17 changes: 17 additions & 0 deletions test/integration/048_rpc_test/sql/snowflake.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

select
cast(1 as number(12, 2)) as test_number,
cast(1 as int) as test_int,
cast(1 as float) as test_float,
cast('abc123' as varchar) as test_varchar,
cast('abc123' as char(6)) as test_char,
cast('C0FF33' as binary) as test_binary,
cast('2019-01-01' as date) as test_date,
cast('2019-01-01 12:00:00' as datetime) as test_datetime,
cast('12:00:00' as time) as test_time,
cast('2019-01-01 12:00:00' as timestamp_ltz) as test_timestamp_ltz,
cast('2019-01-01 12:00:00' as timestamp_ntz) as test_timestamp_ntz,
cast('2019-01-01 12:00:00' as timestamp_tz) as test_timestamp_tz,
cast(parse_json('{"a": 1, "b": 2}') as variant) as test_variant,
cast(parse_json('{"a": 1, "b": 2}') as object) as test_object,
cast(parse_json('[{"a": 1, "b": 2}]') as array) as test_array
50 changes: 50 additions & 0 deletions test/integration/048_rpc_test/test_execute_fetch_and_serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from test.integration.base import DBTIntegrationTest, use_profile
import pickle
import os

class TestRpcExecuteReturnsResults(DBTIntegrationTest):

@property
def schema(self):
return "rpc_test_048"

@property
def models(self):
return "models"

@property
def project_config(self):
return {
'macro-paths': ['macros'],
}

def test_pickle(self, agate_table):
table = {
'column_names': list(agate_table.column_names),
'rows': [list(row) for row in agate_table]
}

pickle.dumps(table)

def test_file(self, filename):
file_path = os.path.join("sql", filename)
with open(file_path) as fh:
query = fh.read()

status, table = self.adapter.execute(query, auto_begin=False, fetch=True)
self.assertTrue(len(table.columns) > 0, "agate table had no columns")
self.assertTrue(len(table.rows) > 0, "agate table had no rows")

self.test_pickle(table)

@use_profile('bigquery')
def test__bigquery_fetch_and_serialize(self):
self.test_file('bigquery.sql')

@use_profile('snowflake')
def test__snowflake_fetch_and_serialize(self):
self.test_file('snowflake.sql')

@use_profile('redshift')
def test__redshift_fetch_and_serialize(self):
self.test_file('redshift.sql')

0 comments on commit fe3986b

Please sign in to comment.