diff --git a/core/dbt/adapters/base/connections.py b/core/dbt/adapters/base/connections.py index 6c5f751dab5..ea2f39c97f7 100644 --- a/core/dbt/adapters/base/connections.py +++ b/core/dbt/adapters/base/connections.py @@ -11,7 +11,7 @@ import dbt.exceptions import dbt.flags from dbt.contracts.connection import ( - Connection, Identifier, ConnectionState, AdapterRequiredConfig + Connection, Identifier, ConnectionState, AdapterRequiredConfig, LazyHandle ) from dbt.adapters.base.query_headers import ( QueryStringSetter, MacroQueryStringSetter, @@ -61,6 +61,14 @@ def get_thread_connection(self) -> Connection: ) return self.thread_connections[key] + def set_thread_connection(self, conn): + key = self.get_thread_identifier() + if key in self.thread_connections: + raise dbt.exceptions.InternalException( + 'In set_thread_connection, existing connection exists for {}' + ) + self.thread_connections[key] = conn + def get_if_exists(self) -> Optional[Connection]: key = self.get_thread_identifier() with self.lock: @@ -109,8 +117,6 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection: conn_name = name conn = self.get_if_exists() - thread_id_key = self.get_thread_identifier() - if conn is None: conn = Connection( type=Identifier(self.TYPE), @@ -120,7 +126,7 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection: handle=None, credentials=self.profile.credentials ) - self.thread_connections[thread_id_key] = conn + self.set_thread_connection(conn) if conn.name == conn_name and conn.state == 'open': return conn @@ -138,7 +144,7 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection: 'Opening a new connection, currently in state {}' .format(conn.state) ) - self.open(conn) + conn.handle = LazyHandle(type(self)) conn.name = conn_name return conn diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 15b534cb5aa..837d32ed766 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -234,11 +234,11 @@ def nice_connection_name(self): @contextmanager def connection_named( self, name: str, node: Optional[CompileResultNode] = None - ): + ) -> Iterator[None]: try: self.connections.query_header.set(name, node) - conn = self.acquire_connection(name) - yield conn + self.acquire_connection(name) + yield finally: self.release_connection() self.connections.query_header.reset() @@ -246,9 +246,9 @@ def connection_named( @contextmanager def connection_for( self, node: CompileResultNode - ) -> Iterator[Connection]: - with self.connection_named(node.unique_id, node) as conn: - yield conn + ) -> Iterator[None]: + with self.connection_named(node.unique_id, node): + yield @available.parse(lambda *a, **k: ('', empty_table())) def execute( diff --git a/core/dbt/contracts/connection.py b/core/dbt/contracts/connection.py index dc1a2423309..4628c5b77a1 100644 --- a/core/dbt/contracts/connection.py +++ b/core/dbt/contracts/connection.py @@ -2,7 +2,7 @@ import itertools from dataclasses import dataclass, field from typing import ( - Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType, List + Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType, List, Type ) from typing_extensions import Protocol @@ -12,6 +12,7 @@ ) from dbt.contracts.util import Replaceable +from dbt.exceptions import InternalException from dbt.utils import translate_aliases @@ -26,6 +27,23 @@ class ConnectionState(StrEnum): FAIL = 'fail' +class ConnectionOpenerProtocol(Protocol): + @classmethod + def open(cls, connection: 'Connection') -> Any: + raise NotImplementedError(f'open() not implemented for {cls.__name__}') + + +class LazyHandle: + """Opener must be a callable that takes a Connection object and opens the + connection, updating the handle on the Connection. + """ + def __init__(self, opener: Type[ConnectionOpenerProtocol]): + self.opener = opener + + def resolve(self, connection: 'Connection') -> Any: + return self.opener.open(connection) + + @dataclass(init=False) class Connection(ExtensibleJsonSchemaMixin, Replaceable): type: Identifier @@ -62,6 +80,15 @@ def credentials(self, value): @property def handle(self): + if isinstance(self._handle, LazyHandle): + try: + # this will actually change 'self._handle'. + self._handle.resolve(self) + except RecursionError as exc: + raise InternalException( + "A connection's open() method attempted to read the " + "handle value" + ) from exc return self._handle @handle.setter diff --git a/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py b/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py index 7d185cb0458..eaa7047c3a7 100644 --- a/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py +++ b/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py @@ -41,11 +41,12 @@ def project_config(self): def run_select_and_check(self, rel, sql): connection_name = '__test_{}'.format(id(threading.current_thread())) try: - with self._secret_adapter.connection_named(connection_name) as conn: + with self._secret_adapter.connection_named(connection_name): + conn = self._secret_adapter.connections.get_thread_connection() res = self.run_sql_common(self.transform_sql(sql), 'one', conn) # The result is the output of f_sleep(), which is True - if res[0] == True: + if res[0]: self.query_state[rel] = 'good' else: self.query_state[rel] = 'bad' diff --git a/test/integration/base.py b/test/integration/base.py index 11308cd366d..3ebf503b4cb 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -753,7 +753,8 @@ def get_connection(self, name=None): if name is None: name = '__test' with patch.object(common, 'get_adapter', return_value=self.adapter): - with self.adapter.connection_named(name) as conn: + with self.adapter.connection_named(name): + conn = self.adapter.connections.get_thread_connection() yield conn def get_relation_columns(self, relation): diff --git a/test/rpc/util.py b/test/rpc/util.py index 856d8261168..999f2beecfa 100644 --- a/test/rpc/util.py +++ b/test/rpc/util.py @@ -519,7 +519,8 @@ def __init__(self, profiles_dir, which='run-operation', kwargs={}): def execute(adapter, sql): - with adapter.connection_named('rpc-tests') as conn: + with adapter.connection_named('rpc-tests'): + conn = adapter.connections.get_thread_connection() with conn.handle.cursor() as cursor: try: cursor.execute(sql) diff --git a/test/unit/test_bigquery_adapter.py b/test/unit/test_bigquery_adapter.py index c78755c2027..06d9545993a 100644 --- a/test/unit/test_bigquery_adapter.py +++ b/test/unit/test_bigquery_adapter.py @@ -100,6 +100,8 @@ def test_acquire_connection_oauth_validations(self, mock_open_connection): except BaseException as e: raise + mock_open_connection.assert_not_called() + connection.handle mock_open_connection.assert_called_once() @patch('dbt.adapters.bigquery.BigQueryConnectionManager.open', return_value=_bq_conn()) @@ -115,6 +117,8 @@ def test_acquire_connection_service_account_validations(self, mock_open_connecti except BaseException as e: raise + mock_open_connection.assert_not_called() + connection.handle mock_open_connection.assert_called_once() @patch('dbt.adapters.bigquery.BigQueryConnectionManager.open', return_value=_bq_conn()) @@ -128,9 +132,8 @@ def test_acquire_connection_priority(self, mock_open_connection): except dbt.exceptions.ValidationException as e: self.fail('got ValidationException: {}'.format(str(e))) - except BaseException as e: - raise - + mock_open_connection.assert_not_called() + connection.handle mock_open_connection.assert_called_once() def test_cancel_open_connections_empty(self): @@ -158,8 +161,11 @@ def test_location_value(self, mock_bq, mock_auth_default): mock_auth_default.return_value = (creds, MagicMock()) adapter = self.get_adapter('loc') - adapter.acquire_connection('dummy') + connection = adapter.acquire_connection('dummy') mock_client = mock_bq.Client + + mock_client.assert_not_called() + connection.handle mock_client.assert_called_once_with('dbt-unit-000000', creds, location='Luna Station') diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index 621da633f56..5b5c6795dce 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -60,12 +60,17 @@ def test_acquire_connection_validations(self, psycopg2): self.fail('acquiring connection failed with unknown exception: {}' .format(str(e))) self.assertEqual(connection.type, 'postgres') + + psycopg2.connect.assert_not_called() + connection.handle psycopg2.connect.assert_called_once() @mock.patch('dbt.adapters.postgres.connections.psycopg2') def test_acquire_connection(self, psycopg2): connection = self.adapter.acquire_connection('dummy') + psycopg2.connect.assert_not_called() + connection.handle self.assertEqual(connection.state, 'open') self.assertNotEqual(connection.handle, None) psycopg2.connect.assert_called_once() @@ -101,6 +106,8 @@ def test_cancel_open_connections_single(self): def test_default_keepalive(self, psycopg2): connection = self.adapter.acquire_connection('dummy') + psycopg2.connect.assert_not_called() + connection.handle psycopg2.connect.assert_called_once_with( dbname='postgres', user='root', @@ -114,6 +121,8 @@ def test_changed_keepalive(self, psycopg2): self.config.credentials = self.config.credentials.replace(keepalives_idle=256) connection = self.adapter.acquire_connection('dummy') + psycopg2.connect.assert_not_called() + connection.handle psycopg2.connect.assert_called_once_with( dbname='postgres', user='root', @@ -128,6 +137,8 @@ def test_search_path(self, psycopg2): self.config.credentials = self.config.credentials.replace(search_path="test") connection = self.adapter.acquire_connection('dummy') + psycopg2.connect.assert_not_called() + connection.handle psycopg2.connect.assert_called_once_with( dbname='postgres', user='root', @@ -142,6 +153,8 @@ def test_schema_with_space(self, psycopg2): self.config.credentials = self.config.credentials.replace(search_path="test test") connection = self.adapter.acquire_connection('dummy') + psycopg2.connect.assert_not_called() + connection.handle psycopg2.connect.assert_called_once_with( dbname='postgres', user='root', @@ -156,6 +169,8 @@ def test_set_zero_keepalive(self, psycopg2): self.config.credentials = self.config.credentials.replace(keepalives_idle=0) connection = self.adapter.acquire_connection('dummy') + psycopg2.connect.assert_not_called() + connection.handle psycopg2.connect.assert_called_once_with( dbname='postgres', user='root', diff --git a/test/unit/test_redshift_adapter.py b/test/unit/test_redshift_adapter.py index be54a82e475..4d509c631c9 100644 --- a/test/unit/test_redshift_adapter.py +++ b/test/unit/test_redshift_adapter.py @@ -153,6 +153,8 @@ def test_cancel_open_connections_single(self): def test_default_keepalive(self, psycopg2): connection = self.adapter.acquire_connection('dummy') + psycopg2.connect.assert_not_called() + connection.handle psycopg2.connect.assert_called_once_with( dbname='redshift', user='root', @@ -168,6 +170,8 @@ def test_changed_keepalive(self, psycopg2): self.config.credentials = self.config.credentials.replace(keepalives_idle=256) connection = self.adapter.acquire_connection('dummy') + psycopg2.connect.assert_not_called() + connection.handle psycopg2.connect.assert_called_once_with( dbname='redshift', user='root', @@ -182,6 +186,8 @@ def test_search_path(self, psycopg2): self.config.credentials = self.config.credentials.replace(search_path="test") connection = self.adapter.acquire_connection('dummy') + psycopg2.connect.assert_not_called() + connection.handle psycopg2.connect.assert_called_once_with( dbname='redshift', user='root', @@ -197,6 +203,8 @@ def test_search_path_with_space(self, psycopg2): self.config.credentials = self.config.credentials.replace(search_path="test test") connection = self.adapter.acquire_connection('dummy') + psycopg2.connect.assert_not_called() + connection.handle psycopg2.connect.assert_called_once_with( dbname='redshift', user='root', @@ -212,6 +220,8 @@ def test_set_zero_keepalive(self, psycopg2): self.config.credentials = self.config.credentials.replace(keepalives_idle=0) connection = self.adapter.acquire_connection('dummy') + psycopg2.connect.assert_not_called() + connection.handle psycopg2.connect.assert_called_once_with( dbname='redshift', user='root', diff --git a/test/unit/test_snowflake_adapter.py b/test/unit/test_snowflake_adapter.py index ff53a6a704d..a54d6921289 100644 --- a/test/unit/test_snowflake_adapter.py +++ b/test/unit/test_snowflake_adapter.py @@ -230,7 +230,10 @@ def test_cancel_open_connections_single(self): add_query.assert_called_once_with('select system$abort_session(42)') def test_client_session_keep_alive_false_by_default(self): - self.adapter.connections.set_connection_name(name='new_connection_with_new_config') + conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config') + + self.snowflake.assert_not_called() + conn.handle self.snowflake.assert_has_calls([ mock.call( account='test_account', autocommit=False, @@ -243,8 +246,10 @@ def test_client_session_keep_alive_true(self): self.config.credentials = self.config.credentials.replace( client_session_keep_alive=True) self.adapter = SnowflakeAdapter(self.config) - self.adapter.connections.set_connection_name(name='new_connection_with_new_config') + conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config') + self.snowflake.assert_not_called() + conn.handle self.snowflake.assert_has_calls([ mock.call( account='test_account', autocommit=False, @@ -258,8 +263,10 @@ def test_user_pass_authentication(self): password='test_password', ) self.adapter = SnowflakeAdapter(self.config) - self.adapter.connections.set_connection_name(name='new_connection_with_new_config') + conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config') + self.snowflake.assert_not_called() + conn.handle self.snowflake.assert_has_calls([ mock.call( account='test_account', autocommit=False, @@ -275,8 +282,10 @@ def test_authenticator_user_pass_authentication(self): authenticator='test_sso_url', ) self.adapter = SnowflakeAdapter(self.config) - self.adapter.connections.set_connection_name(name='new_connection_with_new_config') + conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config') + self.snowflake.assert_not_called() + conn.handle self.snowflake.assert_has_calls([ mock.call( account='test_account', autocommit=False, @@ -292,8 +301,10 @@ def test_authenticator_externalbrowser_authentication(self): authenticator='externalbrowser' ) self.adapter = SnowflakeAdapter(self.config) - self.adapter.connections.set_connection_name(name='new_connection_with_new_config') + conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config') + self.snowflake.assert_not_called() + conn.handle self.snowflake.assert_has_calls([ mock.call( account='test_account', autocommit=False, @@ -311,8 +322,10 @@ def test_authenticator_private_key_authentication(self, mock_get_private_key): ) self.adapter = SnowflakeAdapter(self.config) - self.adapter.connections.set_connection_name(name='new_connection_with_new_config') + conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config') + self.snowflake.assert_not_called() + conn.handle self.snowflake.assert_has_calls([ mock.call( account='test_account', autocommit=False,