diff --git a/plugins/snowflake/dbt/adapters/snowflake/connections.py b/plugins/snowflake/dbt/adapters/snowflake/connections.py index 93a91866570..c7f117a060c 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/connections.py +++ b/plugins/snowflake/dbt/adapters/snowflake/connections.py @@ -7,6 +7,8 @@ import dbt.compat import dbt.exceptions +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization from dbt.adapters.base import Credentials from dbt.adapters.sql import SQLConnectionManager from dbt.logger import GLOBAL_LOGGER as logger @@ -29,6 +31,12 @@ 'type': 'string', 'description': "Either 'externalbrowser', or a valid Okta url" }, + 'private_key_path': { + 'type': 'string', + }, + 'private_key_passphrase': { + 'type': 'string', + }, 'database': { 'type': 'string', }, @@ -104,6 +112,11 @@ def open(cls, connection): auth_args = {auth_key: credentials[auth_key] for auth_key in ['user', 'password', 'authenticator'] if auth_key in credentials} + + auth_args['private_key'] = cls._get_private_key( + credentials.get('private_key_path'), + credentials.get('private_key_passphrase')) + handle = snowflake.connector.connect( account=credentials.account, database=credentials.database, @@ -163,6 +176,23 @@ def _split_queries(cls, sql): split_query = snowflake.connector.util_text.split_statements(sql_buf) return [part[0] for part in split_query] + @classmethod + def _get_private_key(cls, private_key_path, private_key_passphrase): + """Get Snowflake private key by path or None.""" + if private_key_path is None or private_key_passphrase is None: + return None + + with open(private_key_path, 'rb') as key: + p_key = serialization.load_pem_private_key( + key.read(), + password=private_key_passphrase.encode(), + backend=default_backend()) + + return p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption()) + def add_query(self, sql, model_name=None, auto_begin=True, bindings=None, abridge_sql_log=False): diff --git a/test/unit/test_snowflake_adapter.py b/test/unit/test_snowflake_adapter.py index 741b4a0b57b..5381efc1d3e 100644 --- a/test/unit/test_snowflake_adapter.py +++ b/test/unit/test_snowflake_adapter.py @@ -1,3 +1,5 @@ +from mock import patch + import mock import unittest @@ -150,7 +152,7 @@ def test_client_session_keep_alive_false_by_default(self): account='test_account', autocommit=False, client_session_keep_alive=False, database='test_databse', role=None, schema='public', user='test_user', - warehouse='test_warehouse') + warehouse='test_warehouse', private_key=None) ]) def test_client_session_keep_alive_true(self): @@ -164,7 +166,7 @@ def test_client_session_keep_alive_true(self): account='test_account', autocommit=False, client_session_keep_alive=True, database='test_databse', role=None, schema='public', user='test_user', - warehouse='test_warehouse') + warehouse='test_warehouse', private_key=None) ]) def test_user_pass_authentication(self): @@ -178,7 +180,7 @@ def test_user_pass_authentication(self): account='test_account', autocommit=False, client_session_keep_alive=False, database='test_databse', password='test_password', role=None, schema='public', - user='test_user', warehouse='test_warehouse') + user='test_user', warehouse='test_warehouse', private_key=None) ]) def test_authenticator_user_pass_authentication(self): @@ -193,7 +195,7 @@ def test_authenticator_user_pass_authentication(self): client_session_keep_alive=False, database='test_databse', password='test_password', role=None, schema='public', user='test_user', warehouse='test_warehouse', - authenticator='test_sso_url') + authenticator='test_sso_url', private_key=None) ]) def test_authenticator_externalbrowser_authentication(self): @@ -207,5 +209,23 @@ def test_authenticator_externalbrowser_authentication(self): account='test_account', autocommit=False, client_session_keep_alive=False, database='test_databse', role=None, schema='public', user='test_user', - warehouse='test_warehouse', authenticator='externalbrowser') + warehouse='test_warehouse', authenticator='externalbrowser', + private_key=None) + ]) + + @patch('dbt.adapters.snowflake.SnowflakeConnectionManager._get_private_key', return_value='test_key') + def test_authenticator_private_key_authentication(self, mock_get_private_key): + self.config.credentials = self.config.credentials.incorporate( + private_key_path='/tmp/test_key.p8', + private_key_passphrase='p@ssphr@se') + + self.adapter = SnowflakeAdapter(self.config) + self.adapter.connections.get(name='new_connection_with_new_config') + + self.snowflake.assert_has_calls([ + mock.call( + account='test_account', autocommit=False, + client_session_keep_alive=False, database='test_databse', + role=None, schema='public', user='test_user', + warehouse='test_warehouse', private_key='test_key') ])