Skip to content

Commit

Permalink
Add AzureCliCredential to login with CLI
Browse files Browse the repository at this point in the history
See [this closed PR](microsoft/dbt-synapse#35)
for the full commit history
  • Loading branch information
JCZuurmond committed Jan 6, 2021
1 parent ef4f8da commit 339debb
Showing 1 changed file with 147 additions and 66 deletions.
213 changes: 147 additions & 66 deletions dbt/adapters/sqlserver/connections.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,21 @@
from contextlib import contextmanager

import pyodbc
import os
import time
import struct
import time
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import chain, repeat
from typing import Callable, Mapping, Optional

import dbt.exceptions
import pyodbc
from dbt.adapters.base import Credentials
from dbt.adapters.sql import SQLConnectionManager
from azure.identity import DefaultAzureCredential

from dbt.logger import GLOBAL_LOGGER as logger

from dataclasses import dataclass
from typing import Optional
from azure.core.credentials import AccessToken
from azure.identity import AzureCliCredential, DefaultAzureCredential


def create_token(tenant_id, client_id, client_secret):
# bc DefaultAzureCredential will look in env variables
os.environ["AZURE_TENANT_ID"] = tenant_id
os.environ["AZURE_CLIENT_ID"] = client_id
os.environ["AZURE_CLIENT_SECRET"] = client_secret

token = DefaultAzureCredential().get_token("https://database.windows.net//.default")
# convert to byte string interspersed with the 1-byte
# TODO decide which is cleaner?
# exptoken=b''.join([bytes({i})+bytes(1) for i in bytes(token.token, "UTF-8")])
exptoken = bytes(1).join([bytes(i, "UTF-8") for i in token.token]) + bytes(1)
# make c object with bytestring length prefix
tokenstruct = struct.pack("=i", len(exptoken)) + exptoken

return tokenstruct
AZURE_CREDENTIAL_SCOPE = "https://database.windows.net//.default"


@dataclass
Expand All @@ -49,8 +34,8 @@ class SQLServerCredentials(Credentials):
# "sql", "ActiveDirectoryPassword" or "ActiveDirectoryInteractive", or
# "ServicePrincipal"
authentication: Optional[str] = "sql"
encrypt: Optional[bool] = False
trust_cert: Optional[bool] = False
encrypt: Optional[bool] = True
trust_cert: Optional[bool] = True

_ALIASES = {
"user": "UID",
Expand All @@ -74,7 +59,7 @@ def _connection_keys(self):
# raise NotImplementedError
if self.windows_login is True:
self.authentication = "Windows Login"

return (
"server",
"database",
Expand All @@ -84,10 +69,101 @@ def _connection_keys(self):
"client_id",
"authentication",
"encrypt",
"trust_cert"
"trust_cert",
)


def convert_bytes_to_mswindows_byte_string(value: bytes) -> bytes:
"""
Convert bytes to a Microsoft windows byte string.
Parameters
----------
value : bytes
The bytes.
Returns
-------
out : bytes
The Microsoft byte string.
"""
encoded_bytes = bytes(chain.from_iterable(zip(value, repeat(0))))
return struct.pack("<i", len(encoded_bytes)) + encoded_bytes


def convert_access_token_to_mswindows_byte_string(token: AccessToken) -> bytes:
"""
Convert an access token to a Microsoft windows byte string.
Parameters
----------
token : AccessToken
The token.
Returns
-------
out : bytes
The Microsoft byte string.
"""
value = bytes(token.token, "UTF-8")
return convert_bytes_to_mswindows_byte_string(value)


def get_cli_access_token(credentials: SQLServerCredentials) -> AccessToken:
"""
Get an Azure access token using the CLI credentials
First login with:
```bash
az login
```
Parameters
----------
credentials: SQLServerConnectionManager
The credentials.
Returns
-------
out : AccessToken
Access token.
"""
_ = credentials
token = AzureCliCredential().get_token(AZURE_CREDENTIAL_SCOPE)
return token


def get_sp_access_token(credentials: SQLServerCredentials) -> AccessToken:
"""
Get an Azure access token using the SP credentials.
Parameters
----------
credentials : SQLServerCredentials
Credentials.
Returns
-------
out : AccessToken
The access token.
"""
# bc DefaultAzureCredential will look in env variables
os.environ["AZURE_TENANT_ID"] = credentials.tenant_id
os.environ["AZURE_CLIENT_ID"] = credentials.client_id
os.environ["AZURE_CLIENT_SECRET"] = credentials.client_secret

token = DefaultAzureCredential().get_token(AZURE_CREDENTIAL_SCOPE)
return token


AZURE_AUTH_FUNCTION_TYPE = Callable[[SQLServerCredentials], AccessToken]
AZURE_AUTH_FUNCTIONS: Mapping[str, AZURE_AUTH_FUNCTION_TYPE] = {
"ServicePrincipal": get_sp_access_token,
"CLI": get_cli_access_token,
}


class SQLServerConnectionManager(SQLConnectionManager):
TYPE = "sqlserver"
TOKEN = None
Expand Down Expand Up @@ -135,8 +211,9 @@ def open(cls, connection):
con_str.append(f"DRIVER={{{credentials.driver}}}")

if "\\" in credentials.host:
# if there is a backslash \ in the host name the host is a sql-server named instance
# in this case then port number has to be omitted
# if there is a backslash \ in the host name the host is a
# sql-server named instance in this case then port number has
# to be omitted
con_str.append(f"SERVER={credentials.host}")
else:
con_str.append(f"SERVER={credentials.host},{credentials.port}")
Expand All @@ -159,55 +236,53 @@ def open(cls, connection):
elif type_auth == "ActiveDirectoryMsi":
raise ValueError("ActiveDirectoryMsi is not supported yet")

elif type_auth == "ServicePrincipal":
app_id = getattr(credentials, "AppId", None)
app_secret = getattr(credentials, "AppSecret", None)

elif getattr(credentials, "windows_login", False):
con_str.append(f"trusted_connection=yes")
con_str.append("trusted_connection=yes")
elif type_auth == "sql":
#con_str.append("Authentication=SqlPassword")
con_str.append("Authentication=SqlPassword")
con_str.append(f"UID={{{credentials.UID}}}")
con_str.append(f"PWD={{{credentials.PWD}}}")

# still confused whether to use "Yes", "yes", "True", or "true"
# to learn more visit
# https://docs.microsoft.com/en-us/sql/relational-databases/native-client/features/using-encryption-without-validation?view=sql-server-ver15
if getattr(credentials, "encrypt", False) is True:
con_str.append(f"Encrypt=Yes")
if getattr(credentials, "trust_cert", False) is True:
con_str.append(f"TrustServerCertificate=Yes")
if not getattr(credentials, "encrypt", False):
con_str.append("Encrypt=yes")
if not getattr(credentials, "trust_cert", False):
con_str.append("TrustServerCertificate=yes")

con_str_concat = ";".join(con_str)

con_str_concat = ';'.join(con_str)

index = []
for i, elem in enumerate(con_str):
if 'pwd=' in elem.lower():
if "pwd=" in elem.lower():
index.append(i)

if len(index) !=0 :
con_str[index[0]]="PWD=***"

con_str_display = ';'.join(con_str)

logger.debug(f'Using connection string: {con_str_display}')
if len(index) != 0:
con_str[index[0]] = "PWD=***"

if type_auth != "ServicePrincipal":
handle = pyodbc.connect(con_str_concat, autocommit=True)
con_str_display = ";".join(con_str)

elif type_auth == "ServicePrincipal":
logger.debug(f"Using connection string: {con_str_display}")

# create token if it does not exist
if type_auth in AZURE_AUTH_FUNCTIONS.keys():
if cls.TOKEN is None:
tenant_id = getattr(credentials, "tenant_id", None)
client_id = getattr(credentials, "client_id", None)
client_secret = getattr(credentials, "client_secret", None)
azure_auth_function = AZURE_AUTH_FUNCTIONS[type_auth]
token = azure_auth_function(credentials)
cls.TOKEN = convert_access_token_to_mswindows_byte_string(
token
)

cls.TOKEN = create_token(tenant_id, client_id, client_secret)
# Source:
# https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory?view=sql-server-ver15#authenticating-with-an-access-token
SQL_COPT_SS_ACCESS_TOKEN = 1256

handle = pyodbc.connect(
con_str_concat, attrs_before={1256: cls.TOKEN}, autocommit=True
)
attrs_before = {SQL_COPT_SS_ACCESS_TOKEN: cls.TOKEN}
else:
attrs_before = {}

handle = pyodbc.connect(
con_str_concat,
attrs_before=attrs_before,
autocommit=True,
)

connection.state = "open"
connection.handle = handle
Expand Down Expand Up @@ -235,18 +310,24 @@ def add_commit_query(self):
# return self.add_query('COMMIT TRANSACTION', auto_begin=False)
pass

def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False):
def add_query(
self, sql, auto_begin=True, bindings=None, abridge_sql_log=False
):

connection = self.get_thread_connection()

if auto_begin and connection.transaction_open is False:
self.begin()

logger.debug('Using {} connection "{}".'.format(self.TYPE, connection.name))
logger.debug(
'Using {} connection "{}".'.format(self.TYPE, connection.name)
)

with self.exception_handler(sql):
if abridge_sql_log:
logger.debug("On {}: {}....".format(connection.name, sql[0:512]))
logger.debug(
"On {}: {}....".format(connection.name, sql[0:512])
)
else:
logger.debug("On {}: {}".format(connection.name, sql))
pre = time.time()
Expand Down

0 comments on commit 339debb

Please sign in to comment.