diff --git a/plugins/snowflake/dbt/adapters/snowflake/connections.py b/plugins/snowflake/dbt/adapters/snowflake/connections.py index 9862f393afc..e6b60d1fafb 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/connections.py +++ b/plugins/snowflake/dbt/adapters/snowflake/connections.py @@ -1,22 +1,21 @@ -import re -from io import StringIO -from contextlib import contextmanager import datetime import pytz +import re +from contextlib import contextmanager +from dataclasses import dataclass +from io import StringIO +from typing import Optional +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization import snowflake.connector import snowflake.connector.errors import dbt.exceptions -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization from dbt.adapters.base import Credentials from dbt.adapters.sql import SQLConnectionManager from dbt.logger import GLOBAL_LOGGER as logger -from dataclasses import dataclass -from typing import Optional - @dataclass class SnowflakeCredentials(Credentials): @@ -28,6 +27,7 @@ class SnowflakeCredentials(Credentials): authenticator: Optional[str] private_key_path: Optional[str] private_key_passphrase: Optional[str] + token: Optional[str] client_session_keep_alive: bool = False @property @@ -46,6 +46,8 @@ def auth_args(self): result['password'] = self.password if self.authenticator: result['authenticator'] = self.authenticator + if self.authenticator == 'oauth': + result['token'] = self.token result['private_key'] = self._get_private_key() return result diff --git a/plugins/snowflake/setup.py b/plugins/snowflake/setup.py index 96c0c527f92..bf0addc7e79 100644 --- a/plugins/snowflake/setup.py +++ b/plugins/snowflake/setup.py @@ -44,6 +44,8 @@ 'azure-storage-blob~=2.1', 'azure-storage-common~=2.1', 'urllib3<1.25.0', + # this seems sufficiently broad + 'cryptography>=2,<3', ], zip_safe=False, classifiers=[ diff --git a/test/unit/test_snowflake_adapter.py b/test/unit/test_snowflake_adapter.py index a54d6921289..7c10b11ef96 100644 --- a/test/unit/test_snowflake_adapter.py +++ b/test/unit/test_snowflake_adapter.py @@ -314,6 +314,25 @@ def test_authenticator_externalbrowser_authentication(self): private_key=None, application='dbt') ]) + def test_authenticator_oauth_authentication(self): + self.config.credentials = self.config.credentials.replace( + authenticator='oauth', + token='my-oauth-token', + ) + self.adapter = SnowflakeAdapter(self.config) + conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config') + + self.snowflake.assert_not_called() + conn.handle + self.snowflake.assert_has_calls([ + mock.call( + account='test_account', autocommit=False, + client_session_keep_alive=False, database='test_database', + role=None, schema='public', user='test_user', + warehouse='test_warehouse', authenticator='oauth', token='my-oauth-token', + private_key=None, application='dbt') + ]) + @mock.patch('dbt.adapters.snowflake.SnowflakeCredentials._get_private_key', return_value='test_key') def test_authenticator_private_key_authentication(self, mock_get_private_key): self.config.credentials = self.config.credentials.replace(