Skip to content

Commit

Permalink
Merge pull request #1224 from convoyinc/adriank/add_snowflake_sso_sup…
Browse files Browse the repository at this point in the history
…port

Add support for Snowflake SSO authentication round 2
  • Loading branch information
cmcarthur authored Jan 9, 2019
2 parents 1e5308d + c01caef commit e359a69
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 13 deletions.
16 changes: 12 additions & 4 deletions plugins/snowflake/dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
'password': {
'type': 'string',
},
'authenticator': {
'type': 'string',
'description': "Either 'externalbrowser', or a valid Okta url"
},
'database': {
'type': 'string',
},
Expand All @@ -41,7 +45,7 @@
'type': 'boolean',
}
},
'required': ['account', 'user', 'password', 'database', 'schema'],
'required': ['account', 'user', 'database', 'schema'],
}


Expand Down Expand Up @@ -95,17 +99,21 @@ def open(cls, connection):

try:
credentials = connection.credentials
# Pull all of the optional authentication args for the connector,
# let connector handle the actual arg validation
auth_args = {auth_key: credentials[auth_key]
for auth_key in ['user', 'password', 'authenticator']
if auth_key in credentials}
handle = snowflake.connector.connect(
account=credentials.account,
user=credentials.user,
password=credentials.password,
database=credentials.database,
schema=credentials.schema,
warehouse=credentials.warehouse,
role=credentials.get('role', None),
autocommit=False,
client_session_keep_alive=credentials.get(
'client_session_keep_alive', False)
'client_session_keep_alive', False),
**auth_args
)

connection.handle = handle
Expand Down
69 changes: 60 additions & 9 deletions test/unit/test_snowflake_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .utils import config_from_parts_or_dicts


class TestSnowflakeAdapter(unittest.TestCase):
def setUp(self):
flags.STRICT_MODE = False
Expand All @@ -21,7 +22,6 @@ def setUp(self):
'type': 'snowflake',
'account': 'test_account',
'user': 'test_user',
'password': 'test_password',
'database': 'test_databse',
'warehouse': 'test_warehouse',
'schema': 'public',
Expand All @@ -42,10 +42,12 @@ def setUp(self):
}
self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)

self.handle = mock.MagicMock(spec=snowflake_connector.SnowflakeConnection)
self.handle = mock.MagicMock(
spec=snowflake_connector.SnowflakeConnection)
self.cursor = self.handle.cursor.return_value
self.mock_execute = self.cursor.execute
self.patcher = mock.patch('dbt.adapters.snowflake.connections.snowflake.connector.connect')
self.patcher = mock.patch(
'dbt.adapters.snowflake.connections.snowflake.connector.connect')
self.snowflake = self.patcher.start()

self.snowflake.return_value = self.handle
Expand Down Expand Up @@ -74,7 +76,9 @@ def test_quoting_on_drop(self):
)
self.adapter.drop_relation(relation)
self.mock_execute.assert_has_calls([
mock.call('drop table if exists "test_schema".test_table cascade', None)
mock.call(
'drop table if exists "test_schema".test_table cascade',
None)
])

def test_quoting_on_truncate(self):
Expand Down Expand Up @@ -108,7 +112,9 @@ def test_quoting_on_rename(self):
to_relation=to_relation
)
self.mock_execute.assert_has_calls([
mock.call('alter table "test_schema".table_a rename to table_b', None)
mock.call(
'alter table "test_schema".table_a rename to table_b',
None)
])

def test_cancel_open_connections_empty(self):
Expand All @@ -131,18 +137,20 @@ def test_cancel_open_connections_single(self):
query_result = mock.MagicMock()
add_query.return_value = (None, query_result)

self.assertEqual(len(list(self.adapter.cancel_open_connections())), 1)
self.assertEqual(
len(list(self.adapter.cancel_open_connections())), 1)

add_query.assert_called_once_with('select system$abort_session(42)', 'master')
add_query.assert_called_once_with(
'select system$abort_session(42)', 'master')

def test_client_session_keep_alive_false_by_default(self):
self.adapter.connections.get(name='new_connection_with_new_config')
self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
client_session_keep_alive=False, database='test_databse',
password='test_password', role=None, schema='public',
user='test_user', warehouse='test_warehouse')
role=None, schema='public', user='test_user',
warehouse='test_warehouse')
])

def test_client_session_keep_alive_true(self):
Expand All @@ -155,6 +163,49 @@ def test_client_session_keep_alive_true(self):
mock.call(
account='test_account', autocommit=False,
client_session_keep_alive=True, database='test_databse',
role=None, schema='public', user='test_user',
warehouse='test_warehouse')
])

def test_user_pass_authentication(self):
self.config.credentials = self.config.credentials.incorporate(
password='test_password')
self.adapter = SnowflakeAdapter(self.config)
self.adapter.connections.get(name='new_connection_with_new_config')

self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
client_session_keep_alive=False, database='test_databse',
password='test_password', role=None, schema='public',
user='test_user', warehouse='test_warehouse')
])

def test_authenticator_user_pass_authentication(self):
self.config.credentials = self.config.credentials.incorporate(
password='test_password', authenticator='test_sso_url')
self.adapter = SnowflakeAdapter(self.config)
self.adapter.connections.get(name='new_connection_with_new_config')

self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
client_session_keep_alive=False, database='test_databse',
password='test_password', role=None, schema='public',
user='test_user', warehouse='test_warehouse',
authenticator='test_sso_url')
])

def test_authenticator_externalbrowser_authentication(self):
self.config.credentials = self.config.credentials.incorporate(
authenticator='externalbrowser')
self.adapter = SnowflakeAdapter(self.config)
self.adapter.connections.get(name='new_connection_with_new_config')

self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
client_session_keep_alive=False, database='test_databse',
role=None, schema='public', user='test_user',
warehouse='test_warehouse', authenticator='externalbrowser')
])

0 comments on commit e359a69

Please sign in to comment.