diff --git a/CHANGELOG.md b/CHANGELOG.md index ef7ba64c17b..9657efb0a8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ - Fix an issue where dbt rendered source test args, fix issue where dbt ran an extra compile pass over the wrapped SQL. ([#2114](https://github.com/fishtown-analytics/dbt/issues/2114), [#2150](https://github.com/fishtown-analytics/dbt/pull/2150)) - Set more upper bounds for jinja2,requests, and idna dependencies, upgrade snowflake-connector-python ([#2147](https://github.com/fishtown-analytics/dbt/issues/2147), [#2151](https://github.com/fishtown-analytics/dbt/pull/2151)) +### Under the hood +- Parallelize filling the cache and listing schemas in each database during startup ([#2127](https://github.com/fishtown-analytics/dbt/issues/2127), [#2157](https://github.com/fishtown-analytics/dbt/pull/2157)) + Contributors: - [@bubbomb](https://github.com/bubbomb) ([#2080](https://github.com/fishtown-analytics/dbt/pull/2080)) - [@sonac](https://github.com/sonac) ([#2078](https://github.com/fishtown-analytics/dbt/pull/2078)) diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 3af51b1d24e..76908738328 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -1,6 +1,5 @@ import abc -from concurrent.futures import ThreadPoolExecutor, as_completed -from concurrent.futures import Future # noqa - we use this for typing only +from concurrent.futures import as_completed, Future from contextlib import contextmanager from datetime import datetime from typing import ( @@ -27,7 +26,7 @@ from dbt.exceptions import warn_or_error from dbt.node_types import NodeType from dbt.logger import GLOBAL_LOGGER as logger -from dbt.utils import filter_null_values +from dbt.utils import filter_null_values, executor from dbt.adapters.base.connections import BaseConnectionManager, Connection from dbt.adapters.base.meta import AdapterMeta, available @@ -358,6 +357,12 @@ def _get_cache_schemas( # databases return info_schema_name_map + def _list_relations_get_connection( + self, db: BaseRelation, schema: str + ) -> List[BaseRelation]: + with self.connection_named(f'list_{db.database}_{schema}'): + return self.list_relations_without_caching(db, schema) + def _relations_cache_for_schemas(self, manifest: Manifest) -> None: """Populate the relations cache for the given schemas. Returns an iterable of the schemas populated, as strings. @@ -365,16 +370,22 @@ def _relations_cache_for_schemas(self, manifest: Manifest) -> None: if not dbt.flags.USE_CACHE: return - info_schema_name_map = self._get_cache_schemas(manifest, - exec_only=True) - for db, schema in info_schema_name_map.search(): - for relation in self.list_relations_without_caching(db, schema): - self.cache.add(relation) + schema_map = self._get_cache_schemas(manifest, exec_only=True) + with executor(self.config) as tpe: + futures: List[Future[List[BaseRelation]]] = [ + tpe.submit(self._list_relations_get_connection, db, schema) + for db, schema in schema_map.search() + ] + for future in as_completed(futures): + # if we can't read the relations we need to just raise anyway, + # so just call future.result() and let that raise on failure + for relation in future.result(): + self.cache.add(relation) # it's possible that there were no relations in some schemas. We want # to insert the schemas we query into the cache's `.schemas` attribute # so we can check it later - self.cache.update_schemas(info_schema_name_map.schemas_searched()) + self.cache.update_schemas(schema_map.schemas_searched()) def set_relations_cache( self, manifest: Manifest, clear: bool = False @@ -1047,13 +1058,11 @@ def _get_one_catalog( def get_catalog( self, manifest: Manifest ) -> Tuple[agate.Table, List[Exception]]: - # snowflake is super slow. split it out into the specified threads - num_threads = self.config.threads schema_map = self._get_cache_schemas(manifest) - with ThreadPoolExecutor(max_workers=num_threads) as executor: - futures = [ - executor.submit(self._get_one_catalog, info, schemas, manifest) + with executor(self.config) as tpe: + futures: List[Future[agate.Table]] = [ + tpe.submit(self._get_one_catalog, info, schemas, manifest) for info, schemas in schema_map.items() if len(schemas) > 0 ] catalogs, exceptions = catch_as_completed(futures) diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 7abf5fba799..a7b4b06aefe 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -1,8 +1,9 @@ import os import time +from concurrent.futures import as_completed from datetime import datetime from multiprocessing.dummy import Pool as ThreadPool -from typing import Optional, Dict, List, Set, Tuple +from typing import Optional, Dict, List, Set, Tuple, Iterable from dbt.task.base import ConfiguredTask from dbt.adapters.factory import get_adapter @@ -374,7 +375,9 @@ def interpret_results(self, results): failures = [r for r in results if r.error or r.fail] return len(failures) == 0 - def get_model_schemas(self, selected_uids): + def get_model_schemas( + self, selected_uids: Iterable[str] + ) -> Set[Tuple[str, str]]: if self.manifest is None: raise InternalException('manifest was None in get_model_schemas') @@ -387,19 +390,46 @@ def get_model_schemas(self, selected_uids): return schemas - def create_schemas(self, adapter, selected_uids): + def create_schemas(self, adapter, selected_uids: Iterable[str]): required_schemas = self.get_model_schemas(selected_uids) required_databases = set(db for db, _ in required_schemas) existing_schemas_lowered: Set[Tuple[str, str]] = set() - for db in required_databases: - existing_schemas_lowered.update( - (db.lower(), s.lower()) for s in adapter.list_schemas(db)) - for db, schema in required_schemas: - if (db.lower(), schema.lower()) not in existing_schemas_lowered: + def list_schemas(db: str) -> List[Tuple[str, str]]: + with adapter.connection_named(f'list_{db}'): + return [ + (db.lower(), s.lower()) + for s in adapter.list_schemas(db) + ] + + def create_schema(db: str, schema: str) -> None: + with adapter.connection_named(f'create_{db}_{schema}'): adapter.create_schema(db, schema) + list_futures = [] + create_futures = [] + + with dbt.utils.executor(self.config) as tpe: + list_futures = [ + tpe.submit(list_schemas, db) for db in required_databases + ] + + for ls_future in as_completed(list_futures): + existing_schemas_lowered.update(ls_future.result()) + + for db, schema in required_schemas: + db_schema = (db.lower(), schema.lower()) + if db_schema not in existing_schemas_lowered: + existing_schemas_lowered.add(db_schema) + create_futures.append( + tpe.submit(create_schema, db, schema) + ) + + for create_future in as_completed(create_futures): + # trigger/re-raise any excceptions while creating schemas + create_future.result() + def get_result(self, results, elapsed_time, generated_at): return ExecutionResult( results=results, diff --git a/core/dbt/utils.py b/core/dbt/utils.py index 48e562338cf..a38a517b870 100644 --- a/core/dbt/utils.py +++ b/core/dbt/utils.py @@ -1,4 +1,5 @@ import collections +import concurrent.futures import copy import datetime import decimal @@ -8,6 +9,7 @@ import json import os from enum import Enum +from typing_extensions import Protocol from typing import ( Tuple, Type, Any, Optional, TypeVar, Dict, Union, Callable ) @@ -489,3 +491,48 @@ def format_bytes(num_bytes): num_bytes /= 1024.0 return "> 1024 TB" + + +# a little concurrent.futures.Executor for single-threaded mode +class SingleThreadedExecutor(concurrent.futures.Executor): + def submit(*args, **kwargs): + # this basic pattern comes from concurrent.futures.Executor itself, + # but without handling the `fn=` form. + if len(args) >= 2: + self, fn, *args = args + elif not args: + raise TypeError( + "descriptor 'submit' of 'SingleThreadedExecutor' object needs " + "an argument" + ) + else: + raise TypeError( + 'submit expected at least 1 positional argument, ' + 'got %d' % (len(args) - 1) + ) + fut = concurrent.futures.Future() + try: + result = fn(*args, **kwargs) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(result) + return fut + + +class ThreadedArgs(Protocol): + single_threaded: bool + + +class HasThreadingConfig(Protocol): + args: ThreadedArgs + threads: Optional[int] + + +def executor(config: HasThreadingConfig) -> concurrent.futures.Executor: + if config.args.single_threaded: + return SingleThreadedExecutor() + else: + return concurrent.futures.ThreadPoolExecutor( + max_workers=config.threads + ) diff --git a/test/integration/001_simple_copy_test/test_simple_copy.py b/test/integration/001_simple_copy_test/test_simple_copy.py index d50d54d11b0..92149c18bce 100644 --- a/test/integration/001_simple_copy_test/test_simple_copy.py +++ b/test/integration/001_simple_copy_test/test_simple_copy.py @@ -71,7 +71,6 @@ def test__postgres__simple_copy_with_materialized_views(self): select * from {schema}.unrelated_materialized_view ) '''.format(schema=self.unique_schema())) - results = self.run_dbt(["seed"]) self.assertEqual(len(results), 1) results = self.run_dbt() diff --git a/test/unit/utils.py b/test/unit/utils.py index e4ed58bdeee..affb6c375f6 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -25,6 +25,7 @@ def normalize(path): class Obj: which = 'blah' + single_threaded = False def mock_connection(name):