diff --git a/plugins/snowflake/dbt/adapters/snowflake/connections.py b/plugins/snowflake/dbt/adapters/snowflake/connections.py index 1b517f90661..93a91866570 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/connections.py +++ b/plugins/snowflake/dbt/adapters/snowflake/connections.py @@ -25,6 +25,10 @@ 'password': { 'type': 'string', }, + 'authenticator': { + 'type': 'string', + 'description': "Either 'externalbrowser', or a valid Okta url" + }, 'database': { 'type': 'string', }, @@ -41,7 +45,7 @@ 'type': 'boolean', } }, - 'required': ['account', 'user', 'password', 'database', 'schema'], + 'required': ['account', 'user', 'database', 'schema'], } @@ -95,17 +99,21 @@ def open(cls, connection): try: credentials = connection.credentials + # Pull all of the optional authentication args for the connector, + # let connector handle the actual arg validation + auth_args = {auth_key: credentials[auth_key] + for auth_key in ['user', 'password', 'authenticator'] + if auth_key in credentials} handle = snowflake.connector.connect( account=credentials.account, - user=credentials.user, - password=credentials.password, database=credentials.database, schema=credentials.schema, warehouse=credentials.warehouse, role=credentials.get('role', None), autocommit=False, client_session_keep_alive=credentials.get( - 'client_session_keep_alive', False) + 'client_session_keep_alive', False), + **auth_args ) connection.handle = handle diff --git a/test/unit/test_snowflake_adapter.py b/test/unit/test_snowflake_adapter.py index 0a791f13217..741b4a0b57b 100644 --- a/test/unit/test_snowflake_adapter.py +++ b/test/unit/test_snowflake_adapter.py @@ -11,6 +11,7 @@ from .utils import config_from_parts_or_dicts + class TestSnowflakeAdapter(unittest.TestCase): def setUp(self): flags.STRICT_MODE = False @@ -21,7 +22,6 @@ def setUp(self): 'type': 'snowflake', 'account': 'test_account', 'user': 'test_user', - 'password': 'test_password', 'database': 'test_databse', 'warehouse': 'test_warehouse', 'schema': 'public', @@ -42,10 +42,12 @@ def setUp(self): } self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) - self.handle = mock.MagicMock(spec=snowflake_connector.SnowflakeConnection) + self.handle = mock.MagicMock( + spec=snowflake_connector.SnowflakeConnection) self.cursor = self.handle.cursor.return_value self.mock_execute = self.cursor.execute - self.patcher = mock.patch('dbt.adapters.snowflake.connections.snowflake.connector.connect') + self.patcher = mock.patch( + 'dbt.adapters.snowflake.connections.snowflake.connector.connect') self.snowflake = self.patcher.start() self.snowflake.return_value = self.handle @@ -74,7 +76,9 @@ def test_quoting_on_drop(self): ) self.adapter.drop_relation(relation) self.mock_execute.assert_has_calls([ - mock.call('drop table if exists "test_schema".test_table cascade', None) + mock.call( + 'drop table if exists "test_schema".test_table cascade', + None) ]) def test_quoting_on_truncate(self): @@ -108,7 +112,9 @@ def test_quoting_on_rename(self): to_relation=to_relation ) self.mock_execute.assert_has_calls([ - mock.call('alter table "test_schema".table_a rename to table_b', None) + mock.call( + 'alter table "test_schema".table_a rename to table_b', + None) ]) def test_cancel_open_connections_empty(self): @@ -131,9 +137,11 @@ def test_cancel_open_connections_single(self): query_result = mock.MagicMock() add_query.return_value = (None, query_result) - self.assertEqual(len(list(self.adapter.cancel_open_connections())), 1) + self.assertEqual( + len(list(self.adapter.cancel_open_connections())), 1) - add_query.assert_called_once_with('select system$abort_session(42)', 'master') + add_query.assert_called_once_with( + 'select system$abort_session(42)', 'master') def test_client_session_keep_alive_false_by_default(self): self.adapter.connections.get(name='new_connection_with_new_config') @@ -141,8 +149,8 @@ def test_client_session_keep_alive_false_by_default(self): mock.call( account='test_account', autocommit=False, client_session_keep_alive=False, database='test_databse', - password='test_password', role=None, schema='public', - user='test_user', warehouse='test_warehouse') + role=None, schema='public', user='test_user', + warehouse='test_warehouse') ]) def test_client_session_keep_alive_true(self): @@ -155,6 +163,49 @@ def test_client_session_keep_alive_true(self): mock.call( account='test_account', autocommit=False, client_session_keep_alive=True, database='test_databse', + role=None, schema='public', user='test_user', + warehouse='test_warehouse') + ]) + + def test_user_pass_authentication(self): + self.config.credentials = self.config.credentials.incorporate( + password='test_password') + self.adapter = SnowflakeAdapter(self.config) + self.adapter.connections.get(name='new_connection_with_new_config') + + self.snowflake.assert_has_calls([ + mock.call( + account='test_account', autocommit=False, + client_session_keep_alive=False, database='test_databse', password='test_password', role=None, schema='public', user='test_user', warehouse='test_warehouse') ]) + + def test_authenticator_user_pass_authentication(self): + self.config.credentials = self.config.credentials.incorporate( + password='test_password', authenticator='test_sso_url') + self.adapter = SnowflakeAdapter(self.config) + self.adapter.connections.get(name='new_connection_with_new_config') + + self.snowflake.assert_has_calls([ + mock.call( + account='test_account', autocommit=False, + client_session_keep_alive=False, database='test_databse', + password='test_password', role=None, schema='public', + user='test_user', warehouse='test_warehouse', + authenticator='test_sso_url') + ]) + + def test_authenticator_externalbrowser_authentication(self): + self.config.credentials = self.config.credentials.incorporate( + authenticator='externalbrowser') + self.adapter = SnowflakeAdapter(self.config) + self.adapter.connections.get(name='new_connection_with_new_config') + + self.snowflake.assert_has_calls([ + mock.call( + account='test_account', autocommit=False, + client_session_keep_alive=False, database='test_databse', + role=None, schema='public', user='test_user', + warehouse='test_warehouse', authenticator='externalbrowser') + ])