Skip to content

Commit

Permalink
actually update the connection
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob Beck committed Feb 26, 2020
1 parent 9081303 commit 6e8d4a6
Showing 1 changed file with 86 additions and 10 deletions.
96 changes: 86 additions & 10 deletions plugins/snowflake/dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import datetime
import pytz
import re
Expand All @@ -8,15 +9,22 @@

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
import requests
import snowflake.connector
import snowflake.connector.errors

import dbt.exceptions
from dbt.exceptions import (
InternalException, RuntimeException, FailedToConnectException,
DatabaseException, warn_or_error
)
from dbt.adapters.base import Credentials
from dbt.adapters.sql import SQLConnectionManager
from dbt.logger import GLOBAL_LOGGER as logger


_TOKEN_REQUEST_URL = 'https://{}.snowflakecomputing.com/oauth/token-request'


@dataclass
class SnowflakeCredentials(Credentials):
account: str
Expand All @@ -28,15 +36,30 @@ class SnowflakeCredentials(Credentials):
private_key_path: Optional[str]
private_key_passphrase: Optional[str]
token: Optional[str]
oauth_client_id: Optional[str]
oauth_client_secret: Optional[str]
client_session_keep_alive: bool = False

def __post_init__(self):
if (
self.authenticator != 'oauth' and
(self.oauth_client_secret or self.oauth_client_id or self.token)
):
# the user probably forgot to set 'authenticator' like I keep doing
warn_or_error(
'Authenticator is not set to oauth, but an oauth-only '
'parameter is set! Did you mean to set authenticator: oauth?'
)

@property
def type(self):
return 'snowflake'

def _connection_keys(self):
return ('account', 'user', 'database', 'schema', 'warehouse', 'role',
'client_session_keep_alive')
return (
'account', 'user', 'database', 'schema', 'warehouse', 'role',
'client_session_keep_alive'
)

def auth_args(self):
# Pull all of the optional authentication args for the connector,
Expand All @@ -47,10 +70,63 @@ def auth_args(self):
if self.authenticator:
result['authenticator'] = self.authenticator
if self.authenticator == 'oauth':
result['token'] = self.token
token = self.token
# if we have a client ID/client secret, the token is a refresh
# token, not an access token
if self.oauth_client_id and self.oauth_client_secret:
token = self._get_access_token()
elif self.oauth_client_id:
warn_or_error(
'Invalid profile: got an oauth_client_id, but not an '
'oauth_client_secret!'
)
elif self.oauth_client_secret:
warn_or_error(
'Invalid profile: got an oauth_client_secret, but not '
'an oauth_client_id!'
)

result['token'] = token
result['private_key'] = self._get_private_key()
return result

def _get_access_token(self) -> str:
if self.authenticator != 'oauth':
raise InternalException('Can only get access tokens for oauth')
missing = any(
x is None for x in
(self.oauth_client_id, self.oauth_client_secret, self.token)
)
if missing:
raise InternalException(
'need a client ID a client secret, and a refresh token to get '
'an access token'
)
# should the full url be a config item?
token_url = _TOKEN_REQUEST_URL.format(self.account)
# I think this is only used to redirect on success, which we ignore
# (it does not have to match the integration's settings in snowflake)
redirect_uri = 'http://localhost:9999'
data = {
'grant_type': 'refresh_token',
'refresh_token': self.token,
'redirect_uri': redirect_uri
}

auth = base64.b64encode(
f'{self.oauth_client_id}:{self.oauth_client_secret}'
.encode('ascii')
).decode('ascii')
headers = {
'Authorization': f'Basic {auth}',
'Content-type': 'application/x-www-form-urlencoded;charset=utf-8'
}
result = requests.post(token_url, headers=headers, data=data)
result_json = result.json()
if 'access_token' not in result_json:
raise DatabaseException(f'Did not get a token: {result_json}')
return result_json['access_token']

def _get_private_key(self):
"""Get Snowflake private key by path or None."""
if not self.private_key_path or self.private_key_passphrase is None:
Expand Down Expand Up @@ -84,25 +160,25 @@ def exception_handler(self, sql):
logger.debug("got empty sql statement, moving on")
elif 'This session does not have a current database' in msg:
self.release()
raise dbt.exceptions.FailedToConnectException(
raise FailedToConnectException(
('{}\n\nThis error sometimes occurs when invalid '
'credentials are provided, or when your default role '
'does not have access to use the specified database. '
'Please double check your profile and try again.')
.format(msg))
else:
self.release()
raise dbt.exceptions.DatabaseException(msg)
raise DatabaseException(msg)
except Exception as e:
logger.debug("Error running SQL: {}", sql)
logger.debug("Rolling back transaction.")
self.release()
if isinstance(e, dbt.exceptions.RuntimeException):
if isinstance(e, RuntimeException):
# during a sql query, an internal to dbt exception was raised.
# this sounds a lot like a signal handler and probably has
# useful information, so raise it without modification.
raise
raise dbt.exceptions.RuntimeException(str(e)) from e
raise RuntimeException(str(e)) from e

@classmethod
def open(cls, connection):
Expand Down Expand Up @@ -136,7 +212,7 @@ def open(cls, connection):
connection.handle = None
connection.state = 'fail'

raise dbt.exceptions.FailedToConnectException(str(e))
raise FailedToConnectException(str(e))

def cancel(self, connection):
handle = connection.handle
Expand Down Expand Up @@ -228,7 +304,7 @@ def add_query(self, sql, auto_begin=True,
else:
conn_name = conn.name

raise dbt.exceptions.RuntimeException(
raise RuntimeException(
"Tried to run an empty query on model '{}'. If you are "
"conditionally running\nsql, eg. in a model hook, make "
"sure your `else` clause contains valid sql!\n\n"
Expand Down

0 comments on commit 6e8d4a6

Please sign in to comment.