Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: close connections after use #2650

Merged
merged 6 commits into from
Jul 28, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
### Fixes
- fast-fail option with adapters that don't support cancelling queries will now passthrough the original error messages ([#2644](https:/fishtown-analytics/dbt/issues/2644), [#2646](https:/fishtown-analytics/dbt/pull/2646))
- `dbt clean` no longer requires a profile ([#2620](https:/fishtown-analytics/dbt/issues/2620), [#2649](https:/fishtown-analytics/dbt/pull/2649))
- Close all connections so snowflake's keepalive thread will exit. ([#2645](https:/fishtown-analytics/dbt/issues/2645), [#2650](https:/fishtown-analytics/dbt/pull/2650))

Contributors:
- [@joshpeng-quibi](https:/joshpeng-quibi) ([#2646](https:/fishtown-analytics/dbt/pull/2646))
Expand Down
26 changes: 15 additions & 11 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ def clear_transaction(self) -> None:
self.begin()
self.commit()

def rollback_if_open(self) -> None:
conn = self.get_if_exists()
if conn is not None and conn.handle and conn.transaction_open:
self._rollback(conn)

@abc.abstractmethod
def exception_handler(self, sql: str) -> ContextManager:
"""Create a context manager that handles exceptions caused by database
Expand Down Expand Up @@ -176,11 +181,9 @@ def release(self) -> None:
return

try:
if conn.state == 'open':
if conn.transaction_open is True:
self._rollback(conn)
else:
self.close(conn)
# always close the connection. close() calls _rollback() if there
# is an open transaction
self.close(conn)
except Exception:
# if rollback or close failed, remove our busted connection
self.clear_thread_connection()
Expand Down Expand Up @@ -230,11 +233,10 @@ def _close_handle(cls, connection: Connection) -> None:
"""Perform the actual close operation."""
# On windows, sometimes connection handles don't have a close() attr.
if hasattr(connection.handle, 'close'):
logger.debug('On {}: Close'.format(connection.name))
logger.debug(f'On {connection.name}: Close')
connection.handle.close()
else:
logger.debug('On {}: No close available on handle'
.format(connection.name))
logger.debug(f'On {connection.name}: No close available on handle')

@classmethod
def _rollback(cls, connection: Connection) -> None:
Expand All @@ -247,10 +249,11 @@ def _rollback(cls, connection: Connection) -> None:

if connection.transaction_open is False:
raise dbt.exceptions.InternalException(
'Tried to rollback transaction on connection "{}", but '
'it does not have one open!'.format(connection.name))
f'Tried to rollback transaction on connection '
f'"{connection.name}", but it does not have one open!'
)

logger.debug('On {}: ROLLBACK'.format(connection.name))
logger.debug(f'On {connection.name}: ROLLBACK')
cls._rollback_handle(connection)

connection.transaction_open = False
Expand All @@ -268,6 +271,7 @@ def close(cls, connection: Connection) -> Connection:
return connection

if connection.transaction_open and connection.handle:
logger.debug('On {}: ROLLBACK'.format(connection.name))
cls._rollback_handle(connection)
connection.transaction_open = False

Expand Down
72 changes: 38 additions & 34 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,6 @@ def _get_catalog_schemas(self, manifest: Manifest) -> SchemaSearchMap:
# databases
return info_schema_name_map

def _list_relations_get_connection(
self, schema_relation: BaseRelation
) -> List[BaseRelation]:
name = f'list_{schema_relation.database}_{schema_relation.schema}'
with self.connection_named(name):
return self.list_relations_without_caching(schema_relation)

def _relations_cache_for_schemas(self, manifest: Manifest) -> None:
"""Populate the relations cache for the given schemas. Returns an
iterable of the schemas populated, as strings.
Expand All @@ -328,10 +321,16 @@ def _relations_cache_for_schemas(self, manifest: Manifest) -> None:

cache_schemas = self._get_cache_schemas(manifest)
with executor(self.config) as tpe:
futures: List[Future[List[BaseRelation]]] = [
tpe.submit(self._list_relations_get_connection, cache_schema)
for cache_schema in cache_schemas
]
futures: List[Future[List[BaseRelation]]] = []
for cache_schema in cache_schemas:
fut = tpe.submit_connected(
self,
f'list_{cache_schema.database}_{cache_schema.schema}',
self.list_relations_without_caching,
cache_schema
)
futures.append(fut)

for future in as_completed(futures):
# if we can't read the relations we need to just raise anyway,
# so just call future.result() and let that raise on failure
Expand Down Expand Up @@ -1001,24 +1000,18 @@ def _get_one_catalog(
manifest: Manifest,
) -> agate.Table:

name = '.'.join([
str(information_schema.database),
'information_schema'
])

with self.connection_named(name):
kwargs = {
'information_schema': information_schema,
'schemas': schemas
}
table = self.execute_macro(
GET_CATALOG_MACRO_NAME,
kwargs=kwargs,
release=True,
# pass in the full manifest so we get any local project
# overrides
manifest=manifest,
)
kwargs = {
'information_schema': information_schema,
'schemas': schemas
}
table = self.execute_macro(
GET_CATALOG_MACRO_NAME,
kwargs=kwargs,
release=False,
# pass in the full manifest so we get any local project
# overrides
manifest=manifest,
)

results = self._catalog_filter_table(table, manifest)
return results
Expand All @@ -1029,10 +1022,21 @@ def get_catalog(
schema_map = self._get_catalog_schemas(manifest)

with executor(self.config) as tpe:
futures: List[Future[agate.Table]] = [
tpe.submit(self._get_one_catalog, info, schemas, manifest)
for info, schemas in schema_map.items() if len(schemas) > 0
]
futures: List[Future[agate.Table]] = []
for info, schemas in schema_map.items():
if len(schemas) == 0:
continue
name = '.'.join([
str(info.database),
'information_schema'
])

fut = tpe.submit_connected(
self, name,
self._get_one_catalog, info, schemas, manifest
)
futures.append(fut)

catalogs, exceptions = catch_as_completed(futures)

return catalogs, exceptions
Expand All @@ -1059,7 +1063,7 @@ def calculate_freshness(
table = self.execute_macro(
FRESHNESS_MACRO_NAME,
kwargs=kwargs,
release=True,
release=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the release arg used anymore and should it stick around if not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not used in core anymore, but it is used in some plugins. Maybe anything using it needs to be changed anyway, I just hate to cause more plugin churn.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, sounds good

Copy link
Contributor Author

@beckjake beckjake Jul 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've played around with this some, and I think we have to remove the release arg.

@jtcohen6 I think this will break any dependent adapters. I am mostly fine with that happening in a patch release since it fixes a pretty crippling bug. Ours all are locked to minor versions so we'll be fine... let me know what you think!

edit: I've decided to deprecate it. I don't think it's a great fix, but I think it's an okay fix. Some adapters that are currently using the release argument may not have to make changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok! I'm on board. This is a tiny breaking change to resolve a much more significant regression. We'll want to check that our plugins (spark + presto) can work with 0.17.2-rc1.

manifest=manifest
)
# now we have a 1-row table of the maximum `loaded_at_field` value and
Expand Down
7 changes: 5 additions & 2 deletions core/dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import dbt.clients.agate_helper
import dbt.exceptions
from dbt.contracts.connection import Connection
from dbt.contracts.connection import Connection, ConnectionState
from dbt.adapters.base import BaseConnectionManager
from dbt.logger import GLOBAL_LOGGER as logger

Expand Down Expand Up @@ -37,7 +37,10 @@ def cancel_open(self) -> List[str]:

# if the connection failed, the handle will be None so we have
# nothing to cancel.
if connection.handle is not None:
if (
connection.handle is not None and
connection.state == ConnectionState.OPEN
):
self.cancel(connection)
if connection.name is not None:
names.append(connection.name)
Expand Down
48 changes: 27 additions & 21 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,15 @@ def _cancel_connections(self, pool):
dbt.ui.printer.print_timestamped_line(msg, yellow)

else:
for conn_name in adapter.cancel_open_connections():
if self.manifest is not None:
node = self.manifest.nodes.get(conn_name)
if node is not None and node.is_ephemeral_model:
continue
# if we don't have a manifest/don't have a node, print anyway.
dbt.ui.printer.print_cancel_line(conn_name)
with adapter.connection_named('master'):
for conn_name in adapter.cancel_open_connections():
if self.manifest is not None:
node = self.manifest.nodes.get(conn_name)
if node is not None and node.is_ephemeral_model:
continue
# if we don't have a manifest/don't have a node, print
# anyway.
dbt.ui.printer.print_cancel_line(conn_name)

pool.join()

Expand Down Expand Up @@ -457,18 +459,15 @@ def list_schemas(
db_lowercase = dbt.utils.lowercase(db_only.database)
if db_only.database is None:
database_quoted = None
conn_name = 'list_schemas'
else:
database_quoted = str(db_only)
conn_name = f'list_{db_only.database}'

with adapter.connection_named(conn_name):
# we should never create a null schema, so just filter them out
return [
(db_lowercase, s.lower())
for s in adapter.list_schemas(database_quoted)
if s is not None
]
# we should never create a null schema, so just filter them out
return [
(db_lowercase, s.lower())
for s in adapter.list_schemas(database_quoted)
if s is not None
]

def create_schema(relation: BaseRelation) -> None:
db = relation.database or ''
Expand All @@ -480,9 +479,13 @@ def create_schema(relation: BaseRelation) -> None:
create_futures = []

with dbt.utils.executor(self.config) as tpe:
list_futures = [
tpe.submit(list_schemas, db) for db in required_databases
]
for req in required_databases:
if req.database is None:
name = 'list_schemas'
else:
name = f'list_{req.database}'
fut = tpe.submit_connected(adapter, name, list_schemas, req)
list_futures.append(fut)

for ls_future in as_completed(list_futures):
existing_schemas_lowered.update(ls_future.result())
Expand All @@ -499,9 +502,12 @@ def create_schema(relation: BaseRelation) -> None:
db_schema = (db_lower, schema.lower())
if db_schema not in existing_schemas_lowered:
existing_schemas_lowered.add(db_schema)
create_futures.append(
tpe.submit(create_schema, info)

fut = tpe.submit_connected(
adapter, f'create_{info.database or ""}_{info.schema}',
create_schema, info
)
create_futures.append(fut)

for create_future in as_completed(create_futures):
# trigger/re-raise any excceptions while creating schemas
Expand Down
31 changes: 26 additions & 5 deletions core/dbt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import itertools
import json
import os
from contextlib import contextmanager
from enum import Enum
from typing_extensions import Protocol
from typing import (
Expand Down Expand Up @@ -518,8 +519,16 @@ def format_bytes(num_bytes):
return "> 1024 TB"


class ConnectingExecutor(concurrent.futures.Executor):
def submit_connected(self, adapter, conn_name, func, *args, **kwargs):
def connected(conn_name, func, *args, **kwargs):
with self.connection_named(adapter, conn_name):
return func(*args, **kwargs)
return self.submit(connected, conn_name, func, *args, **kwargs)


# a little concurrent.futures.Executor for single-threaded mode
class SingleThreadedExecutor(concurrent.futures.Executor):
class SingleThreadedExecutor(ConnectingExecutor):
def submit(*args, **kwargs):
# this basic pattern comes from concurrent.futures.Executor itself,
# but without handling the `fn=` form.
Expand All @@ -544,6 +553,20 @@ def submit(*args, **kwargs):
fut.set_result(result)
return fut

@contextmanager
def connection_named(self, adapter, name):
yield


class MultiThreadedExecutor(
ConnectingExecutor,
concurrent.futures.ThreadPoolExecutor,
):
@contextmanager
def connection_named(self, adapter, name):
with adapter.connection_named(name):
yield


class ThreadedArgs(Protocol):
single_threaded: bool
Expand All @@ -554,13 +577,11 @@ class HasThreadingConfig(Protocol):
threads: Optional[int]


def executor(config: HasThreadingConfig) -> concurrent.futures.Executor:
def executor(config: HasThreadingConfig) -> ConnectingExecutor:
if config.args.single_threaded:
return SingleThreadedExecutor()
else:
return concurrent.futures.ThreadPoolExecutor(
max_workers=config.threads
)
return MultiThreadedExecutor(max_workers=config.threads)


def fqn_search(
Expand Down
3 changes: 2 additions & 1 deletion plugins/bigquery/dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from dbt.utils import format_bytes
from dbt.clients import agate_helper, gcloud
from dbt.contracts.connection import ConnectionState
from dbt.exceptions import (
FailedToConnectException, RuntimeException, DatabaseException
)
Expand Down Expand Up @@ -111,7 +112,7 @@ def cancel_open(self) -> None:

@classmethod
def close(cls, connection):
connection.state = 'closed'
connection.state = ConnectionState.CLOSED

return connection

Expand Down
17 changes: 13 additions & 4 deletions plugins/postgres/dbt/adapters/postgres/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def exception_handler(self, sql):
logger.debug('Postgres error: {}'.format(str(e)))

try:
# attempt to release the connection
self.release()
self.rollback_if_open()
except psycopg2.Error:
logger.debug("Failed to release connection!")
pass
Expand All @@ -60,7 +59,7 @@ def exception_handler(self, sql):
except Exception as e:
logger.debug("Error running SQL: {}", sql)
logger.debug("Rolling back transaction.")
self.release()
self.rollback_if_open()
if isinstance(e, dbt.exceptions.RuntimeException):
# during a sql query, an internal to dbt exception was raised.
# this sounds a lot like a signal handler and probably has
Expand Down Expand Up @@ -122,7 +121,17 @@ def open(cls, connection):

def cancel(self, connection):
connection_name = connection.name
pid = connection.handle.get_backend_pid()
try:
pid = connection.handle.get_backend_pid()
except psycopg2.InterfaceError as exc:
# if the connection is already closed, not much to cancel!
if 'already closed' in str(exc):
logger.debug(
f'Connection {connection_name} was already closed'
)
return
# probably bad, re-raise it
raise

sql = "select pg_terminate_backend({})".format(pid)

Expand Down
Loading