From 0f9c4a8e9a04471e71659c73f2c5ea9fa8a87707 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Thu, 3 Jan 2019 13:43:08 -0700 Subject: [PATCH] Move SQL previously embedded into adapters into macros Adapters now store an internal manifest that only has the dbt internal projects Adapters use that manifest if none is provided to execute_manifest The internal manifest is lazy-loaded to avoid recursion issues Moved declared plugin paths down one level Connection management changes to accomadate calling macro -> adapter -> macro Split up precision and scale when describing number columns so agate doesn't eat commas Manifest building now happens in the RunManager instead of the compiler Now macros: create/drop schema get_columns_in_relation alter column type rename/drop/truncate list_schemas/check_schema_exists list_relations_without_caching --- core/dbt/adapters/base/connections.py | 2 +- core/dbt/adapters/base/impl.py | 81 +++++-- core/dbt/adapters/base/plugin.py | 3 +- core/dbt/adapters/factory.py | 3 +- core/dbt/adapters/sql/impl.py | 225 ++++++++++-------- core/dbt/compilation.py | 73 +----- core/dbt/context/common.py | 25 +- core/dbt/context/parser.py | 7 + core/dbt/context/runtime.py | 7 + core/dbt/exceptions.py | 4 + core/dbt/include/global_project/__init__.py | 4 +- .../global_project/macros/adapters/common.sql | 150 +++++++++++- core/dbt/loader.py | 121 ++++++++-- core/dbt/parser/base.py | 4 +- core/dbt/runner.py | 85 ++++++- core/dbt/schema.py | 15 +- core/dbt/task/generate.py | 7 +- core/dbt/utils.py | 30 --- .../bigquery/dbt/include/bigquery/__init__.py | 2 +- .../dbt/include/bigquery/macros/adapters.sql | 19 +- .../postgres/dbt/adapters/postgres/impl.py | 123 ++-------- .../postgres/dbt/include/postgres/__init__.py | 2 +- .../dbt/include/postgres/macros/adapters.sql | 64 +++++ .../dbt/include/postgres/macros/catalog.sql | 1 + .../dbt/include/postgres/macros/relations.sql | 4 +- .../redshift/dbt/adapters/redshift/impl.py | 69 ------ .../redshift/dbt/include/redshift/__init__.py | 2 +- .../dbt/include/redshift/macros/adapters.sql | 96 ++++++++ .../snowflake/dbt/adapters/snowflake/impl.py | 88 ------- .../dbt/include/snowflake/__init__.py | 2 +- .../dbt/include/snowflake/macros/adapters.sql | 57 +++++ .../test_docs_generate.py | 2 +- .../test_concurrent_transaction.py | 3 +- .../038_caching_test/test_caching.py | 4 +- test/integration/base.py | 21 +- test/unit/test_bigquery_adapter.py | 6 +- test/unit/test_graph.py | 72 +++--- test/unit/test_postgres_adapter.py | 4 +- test/unit/test_schema.py | 6 +- test/unit/test_snowflake_adapter.py | 6 +- test/unit/utils.py | 9 + 41 files changed, 896 insertions(+), 612 deletions(-) create mode 100644 plugins/postgres/dbt/include/postgres/macros/adapters.sql diff --git a/core/dbt/adapters/base/connections.py b/core/dbt/adapters/base/connections.py index 547f365eba9..0f4a49ec211 100644 --- a/core/dbt/adapters/base/connections.py +++ b/core/dbt/adapters/base/connections.py @@ -331,7 +331,7 @@ def commit_if_has_connection(self, name): :param str name: The name of the connection to use. """ - connection = self.get_if_exists(name) + connection = self.in_use.get(name) if connection: self.commit(connection) diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 3929735f1c2..6e15be3522e 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -13,6 +13,7 @@ from dbt.compat import abstractclassmethod, classmethod from dbt.contracts.connection import Connection +from dbt.loader import GraphLoader from dbt.logger import GLOBAL_LOGGER as logger from dbt.schema import Column from dbt.utils import filter_null_values, translate_aliases @@ -21,6 +22,7 @@ from dbt.adapters.base import BaseRelation from dbt.adapters.cache import RelationsCache + GET_CATALOG_MACRO_NAME = 'get_catalog' @@ -68,11 +70,14 @@ def test(row): class BaseAdapter(object): """The BaseAdapter provides an abstract base class for adapters. - Adapters must implement the following methods. Some of these methods can be - safely overridden as a noop, where it makes sense (transactions on - databases that don't support them, for instance). Those methods are marked - with a (passable) in their docstrings. Check docstrings for type - information, etc. + Adapters must implement the following methods and macros. Some of the + methods can be safely overridden as a noop, where it makes sense + (transactions on databases that don't support them, for instance). Those + methods are marked with a (passable) in their docstrings. Check docstrings + for type information, etc. + + To implement a macro, implement "${adapter_type}__${macro_name}". in the + adapter's internal project. Methods: - exception_handler @@ -94,6 +99,9 @@ class BaseAdapter(object): - convert_datetime_type - convert_date_type - convert_time_type + + Macros: + - get_catalog """ requires = {} @@ -106,6 +114,7 @@ def __init__(self, config): self.config = config self.cache = RelationsCache() self.connections = self.ConnectionManager(config) + self._internal_manifest_lazy = None ### # Methods that pass through to the connection manager @@ -159,6 +168,19 @@ def type(cls): """ return cls.ConnectionManager.TYPE + @property + def _internal_manifest(self): + if self._internal_manifest_lazy is None: + manifest = GraphLoader.load_internal(self.config) + self._internal_manifest_lazy = manifest + return self._internal_manifest_lazy + + def check_internal_manifest(self): + """Return the internal manifest (used for executing macros) if it's + been initialized, otherwise return None. + """ + return self._internal_manifest_lazy + ### # Caching methods ### @@ -667,21 +689,38 @@ def convert_agate_type(cls, agate_table, col_idx): ### # Operations involving the manifest ### - def execute_macro(self, manifest, macro_name, project=None, - context_override=None): + def execute_macro(self, macro_name, manifest=None, project=None, + context_override=None, kwargs=None, release=False, + connection_name=None): """Look macro_name up in the manifest and execute its results. - :param Manifest manifest: The manifest to use for generating the base - macro execution context. :param str macro_name: The name of the macro to execute. + :param Optional[Manifest] manifest: The manifest to use for generating + the base macro execution context. If none is provided, use the + internal manifest. :param Optional[str] project: The name of the project to search in, or None for the first match. :param Optional[dict] context_override: An optional dict to update() the macro execution context. + :param Optional[dict] kwargs: An optional dict of keyword args used to + pass to the macro. + :param bool release: If True, release the connection after executing. + :param Optional[str] connection_name: The connection name to use, or + use the macro name. Return an an AttrDict with three attributes: 'table', 'data', and 'status'. 'table' is an agate.Table. """ + if kwargs is None: + kwargs = {} + if context_override is None: + context_override = {} + if connection_name is None: + connection_name = macro_name + + if manifest is None: + manifest = self._internal_manifest + macro = manifest.find_macro_by_name(macro_name, project) if macro is None: raise dbt.exceptions.RuntimeException( @@ -692,15 +731,21 @@ def execute_macro(self, manifest, macro_name, project=None, # This causes a reference cycle, as dbt.context.runtime.generate() # ends up calling get_adapter, so the import has to be here. import dbt.context.runtime - macro_context = dbt.context.runtime.generate( + macro_context = dbt.context.runtime.generate_macro( macro, self.config, - manifest + manifest, + connection_name ) - if context_override: - macro_context.update(context_override) + macro_context.update(context_override) + + macro_function = macro.generator(macro_context) - result = macro.generator(macro_context)() + try: + result = macro_function(**kwargs) + finally: + if release: + self.release_connection(connection_name) return result @classmethod @@ -716,11 +761,9 @@ def get_catalog(self, manifest): """ # make it a list so macros can index into it. context = {'databases': list(manifest.get_used_databases())} - try: - table = self.execute_macro(manifest, GET_CATALOG_MACRO_NAME, - context_override=context) - finally: - self.release_connection(GET_CATALOG_MACRO_NAME) + table = self.execute_macro(GET_CATALOG_MACRO_NAME, + context_override=context, + release=True) results = self._catalog_filter_table(table, manifest) return results diff --git a/core/dbt/adapters/base/plugin.py b/core/dbt/adapters/base/plugin.py index c07c6062a6d..523b8a43fa9 100644 --- a/core/dbt/adapters/base/plugin.py +++ b/core/dbt/adapters/base/plugin.py @@ -18,8 +18,7 @@ def __init__(self, adapter, credentials, include_path, dependencies=None): self.adapter = adapter self.credentials = credentials self.include_path = include_path - project_path = os.path.join(self.include_path, adapter.type()) - project = Project.from_project_root(project_path, {}) + project = Project.from_project_root(include_path, {}) self.project_name = project.project_name if dependencies is None: dependencies = [] diff --git a/core/dbt/adapters/factory.py b/core/dbt/adapters/factory.py index f3145c690f1..2cbe2dc7ac6 100644 --- a/core/dbt/adapters/factory.py +++ b/core/dbt/adapters/factory.py @@ -81,5 +81,6 @@ def reset_adapters(): """Clear the adapters. This is useful for tests, which change configs. """ with _ADAPTER_LOCK: + for adapter in _ADAPTERS.values(): + adapter.cleanup_connections() _ADAPTERS.clear() - ADAPTER_TYPES.clear() diff --git a/core/dbt/adapters/sql/impl.py b/core/dbt/adapters/sql/impl.py index 28185da1600..fe56aae6204 100644 --- a/core/dbt/adapters/sql/impl.py +++ b/core/dbt/adapters/sql/impl.py @@ -12,20 +12,35 @@ from dbt.compat import abstractclassmethod +LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching' +GET_COLUMNS_IN_RELATION_MACRO_NAME = 'get_columns_in_relation' +LIST_SCHEMAS_MACRO_NAME = 'list_schemas' +CHECK_SCHEMA_EXISTS_MACRO_NAME = 'check_schema_exists' +CREATE_SCHEMA_MACRO_NAME = 'create_schema' +DROP_SCHEMA_MACRO_NAME = 'drop_schema' +RENAME_RELATION_MACRO_NAME = 'rename_relation' +TRUNCATE_RELATION_MACRO_NAME = 'truncate_relation' +DROP_RELATION_MACRO_NAME = 'drop_relation' +ALTER_COLUMN_TYPE_MACRO_NAME = 'alter_column_type' + + class SQLAdapter(BaseAdapter): """The default adapter with the common agate conversions and some SQL - methods implemented. This adapter has a different (much shorter) list of - methods to implement, but it may not be possible to implement all of them - on all databases. + methods implemented. This adapter has a different much shorter list of + methods to implement, but some more macros that must be implemented. + + To implement a macro, implement "${adapter_type}__${macro_name}". in the + adapter's internal project. Methods to implement: - exception_handler - type - date_function - - list_schemas - - list_relations_without_caching - - get_columns_in_relation_sql + Macros to implement: + - get_catalog + - list_relations_without_caching + - get_columns_in_relation """ @available def add_query(self, sql, model_name=None, auto_begin=True, bindings=None, @@ -78,12 +93,12 @@ def is_cancelable(cls): def expand_column_types(self, goal, current, model_name=None): reference_columns = { c.name: c for c in - self.get_columns_in_relation(goal, model_name) + self.get_columns_in_relation(goal, model_name=model_name) } target_columns = { c.name: c for c - in self.get_columns_in_relation(current, model_name) + in self.get_columns_in_relation(current, model_name=model_name) } for column_name, reference_column in reference_columns.items(): @@ -96,20 +111,11 @@ def expand_column_types(self, goal, current, model_name=None): logger.debug("Changing col type from %s to %s in table %s", target_column.data_type, new_type, current) - self.alter_column_type(current, column_name, - new_type, model_name) - - def drop_relation(self, relation, model_name=None): - if dbt.flags.USE_CACHE: - self.cache.drop(relation) - if relation.type is None: - dbt.exceptions.raise_compiler_error( - 'Tried to drop relation {}, but its type is null.' - .format(relation)) - - sql = 'drop {} if exists {} cascade'.format(relation.type, relation) + self.alter_column_type(current, column_name, new_type, + model_name=model_name) - connection, cursor = self.add_query(sql, model_name, auto_begin=False) + if model_name is None: + self.release_connection('master') def alter_column_type(self, relation, column_name, new_column_type, model_name=None): @@ -119,105 +125,124 @@ def alter_column_type(self, relation, column_name, new_column_type, 3. Drop the existing column (cascade!) 4. Rename the new column to existing column """ - - opts = { - "relation": relation, - "old_column": column_name, - "tmp_column": "{}__dbt_alter".format(column_name), - "dtype": new_column_type + kwargs = { + 'relation': relation, + 'column_name': column_name, + 'new_column_type': new_column_type, } + self.execute_macro( + ALTER_COLUMN_TYPE_MACRO_NAME, + kwargs=kwargs, + connection_name=model_name + ) - sql = """ - alter table {relation} add column "{tmp_column}" {dtype}; - update {relation} set "{tmp_column}" = "{old_column}"; - alter table {relation} drop column "{old_column}" cascade; - alter table {relation} rename column "{tmp_column}" to "{old_column}"; - """.format(**opts).strip() # noqa - - connection, cursor = self.add_query(sql, model_name) + def drop_relation(self, relation, model_name=None): + if dbt.flags.USE_CACHE: + self.cache.drop(relation) + if relation.type is None: + dbt.exceptions.raise_compiler_error( + 'Tried to drop relation {}, but its type is null.' + .format(relation)) - return connection, cursor + self.execute_macro( + DROP_RELATION_MACRO_NAME, + kwargs={'relation': relation}, + connection_name=model_name + ) def truncate_relation(self, relation, model_name=None): - sql = 'truncate table {}'.format(relation) - - connection, cursor = self.add_query(sql, model_name) + self.execute_macro( + TRUNCATE_RELATION_MACRO_NAME, + kwargs={'relation': relation}, + connection_name=model_name + ) def rename_relation(self, from_relation, to_relation, model_name=None): if dbt.flags.USE_CACHE: self.cache.rename(from_relation, to_relation) - sql = 'alter table {} rename to {}'.format( - from_relation, to_relation.include(database=False, schema=False)) - - connection, cursor = self.add_query(sql, model_name) - - @abstractclassmethod - def get_columns_in_relation_sql(cls, relation): - """Return the sql string to execute on this adapter that will return - information about the columns in this relation. The query should result - in a table with the following type information: - - column_name: text - data_type: text - character_maximum_length: number - numeric_precision: text - numeric_precision should be two integers separated by a comma, - representing the precision and the scale, respectively. - - :param self.Relation relation: The relation to get columns for. - :return: The column information query - :rtype: str - """ - raise dbt.exceptions.NotImplementedException( - '`get_columns_in_relation_sql` is not implemented for this ' - 'adapter!' + kwargs = {'from_relation': from_relation, 'to_relation': to_relation} + self.execute_macro( + RENAME_RELATION_MACRO_NAME, + kwargs=kwargs, + connection_name=model_name ) def get_columns_in_relation(self, relation, model_name=None): - sql = self.get_columns_in_relation_sql(relation) - connection, cursor = self.add_query(sql, model_name) - - data = cursor.fetchall() - columns = [] - - for row in data: - name, data_type, char_size, numeric_size = row - column = self.Column(name, data_type, char_size, numeric_size) - columns.append(column) - - return columns - - def _create_schema_sql(self, database, schema): - schema = self.quote_as_configured(schema, 'schema') - database = self.quote_as_configured(database, 'database') - return 'create schema if not exists {database}.{schema}'.format( - database=database, schema=schema - ) - - def _drop_schema_sql(self, database, schema): - schema = self.quote_as_configured(schema, 'schema') - database = self.quote_as_configured(database, 'database') - return 'drop schema if exists {database}.{schema} cascade'.format( - database=database, schema=schema + return self.execute_macro( + GET_COLUMNS_IN_RELATION_MACRO_NAME, + kwargs={'relation': relation}, + connection_name=model_name ) def create_schema(self, database, schema, model_name=None): - logger.debug('Creating schema "%s".', schema) - - sql = self._create_schema_sql(database, schema) - res = self.add_query(sql, model_name) - + logger.debug('Creating schema "%s"."%s".', database, schema) + if model_name is None: + model_name = 'master' + kwargs = { + 'database_name': self.quote_as_configured(database, 'database'), + 'schema_name': self.quote_as_configured(schema, 'schema'), + } + self.execute_macro(CREATE_SCHEMA_MACRO_NAME, + kwargs=kwargs, + connection_name=model_name) self.commit_if_has_connection(model_name) - return res - def drop_schema(self, database, schema, model_name=None): - logger.debug('Dropping schema "%s".', schema) - - sql = self._drop_schema_sql(database, schema) + logger.debug('Dropping schema "%s"."%s".', database, schema) + kwargs = { + 'database_name': self.quote_as_configured(database, 'database'), + 'schema_name': self.quote_as_configured(schema, 'schema'), + } + self.execute_macro(DROP_SCHEMA_MACRO_NAME, + kwargs=kwargs, + connection_name=model_name) + + def list_relations_without_caching(self, database, schema, + model_name=None): + assert database is not None + assert schema is not None + results = self.execute_macro( + LIST_RELATIONS_MACRO_NAME, + kwargs={'database': database, 'schema': schema}, + connection_name=model_name, + release=True + ) - return self.add_query(sql, model_name) + relations = [] + quote_policy = { + 'schema': True, + 'identifier': True + } + for _database, name, _schema, _type in results: + relations.append(self.Relation.create( + database=_database, + schema=_schema, + identifier=name, + quote_policy=quote_policy, + type=_type + )) + return relations def quote(cls, identifier): return '"{}"'.format(identifier) + + def list_schemas(self, database, model_name=None): + results = self.execute_macro( + LIST_SCHEMAS_MACRO_NAME, + kwargs={'database': database}, + connection_name=model_name, + # release when the model_name is none, as that implies we were + # called by node_runners.py. + release=(model_name is None) + ) + + return [row[0] for row in results] + + def check_schema_exists(self, database, schema, model_name=None): + results = self.execute_macro( + CHECK_SCHEMA_EXISTS_MACRO_NAME, + kwargs={'database': database, 'schema': schema}, + connection_name=model_name + ) + return results[0] > 0 diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index 9563a381cc9..96a446ceb59 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -8,8 +8,8 @@ import dbt.include import dbt.tracking +from dbt import deprecations from dbt.utils import get_materialization, NodeType, is_type - from dbt.linker import Linker import dbt.compat @@ -25,7 +25,6 @@ from dbt.logger import GLOBAL_LOGGER as logger graph_file_name = 'graph.gpickle' -manifest_file_name = 'manifest.json' def print_compile_stats(stats): @@ -160,15 +159,6 @@ def compile_node(self, node, manifest, extra_context=None): return injected_node - def write_manifest_file(self, manifest): - """Write the manifest file to disk. - - manifest should be a Manifest. - """ - filename = manifest_file_name - manifest_path = os.path.join(self.config.target_path, filename) - write_json(manifest_path, manifest.serialize()) - def write_graph_file(self, linker, manifest): filename = graph_file_name graph_path = os.path.join(self.config.target_path, filename) @@ -195,66 +185,9 @@ def link_graph(self, linker, manifest): if cycle: raise RuntimeError("Found a cycle: {}".format(cycle)) - def get_all_projects(self): - all_projects = {self.config.project_name: self.config} - dependency_projects = dbt.utils.dependency_projects(self.config) - - for project_cfg in dependency_projects: - name = project_cfg.project_name - all_projects[name] = project_cfg - - if dbt.flags.STRICT_MODE: - dbt.contracts.project.ProjectList(**all_projects) - - return all_projects - - def _check_resource_uniqueness(cls, manifest): - names_resources = {} - alias_resources = {} - - for resource, node in manifest.nodes.items(): - if node.resource_type not in NodeType.refable(): - continue - - name = node.name - alias = "{}.{}".format(node.schema, node.alias) - - existing_node = names_resources.get(name) - if existing_node is not None: - dbt.exceptions.raise_duplicate_resource_name( - existing_node, node) - - existing_alias = alias_resources.get(alias) - if existing_alias is not None: - dbt.exceptions.raise_ambiguous_alias( - existing_alias, node) - - names_resources[name] = node - alias_resources[alias] = node - - def warn_for_deprecated_configs(self, manifest): - for unique_id, node in manifest.nodes.items(): - is_model = node.resource_type == NodeType.Model - if is_model and 'sql_where' in node.config: - dbt.deprecations.warn('sql_where') - - def compile(self): + def compile(self, manifest): linker = Linker() - all_projects = self.get_all_projects() - - manifest = dbt.loader.GraphLoader.load_all(self.config, all_projects) - - self.write_manifest_file(manifest) - - self._check_resource_uniqueness(manifest) - - resource_fqns = manifest.get_resource_fqns() - disabled_fqns = [n.fqn for n in manifest.disabled] - self.config.warn_for_unused_resource_config_paths(resource_fqns, - disabled_fqns) - self.warn_for_deprecated_configs(manifest) - self.link_graph(linker, manifest) stats = defaultdict(int) @@ -267,4 +200,4 @@ def compile(self): self.write_graph_file(linker, manifest) print_compile_stats(stats) - return manifest, linker + return linker diff --git a/core/dbt/context/common.py b/core/dbt/context/common.py index 07549398751..598de08c53b 100644 --- a/core/dbt/context/common.py +++ b/core/dbt/context/common.py @@ -48,8 +48,8 @@ class DatabaseWrapper(object): """ Wrapper for runtime database interaction. Mostly a compatibility layer now. """ - def __init__(self, model, adapter): - self.model = model + def __init__(self, connection_name, adapter): + self.connection_name = connection_name self.adapter = adapter self.Relation = RelationProxy(adapter) @@ -58,7 +58,7 @@ def wrap(self, name): @functools.wraps(func) def wrapped(*args, **kwargs): - kwargs['model_name'] = self.model.get('name') + kwargs['model_name'] = self.connection_name return func(*args, **kwargs) return wrapped @@ -83,7 +83,7 @@ def type(self): return self.adapter.type() def commit(self): - return self.adapter.commit_if_has_connection(self.model.get('name')) + return self.adapter.commit_if_has_connection(self.connection_name) def _add_macro_map(context, package_name, macro_map): @@ -337,7 +337,7 @@ def get_this_relation(db_wrapper, config, model): def generate_base(model, model_dict, config, manifest, source_config, - provider): + provider, connection_name): """Generate the common aspects of the config dict.""" if provider is None: raise dbt.exceptions.InternalException( @@ -357,7 +357,7 @@ def generate_base(model, model_dict, config, manifest, source_config, pre_hooks = None post_hooks = None - db_wrapper = DatabaseWrapper(model_dict, adapter) + db_wrapper = DatabaseWrapper(connection_name, adapter) context = dbt.utils.merge(context, { "adapter": db_wrapper, @@ -415,7 +415,7 @@ def modify_generated_context(context, model, model_dict, config, manifest): return context -def generate_execute_macro(model, config, manifest, provider): +def generate_execute_macro(model, config, manifest, provider, connection_name): """Internally, macros can be executed like nodes, with some restrictions: - they don't have have all values available that nodes do: @@ -425,7 +425,7 @@ def generate_execute_macro(model, config, manifest, provider): """ model_dict = model.serialize() context = generate_base(model, model_dict, config, manifest, - None, provider) + None, provider, connection_name) return modify_generated_context(context, model, model_dict, config, manifest) @@ -434,7 +434,7 @@ def generate_execute_macro(model, config, manifest, provider): def generate_model(model, config, manifest, source_config, provider): model_dict = model.to_dict() context = generate_base(model, model_dict, config, manifest, - source_config, provider) + source_config, provider, model.get('name')) # operations (hooks) don't get a 'this' if model.resource_type != NodeType.Operation: this = get_this_relation(context['adapter'], config, model_dict) @@ -459,8 +459,5 @@ def generate(model, config, manifest, source_config=None, provider=None): or dbt.context.runtime.generate """ - if isinstance(model, ParsedMacro): - return generate_execute_macro(model, config, manifest, provider) - else: - return generate_model(model, config, manifest, source_config, - provider) + return generate_model(model, config, manifest, source_config, + provider) diff --git a/core/dbt/context/parser.py b/core/dbt/context/parser.py index 84bb69295a8..9d35546ab6f 100644 --- a/core/dbt/context/parser.py +++ b/core/dbt/context/parser.py @@ -91,3 +91,10 @@ def get(self, name, validator=None, default=None): def generate(model, runtime_config, manifest, source_config): return dbt.context.common.generate( model, runtime_config, manifest, source_config, dbt.context.parser) + + +def generate_macro(model, runtime_config, manifest, connection_name): + return dbt.context.common.generate_execute_macro( + model, runtime_config, manifest, dbt.context.parser, + connection_name + ) diff --git a/core/dbt/context/runtime.py b/core/dbt/context/runtime.py index ca2b95de502..0f26726c6b8 100644 --- a/core/dbt/context/runtime.py +++ b/core/dbt/context/runtime.py @@ -97,3 +97,10 @@ def get(self, name, validator=None, default=None): def generate(model, runtime_config, manifest): return dbt.context.common.generate( model, runtime_config, manifest, None, dbt.context.runtime) + + +def generate_macro(model, runtime_config, manifest, connection_name): + return dbt.context.common.generate_execute_macro( + model, runtime_config, manifest, dbt.context.runtime, + connection_name + ) diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index 2e04a7942fb..21be85976ba 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -563,3 +563,7 @@ def raise_unrecognized_credentials_type(typename, supported_types): 'Unrecognized credentials type "{}" - supported types are ({})' .format(typename, ', '.join('"{}"'.format(t) for t in supported_types)) ) + + +def raise_not_implemented(msg): + raise NotImplementedException(msg) diff --git a/core/dbt/include/global_project/__init__.py b/core/dbt/include/global_project/__init__.py index d0415108083..91e1f3b0bca 100644 --- a/core/dbt/include/global_project/__init__.py +++ b/core/dbt/include/global_project/__init__.py @@ -1,10 +1,10 @@ import os -PACKAGE_PATH = os.path.dirname(os.path.dirname(__file__)) +PACKAGE_PATH = os.path.dirname(__file__) PROJECT_NAME = 'dbt' DOCS_INDEX_FILE_PATH = os.path.normpath( - os.path.join(PACKAGE_PATH, "index.html")) + os.path.join(PACKAGE_PATH, '..', "index.html")) # Adapter registration will add to this diff --git a/core/dbt/include/global_project/macros/adapters/common.sql b/core/dbt/include/global_project/macros/adapters/common.sql index 298e6931783..0a1dfce2739 100644 --- a/core/dbt/include/global_project/macros/adapters/common.sql +++ b/core/dbt/include/global_project/macros/adapters/common.sql @@ -28,8 +28,24 @@ {%- endif -%} {%- endmacro %} -{% macro create_schema(database_name, schema_name) %} - {{ adapter.create_schema(database_name, schema_name) }} +{% macro create_schema(database_name, schema_name) -%} + {{ adapter_macro('create_schema', database_name, schema_name) }} +{% endmacro %} + +{% macro default__create_schema(database_name, schema_name) -%} + {%- call statement('create_schema') -%} + create schema if not exists {{database_name}}.{{schema_name}} + {% endcall %} +{% endmacro %} + +{% macro drop_schema(database_name, schema_name) -%} + {{ adapter_macro('drop_schema', database_name, schema_name) }} +{% endmacro %} + +{% macro default__drop_schema(database_name, schema_name) -%} + {%- call statement('drop_schema') -%} + drop schema if exists {{database_name}}.{{schema_name}} cascade + {% endcall %} {% endmacro %} {% macro create_table_as(temporary, relation, sql) -%} @@ -81,16 +97,130 @@ {{ exceptions.raise_compiler_error(msg) }} {% endmacro %} -{% macro get_relations() -%} - {{ return(adapter_macro('get_relations')) }} + +{% macro get_columns_in_relation(relation) -%} + {{ return(adapter_macro('get_columns_in_relation', relation)) }} {% endmacro %} +{% macro sql_convert_columns_in_relation(table) -%} + {% set columns = [] %} + {% for row in table %} + {% do columns.append(api.Column(*row)) %} + {% endfor %} + {{ return(columns) }} +{% endmacro %} -{% macro default__get_relations() -%} - {% set typename = adapter.type() %} - {% set msg -%} - get_relations not implemented for {{ typename }} - {%- endset %} +{% macro default__get_columns_in_relation(relation) -%} + {{ dbt.exceptions.raise_not_implemented( + 'get_columns_in_relation macro not implemented for adapter '+adapter.type()) }} +{% endmacro %} - {{ exceptions.raise_compiler_error(msg) }} +{% macro alter_column_type(relation, column_name, new_column_type) -%} + {{ return(adapter_macro('alter_column_type', relation, column_name, new_column_type)) }} +{% endmacro %} + +{% macro default__alter_column_type(relation, column_name, new_column_type) -%} + {# + 1. Create a new column (w/ temp name and correct type) + 2. Copy data over to it + 3. Drop the existing column (cascade!) + 4. Rename the new column to existing column + #} + {%- set tmp_column = column_name + "__dbt_alter" -%} + + {% call statement('alter_column_type') %} + alter table {{ relation }} add column {{ tmp_column }} {{ new_column_type }}; + update {{ relation }} set {{ tmp_column }} = {{ column_name }}; + alter table {{ relation }} drop column {{ column_name }} cascade; + alter table {{ relation }} rename column {{ tmp_column }} to {{ column_name }} + {% endcall %} + +{% endmacro %} + + +{% macro drop_relation(relation) -%} + {{ return(adapter_macro('drop_relation', relation)) }} +{% endmacro %} + + +{% macro default__drop_relation(relation) -%} + {% call statement('drop_relation', auto_begin=False) -%} + drop {{ relation.type }} if exists {{ relation }} cascade + {%- endcall %} +{% endmacro %} + +{% macro truncate_relation(relation) -%} + {{ return(adapter_macro('truncate_relation', relation)) }} {% endmacro %} + + +{% macro default__truncate_relation(relation) -%} + {% call statement('truncate_relation') -%} + truncate table {{ relation }} + {%- endcall %} +{% endmacro %} + +{% macro rename_relation(from_relation, to_relation) -%} + {{ return(adapter_macro('rename_relation', from_relation, to_relation)) }} +{% endmacro %} + +{% macro default__rename_relation(from_relation, to_relation) -%} + {% set target_name = adapter.quote_as_configured(to_relation.identifier, 'identifier') %} + {% call statement('rename_relation') -%} + alter table {{ from_relation }} rename to {{ target_name }} + {%- endcall %} +{% endmacro %} + + +{% macro information_schema_name(database) %} + {{ return(adapter_macro('information_schema_name', database)) }} +{% endmacro %} + +{% macro default__information_schema_name(database) -%} + {%- if database -%} + "{{ database }}".information_schema + {%- else -%} + information_schema + {%- endif -%} +{%- endmacro %} + + +{% macro list_schemas(database) -%} + {{ return(adapter_macro('list_schemas', database)) }} +{% endmacro %} + +{% macro default__list_schemas(database) -%} + {% call statement('list_schemas', fetch_result=True, auto_begin=False) %} + select distinct schema_name + from {{ information_schema_name(database) }}.schemata + where catalog_name='{{ database }}' + {% endcall %} + {{ return(load_result('list_schemas').table) }} +{% endmacro %} + + +{% macro check_schema_exists(database, schema) -%} + {{ return(adapter_macro('check_schema_exists', database)) }} +{% endmacro %} + +{% macro default__check_schema_exists(database, schema) -%} + {% call statement('check_schema_exists', fetch_result=True, auto_begin=False) -%} + select count(*) + from {{ information_schema_name(database) }}.schemata + where catalog_name='{{ database }}' + and schema_name='{{ schema }}' + {%- endcall %} + {{ return(load_result('check_schema_exists').table) }} +{% endmacro %} + + +{% macro list_relations_without_caching(database, schema) %} + {{ return(adapter_macro('list_relations_without_caching', database, schema)) }} +{% endmacro %} + + +{% macro default__list_relations_without_caching(database, schema) %} + {{ dbt.exceptions.raise_not_implemented( + 'list_relations_without_caching macro not implemented for adapter '+adapter.type()) }} +{% endmacro %} + diff --git a/core/dbt/loader.py b/core/dbt/loader.py index 3fc304d1fbb..5f696ff14e4 100644 --- a/core/dbt/loader.py +++ b/core/dbt/loader.py @@ -1,4 +1,9 @@ +import os +import itertools + +from dbt.include.global_project import PACKAGES import dbt.exceptions +import dbt.flags from dbt.node_types import NodeType from dbt.contracts.graph.manifest import Manifest @@ -8,6 +13,8 @@ DocumentationParser, DataTestParser, HookParser, ArchiveParser, \ SchemaParser, ParserUtils +from dbt.contracts.project import ProjectList + class GraphLoader(object): def __init__(self, root_project, all_projects): @@ -21,20 +28,6 @@ def __init__(self, root_project, all_projects): self.disabled = [] self.macro_manifest = None - def _load_macro_nodes(self, resource_type): - parser = MacroParser(self.root_project, self.all_projects) - for project_name, project in self.all_projects.items(): - self.macros.update(parser.load_and_parse( - package_name=project_name, - root_dir=project.project_root, - relative_dirs=project.macro_paths, - resource_type=resource_type, - )) - - # make a manifest with just the macros to get the context - self.macro_manifest = Manifest(macros=self.macros, nodes={}, docs={}, - generated_at=timestring(), disabled=[]) - def _load_sql_nodes(self, parser_type, resource_type, relative_dirs_attr, **kwargs): parser = parser_type(self.root_project, self.all_projects, @@ -51,9 +44,24 @@ def _load_sql_nodes(self, parser_type, resource_type, relative_dirs_attr, self.nodes.update(nodes) self.disabled.extend(disabled) - def _load_macros(self): - self._load_macro_nodes(NodeType.Macro) - self._load_macro_nodes(NodeType.Operation) + def _load_macros(self, internal_manifest=None): + # skip any projects in the internal manifest + all_projects = self.all_projects.copy() + if internal_manifest is not None: + for name in internal_project_names(): + all_projects.pop(name, None) + self.macros.update(internal_manifest.macros) + + # give the macroparser all projects but then only load what we haven't + # loaded already + parser = MacroParser(self.root_project, self.all_projects) + for project_name, project in all_projects.items(): + self.macros.update(parser.load_and_parse( + package_name=project_name, + root_dir=project.project_root, + relative_dirs=project.macro_paths, + resource_type=NodeType.Macro, + )) def _load_seeds(self): parser = SeedParser(self.root_project, self.all_projects, @@ -115,11 +123,16 @@ def _load_schema_tests(self): ) self.patches[name] = patch - def load(self): - self._load_macros() + def load(self, internal_manifest=None): + self._load_macros(internal_manifest=internal_manifest) + # make a manifest with just the macros to get the context + self.macro_manifest = Manifest(macros=self.macros, nodes={}, docs={}, + generated_at=timestring(), disabled=[]) self._load_nodes() self._load_docs() self._load_schema_tests() + + def create_manifest(self): manifest = Manifest( nodes=self.nodes, macros=self.macros, @@ -136,5 +149,71 @@ def load(self): return manifest @classmethod - def load_all(cls, project_config, all_projects): - return cls(project_config, all_projects).load() + def _load_from_projects(cls, root_config, projects, internal_manifest): + if dbt.flags.STRICT_MODE: + ProjectList(**projects) + + loader = cls(root_config, projects) + loader.load(internal_manifest=internal_manifest) + return loader.create_manifest() + + @classmethod + def load_all(cls, root_config, internal_manifest=None): + projects = load_all_projects(root_config) + return cls._load_from_projects(root_config, projects, + internal_manifest) + + @classmethod + def load_internal(cls, root_config): + projects = load_internal_projects(root_config) + return cls._load_from_projects(root_config, projects, None) + + +def internal_project_names(): + return iter(PACKAGES.values()) + + +def _load_projects(config, paths): + for path in paths: + try: + project = config.new_project(path) + except dbt.exceptions.DbtProjectError as e: + raise dbt.exceptions.DbtProjectError( + 'Failed to read package at {}: {}' + .format(path, e) + ) + else: + yield project.project_name, project + + +def _project_directories(config): + root = os.path.join(config.project_root, config.modules_path) + + dependencies = [] + if os.path.exists(root): + dependencies = os.listdir(root) + + for name in dependencies: + full_obj = os.path.join(root, name) + + if not os.path.isdir(full_obj) or name.startswith('__'): + # exclude non-dirs and dirs that start with __ + # the latter could be something like __pycache__ + # for the global dbt modules dir + continue + + yield full_obj + + +def load_all_projects(config): + all_projects = {config.project_name: config} + project_paths = itertools.chain( + internal_project_names(), + _project_directories(config) + ) + all_projects.update(_load_projects(config, project_paths)) + return all_projects + + +def load_internal_projects(config): + return dict(_load_projects(config, internal_project_names())) diff --git a/core/dbt/parser/base.py b/core/dbt/parser/base.py index 67245385fd9..4cdee6ee8c0 100644 --- a/core/dbt/parser/base.py +++ b/core/dbt/parser/base.py @@ -87,9 +87,9 @@ def get_schema_func(self): def get_schema(_): return self.default_schema else: - root_context = dbt.context.parser.generate( + root_context = dbt.context.parser.generate_macro( get_schema_macro, self.root_project_config, - self.macro_manifest, self.root_project_config + self.macro_manifest, 'generate_schema_name' ) get_schema = get_schema_macro.generator(root_context) diff --git a/core/dbt/runner.py b/core/dbt/runner.py index b459688775f..247d6f0acc9 100644 --- a/core/dbt/runner.py +++ b/core/dbt/runner.py @@ -6,6 +6,9 @@ from dbt.contracts.graph.parsed import ParsedNode from dbt.contracts.graph.manifest import CompileResultNode from dbt.contracts.results import ExecutionResult +from dbt import deprecations +from dbt.loader import GraphLoader +from dbt.node_types import NodeType import dbt.clients.jinja import dbt.compilation @@ -22,6 +25,7 @@ RESULT_FILE_NAME = 'run_results.json' +MANIFEST_FILE_NAME = 'manifest.json' class RunManager(object): @@ -34,13 +38,13 @@ def __init__(self, config, query, Runner): self.query = query self.Runner = Runner - manifest, linker = self.compile(self.config) - self.manifest = manifest - self.linker = linker + self.manifest = self.load_manifest() + self.linker = self.compile() - selector = dbt.graph.selector.NodeSelector(linker, manifest) + selector = dbt.graph.selector.NodeSelector(self.linker, self.manifest) selected_nodes = selector.select(query) - self.job_queue = self.linker.as_graph_queue(manifest, selected_nodes) + self.job_queue = self.linker.as_graph_queue(self.manifest, + selected_nodes) # we use this a couple times. order does not matter. self._flattened_nodes = [ @@ -195,10 +199,75 @@ def write_results(self, execution_result): filepath = os.path.join(self.config.target_path, RESULT_FILE_NAME) write_json(filepath, execution_result.serialize()) - def compile(self, config): - compiler = dbt.compilation.Compiler(config) + def write_manifest_file(self, manifest): + """Write the manifest file to disk. + + manifest should be a Manifest. + """ + manifest_path = os.path.join(self.config.target_path, + MANIFEST_FILE_NAME) + write_json(manifest_path, manifest.serialize()) + + @staticmethod + def _check_resource_uniqueness(manifest): + names_resources = {} + alias_resources = {} + + for resource, node in manifest.nodes.items(): + if node.resource_type not in NodeType.refable(): + continue + + name = node.name + alias = "{}.{}".format(node.schema, node.alias) + + existing_node = names_resources.get(name) + if existing_node is not None: + dbt.exceptions.raise_duplicate_resource_name( + existing_node, node) + + existing_alias = alias_resources.get(alias) + if existing_alias is not None: + dbt.exceptions.raise_ambiguous_alias( + existing_alias, node) + + names_resources[name] = node + alias_resources[alias] = node + + def _warn_for_unused_resource_config_paths(self, manifest): + resource_fqns = manifest.get_resource_fqns() + disabled_fqns = [n.fqn for n in manifest.disabled] + self.config.warn_for_unused_resource_config_paths(resource_fqns, + disabled_fqns) + + @staticmethod + def _warn_for_deprecated_configs(manifest): + for unique_id, node in manifest.nodes.items(): + is_model = node.resource_type == NodeType.Model + if is_model and 'sql_where' in node.config: + deprecations.warn('sql_where') + + def _check_manifest(self, manifest): + # TODO: maybe this belongs in the GraphLoader, too? + self._check_resource_uniqueness(manifest) + self._warn_for_unused_resource_config_paths(manifest) + self._warn_for_deprecated_configs(manifest) + + def load_manifest(self): + # performance trick: if the adapter has a manifest loaded, use that to + # avoid parsing internal macros twice. + internal_manifest = get_adapter(self.config).check_internal_manifest() + manifest = GraphLoader.load_all(self.config, + internal_manifest=internal_manifest) + + self.write_manifest_file(manifest) + self._check_manifest(manifest) + return manifest + + def compile(self): + compiler = dbt.compilation.Compiler(self.config) compiler.initialize() - return compiler.compile() + + return compiler.compile(self.manifest) def run(self): """ diff --git a/core/dbt/schema.py b/core/dbt/schema.py index 87c6e4e972b..5a38f6d86a6 100644 --- a/core/dbt/schema.py +++ b/core/dbt/schema.py @@ -10,11 +10,13 @@ class Column(object): 'INTEGER': 'INT' } - def __init__(self, column, dtype, char_size=None, numeric_size=None): + def __init__(self, column, dtype, char_size=None, numeric_precision=None, + numeric_scale=None): self.column = column self.dtype = dtype self.char_size = char_size - self.numeric_size = numeric_size + self.numeric_precision = numeric_precision + self.numeric_scale = numeric_scale @classmethod def translate_type(cls, dtype): @@ -38,7 +40,8 @@ def data_type(self): if self.is_string(): return Column.string_type(self.string_size()) elif self.is_numeric(): - return Column.numeric_type(self.dtype, self.numeric_size) + return Column.numeric_type(self.dtype, self.numeric_precision, + self.numeric_scale) else: return self.dtype @@ -74,13 +77,13 @@ def string_type(cls, size): return "character varying({})".format(size) @classmethod - def numeric_type(cls, dtype, size): + def numeric_type(cls, dtype, precision, scale): # This could be decimal(...), numeric(...), number(...) # Just use whatever was fed in here -- don't try to get too clever - if size is None: + if precision is None or scale is None: return dtype else: - return "{}({})".format(dtype, size) + return "{}({},{})".format(dtype, precision, scale) def __repr__(self): return "".format(self.name, self.data_type) diff --git a/core/dbt/task/generate.py b/core/dbt/task/generate.py index 96d5789eb10..5be152ba90e 100644 --- a/core/dbt/task/generate.py +++ b/core/dbt/task/generate.py @@ -187,12 +187,7 @@ def incorporate_catalog_unique_ids(catalog, manifest): class GenerateTask(CompileTask): def _get_manifest(self): - compiler = dbt.compilation.Compiler(self.config) - compiler.initialize() - - all_projects = compiler.get_all_projects() - - manifest = dbt.loader.GraphLoader.load_all(self.config, all_projects) + manifest = dbt.loader.GraphLoader.load_all(self.config) return manifest def run(self): diff --git a/core/dbt/utils.py b/core/dbt/utils.py index 3bc9173f6e8..db6be8cb7c7 100644 --- a/core/dbt/utils.py +++ b/core/dbt/utils.py @@ -169,36 +169,6 @@ def get_docs_macro_name(docs_name, with_prefix=True): return docs_name -def dependencies_for_path(config, module_path): - """Given a module path, yield all dependencies in that path.""" - logger.debug("Loading dependency project from {}".format(module_path)) - for obj in os.listdir(module_path): - full_obj = os.path.join(module_path, obj) - - if not os.path.isdir(full_obj) or obj.startswith('__'): - # exclude non-dirs and dirs that start with __ - # the latter could be something like __pycache__ - # for the global dbt modules dir - continue - - try: - yield config.new_project(full_obj) - except dbt.exceptions.DbtProjectError as e: - raise dbt.exceptions.DbtProjectError( - 'Failed to read package at {}: {}' - .format(full_obj, e) - ) - - -def dependency_projects(config): - module_paths = list(PACKAGES.values()) - module_paths.append(os.path.join(config.project_root, config.modules_path)) - - for module_path in module_paths: - for entry in dependencies_for_path(config, module_path): - yield entry - - def split_path(path): return path.split(os.sep) diff --git a/plugins/bigquery/dbt/include/bigquery/__init__.py b/plugins/bigquery/dbt/include/bigquery/__init__.py index 87098354afd..564a3d1e80a 100644 --- a/plugins/bigquery/dbt/include/bigquery/__init__.py +++ b/plugins/bigquery/dbt/include/bigquery/__init__.py @@ -1,2 +1,2 @@ import os -PACKAGE_PATH = os.path.dirname(os.path.dirname(__file__)) +PACKAGE_PATH = os.path.dirname(__file__) diff --git a/plugins/bigquery/dbt/include/bigquery/macros/adapters.sql b/plugins/bigquery/dbt/include/bigquery/macros/adapters.sql index b223b6d4e18..83b97f70895 100644 --- a/plugins/bigquery/dbt/include/bigquery/macros/adapters.sql +++ b/plugins/bigquery/dbt/include/bigquery/macros/adapters.sql @@ -13,7 +13,7 @@ {% macro cluster_by(raw_cluster_by) %} {%- if raw_cluster_by is not none -%} - cluster by + cluster by {% if raw_cluster_by is string -%} {% set raw_cluster_by = [raw_cluster_by] %} {%- endif -%} @@ -45,3 +45,20 @@ {{ sql }} ); {% endmacro %} + +{% macro bigquery__create_schema(database_name, schema_name) -%} + {{ adapter.create_schema(database_name, schema_name) }} +{% endmacro %} + +{% macro bigquery__drop_schema(database_name, schema_name) -%} + {{ adapter.drop_schema(database_name, schema_name) }} +{% endmacro %} + +{% macro bigquery__get_columns_in_relation(relation) -%} + {{ return(adapter.get_columns_in_relation(relation)) }} +{% endmacro %} + + +{% macro bigquery__list_relations_without_caching(database, schema) -%} + {{ return(adapter.list_relations_without_caching(database, schema)) }} +{% endmacro %} diff --git a/plugins/postgres/dbt/adapters/postgres/impl.py b/plugins/postgres/dbt/adapters/postgres/impl.py index 78188e63571..de034aa380f 100644 --- a/plugins/postgres/dbt/adapters/postgres/impl.py +++ b/plugins/postgres/dbt/adapters/postgres/impl.py @@ -2,6 +2,7 @@ import time +from dbt.adapters.base.meta import available_raw from dbt.adapters.sql import SQLAdapter from dbt.adapters.postgres import PostgresConnectionManager import dbt.compat @@ -11,7 +12,8 @@ from dbt.logger import GLOBAL_LOGGER as logger -GET_RELATIONS_MACRO_NAME = 'get_relations' +# note that this isn't an adapter macro, so just a single underscore +GET_RELATIONS_MACRO_NAME = 'postgres_get_relations' class PostgresAdapter(SQLAdapter): @@ -21,15 +23,21 @@ class PostgresAdapter(SQLAdapter): def date_function(cls): return 'datenow()' - def _verify_database(self, database): - if database != self.config.credentials.database: + @available_raw + def verify_database(self, database): + database = database.strip('"') + expected = self.config.credentials.database + if database != expected: raise dbt.exceptions.NotImplementedException( - 'Cross-db references not allowed in {}'.format(self.type()) + 'Cross-db references not allowed in {} ({} vs {})' + .format(self.type(), database, expected) ) + # return an empty string on success so macros can call this + return '' - def _link_cached_database_relations(self, manifest, schemas): - table = self.execute_macro(manifest, GET_RELATIONS_MACRO_NAME) + def _link_cached_database_relations(self, schemas): database = self.config.credentials.database + table = self.execute_macro(GET_RELATIONS_MACRO_NAME) for (refed_schema, refed_name, dep_schema, dep_name) in table: referenced = self.Relation.create( @@ -51,113 +59,14 @@ def _link_cached_database_relations(self, manifest, schemas): def _link_cached_relations(self, manifest): schemas = set() for db, schema in manifest.get_used_schemas(): - self._verify_database(db) + self.verify_database(db) schemas.add(schema) try: - self._link_cached_database_relations(manifest, schemas) + self._link_cached_database_relations(schemas) finally: self.release_connection(GET_RELATIONS_MACRO_NAME) def _relations_cache_for_schemas(self, manifest): super(PostgresAdapter, self)._relations_cache_for_schemas(manifest) self._link_cached_relations(manifest) - - def list_relations_without_caching(self, database, schema, - model_name=None): - self._verify_database(database) - sql = """ - select - table_catalog as database, - table_name as name, - table_schema as schema, - case when table_type = 'BASE TABLE' then 'table' - when table_type = 'VIEW' then 'view' - else table_type - end as table_type - from information_schema.tables - where table_schema ilike '{schema}' - """.format(database=database, schema=schema).strip() # noqa - - connection, cursor = self.add_query(sql, model_name, auto_begin=False) - - results = cursor.fetchall() - - return [self.Relation.create( - database=_database, - schema=_schema, - identifier=name, - quote_policy={ - 'schema': True, - 'identifier': True - }, - type=_type) - for (_database, name, _schema, _type) in results] - - def list_schemas(self, database, model_name=None): - self._verify_database(database) - sql = """ - select distinct schema_name - from information_schema.schemata - """.format(database=database).strip() # noqa - - connection, cursor = self.add_query(sql, model_name, auto_begin=False) - results = cursor.fetchall() - - return [row[0] for row in results] - - def check_schema_exists(self, database, schema, model_name=None): - self._verify_database(database) - sql = """ - select count(*) - from information_schema.schemata - where schema_name='{schema}' - """.format(database=database, schema=schema).strip() # noqa - - connection, cursor = self.add_query(sql, model_name, - auto_begin=False) - results = cursor.fetchone() - - return results[0] > 0 - - @classmethod - def get_columns_in_relation_sql(cls, relation): - db_filter = '1=1' - if relation.database: - db_filter = "table_catalog ilike '{}'".format(relation.database) - - schema_filter = '1=1' - if relation.schema: - schema_filter = "table_schema = '{}'".format(relation.schema) - - sql = """ - select - column_name, - data_type, - character_maximum_length, - numeric_precision || ',' || numeric_scale as numeric_size - - from information_schema.columns - where table_name = '{table_name}' - and {schema_filter} - and {db_filter} - order by ordinal_position - """.format(table_name=relation.identifier, - schema_filter=schema_filter, - db_filter=db_filter).strip() - - return sql - - def _create_schema_sql(self, database, schema): - self._verify_database(database) - - schema = self.quote_as_configured(schema, 'schema') - - return 'create schema if not exists {schema}'.format(schema=schema) - - def _drop_schema_sql(self, database, schema): - self._verify_database(database) - - schema = self.quote_as_configured(schema, 'schema') - - return 'drop schema if exists {schema} cascade'.format(schema=schema) diff --git a/plugins/postgres/dbt/include/postgres/__init__.py b/plugins/postgres/dbt/include/postgres/__init__.py index 87098354afd..564a3d1e80a 100644 --- a/plugins/postgres/dbt/include/postgres/__init__.py +++ b/plugins/postgres/dbt/include/postgres/__init__.py @@ -1,2 +1,2 @@ import os -PACKAGE_PATH = os.path.dirname(os.path.dirname(__file__)) +PACKAGE_PATH = os.path.dirname(__file__) diff --git a/plugins/postgres/dbt/include/postgres/macros/adapters.sql b/plugins/postgres/dbt/include/postgres/macros/adapters.sql new file mode 100644 index 00000000000..b1111bbb1c3 --- /dev/null +++ b/plugins/postgres/dbt/include/postgres/macros/adapters.sql @@ -0,0 +1,64 @@ + +{% macro postgres__create_schema(database_name, schema_name) -%} + {% if database_name -%} + {{ adapter.verify_database(database_name) }} + {%- endif -%} + {%- call statement('create_schema') -%} + create schema if not exists {{ schema_name }} + {%- endcall -%} +{% endmacro %} + +{% macro postgres__drop_schema(database_name, schema_name) -%} + {% if database_name -%} + {{ adapter.verify_database(database_name) }} + {%- endif -%} + {%- call statement('drop_schema') -%} + drop schema if exists {{ schema_name }} cascade + {%- endcall -%} +{% endmacro %} + +{% macro postgres__get_columns_in_relation(relation) -%} + {% call statement('get_columns_in_relation', fetch_result=True) %} + select + column_name, + data_type, + character_maximum_length, + numeric_precision, + numeric_scale + + from {{ information_schema_name(relation.database) }}.columns + where table_name = '{{ relation.identifier }}' + {% if relation.schema %} + and table_schema = '{{ relation.schema }}' + {% endif %} + order by ordinal_position + + {% endcall %} + {% set table = load_result('get_columns_in_relation').table %} + {{ return(sql_convert_columns_in_relation(table)) }} +{% endmacro %} + + +{% macro postgres__list_relations_without_caching(database, schema) %} + {% call statement('list_relations_without_caching', fetch_result=True) -%} + select + table_catalog as database, + table_name as name, + table_schema as schema, + case when table_type = 'BASE TABLE' then 'table' + when table_type = 'VIEW' then 'view' + else table_type + end as table_type + from {{ information_schema_name(database) }}.tables + where table_schema ilike '{{ schema }}' + and table_catalog ilike '{{ database }}' + {% endcall %} + {{ return(load_result('list_relations_without_caching').table) }} +{% endmacro %} + +{% macro postgres__information_schema_name(database) -%} + {% if database_name -%} + {{ adapter.verify_database(database_name) }} + {%- endif -%} + information_schema +{%- endmacro %} diff --git a/plugins/postgres/dbt/include/postgres/macros/catalog.sql b/plugins/postgres/dbt/include/postgres/macros/catalog.sql index c704aa6f974..e04e521ea94 100644 --- a/plugins/postgres/dbt/include/postgres/macros/catalog.sql +++ b/plugins/postgres/dbt/include/postgres/macros/catalog.sql @@ -6,6 +6,7 @@ exceptions.raise_compiler_error('postgres get_catalog requires exactly one database') {% endif %} {% set database = databases[0] %} + {{ adapter.verify_database(database) }} with table_owners as ( diff --git a/plugins/postgres/dbt/include/postgres/macros/relations.sql b/plugins/postgres/dbt/include/postgres/macros/relations.sql index 18c685986eb..820af014eab 100644 --- a/plugins/postgres/dbt/include/postgres/macros/relations.sql +++ b/plugins/postgres/dbt/include/postgres/macros/relations.sql @@ -1,4 +1,4 @@ -{% macro postgres__get_relations () -%} +{% macro postgres_get_relations () -%} {%- call statement('relations', fetch_result=True) -%} -- {# -- in pg_depend, objid is the dependent, refobjid is the referenced object @@ -40,7 +40,7 @@ referenced_class.kind from relation join class as referenced_class on relation.class=referenced_class.id - where referenced_class.kind in ('r', 'v') + where referenced_class.kind in ('r', 'v') ), relationships as ( select diff --git a/plugins/redshift/dbt/adapters/redshift/impl.py b/plugins/redshift/dbt/adapters/redshift/impl.py index 57d0cf00011..db81f748b21 100644 --- a/plugins/redshift/dbt/adapters/redshift/impl.py +++ b/plugins/redshift/dbt/adapters/redshift/impl.py @@ -11,75 +11,6 @@ class RedshiftAdapter(PostgresAdapter): def date_function(cls): return 'getdate()' - @classmethod - def get_columns_in_relation_sql(cls, relation): - # Redshift doesn't support cross-database queries, so we can ignore the - # relation's database - schema_filter = '1=1' - if relation.schema: - schema_filter = "table_schema = '{}'".format(relation.schema) - - sql = """ - with bound_views as ( - select - ordinal_position, - table_schema, - column_name, - data_type, - character_maximum_length, - numeric_precision || ',' || numeric_scale as numeric_size - - from information_schema.columns - where table_name = '{table_name}' - ), - - unbound_views as ( - select - ordinal_position, - view_schema, - col_name, - case - when col_type ilike 'character varying%' then - 'character varying' - when col_type ilike 'numeric%' then 'numeric' - else col_type - end as col_type, - case - when col_type like 'character%' - then nullif(REGEXP_SUBSTR(col_type, '[0-9]+'), '')::int - else null - end as character_maximum_length, - case - when col_type like 'numeric%' - then nullif(REGEXP_SUBSTR(col_type, '[0-9,]+'), '') - else null - end as numeric_size - - from pg_get_late_binding_view_cols() - cols(view_schema name, view_name name, col_name name, - col_type varchar, ordinal_position int) - where view_name = '{table_name}' - ), - - unioned as ( - select * from bound_views - union all - select * from unbound_views - ) - - select - column_name, - data_type, - character_maximum_length, - numeric_size - - from unioned - where {schema_filter} - order by ordinal_position - """.format(table_name=relation.identifier, - schema_filter=schema_filter).strip() - return sql - def drop_relation(self, relation, model_name=None): """ In Redshift, DROP TABLE ... CASCADE should not be used diff --git a/plugins/redshift/dbt/include/redshift/__init__.py b/plugins/redshift/dbt/include/redshift/__init__.py index ef10d7896ca..8b17c9fbfc5 100644 --- a/plugins/redshift/dbt/include/redshift/__init__.py +++ b/plugins/redshift/dbt/include/redshift/__init__.py @@ -1,3 +1,3 @@ import os from dbt.include.postgres import PACKAGE_PATH as POSTGRES_PACKAGE_PATH -PACKAGE_PATH = os.path.dirname(os.path.dirname(__file__)) +PACKAGE_PATH = os.path.dirname(__file__) diff --git a/plugins/redshift/dbt/include/redshift/macros/adapters.sql b/plugins/redshift/dbt/include/redshift/macros/adapters.sql index 7bd9565a534..f6216ef79f6 100644 --- a/plugins/redshift/dbt/include/redshift/macros/adapters.sql +++ b/plugins/redshift/dbt/include/redshift/macros/adapters.sql @@ -64,3 +64,99 @@ {{ dist('dbt_updated_at') }} {{ sort('compound', ['scd_id']) }}; {%- endmacro %} + + +{% macro redshift__create_schema(database_name, schema_name) -%} + {{ postgres__create_schema(database_name, schema_name) }} +{% endmacro %} + +{% macro redshift__drop_schema(database_name, schema_name) -%} + {{ postgres__drop_schema(database_name, schema_name) }} +{% endmacro %} + + +{% macro redshift__get_columns_in_relation(relation) -%} + {% call statement('get_columns_in_relation', fetch_result=True) %} + with bound_views as ( + select + ordinal_position, + table_schema, + column_name, + data_type, + character_maximum_length, + numeric_precision, + numeric_scale + + from information_schema.columns + where table_name = '{{ relation.identifier }}' + ), + + unbound_views as ( + select + ordinal_position, + view_schema, + col_name, + case + when col_type ilike 'character varying%' then + 'character varying' + when col_type ilike 'numeric%' then 'numeric' + else col_type + end as col_type, + case + when col_type like 'character%' + then nullif(REGEXP_SUBSTR(col_type, '[0-9]+'), '')::int + else null + end as character_maximum_length, + case + when col_type like 'numeric%' + then nullif( + SPLIT_PART(REGEXP_SUBSTR(col_type, '[0-9,]+'), ',', 1), + '')::int + else null + end as numeric_precision, + case + when col_type like 'numeric%' + then nullif( + SPLIT_PART(REGEXP_SUBSTR(col_type, '[0-9,]+'), ',', 2), + '')::int + else null + end as numeric_scale + + from pg_get_late_binding_view_cols() + cols(view_schema name, view_name name, col_name name, + col_type varchar, ordinal_position int) + where view_name = '{{ relation.identifier }}' + ), + + unioned as ( + select * from bound_views + union all + select * from unbound_views + ) + + select + column_name, + data_type, + character_maximum_length, + numeric_precision, + numeric_scale + + from unioned + {% if relation.schema %} + where table_schema = '{{ relation.schema }}' + {% endif %} + order by ordinal_position + {% endcall %} + {% set table = load_result('get_columns_in_relation').table %} + {{ return(sql_convert_columns_in_relation(table)) }} +{% endmacro %} + + +{% macro redshift__list_relations_without_caching(database, schema) %} + {{ return(postgres__list_relations_without_caching(database, schema)) }} +{% endmacro %} + + +{% macro redshift__information_schema_name(database) -%} + {{ return(postgres__information_schema_name(database)) }} +{%- endmacro %} diff --git a/plugins/snowflake/dbt/adapters/snowflake/impl.py b/plugins/snowflake/dbt/adapters/snowflake/impl.py index 9544d29f686..338b3d591d4 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/impl.py +++ b/plugins/snowflake/dbt/adapters/snowflake/impl.py @@ -18,63 +18,6 @@ class SnowflakeAdapter(SQLAdapter): def date_function(cls): return 'CURRENT_TIMESTAMP()' - def list_relations_without_caching(self, database, schema, - model_name=None): - sql = """ - select - table_catalog as database, table_name as name, - table_schema as schema, table_type as type - from information_schema.tables - where - table_schema ilike '{schema}' and - table_catalog ilike '{database}' - """.format(database=database, schema=schema).strip() # noqa - - _, cursor = self.add_query(sql, model_name, auto_begin=False) - - results = cursor.fetchall() - - relation_type_lookup = { - 'BASE TABLE': 'table', - 'VIEW': 'view' - - } - return [self.Relation.create( - database=_database, - schema=_schema, - identifier=name, - quote_policy={ - 'identifier': True, - 'schema': True, - }, - type=relation_type_lookup.get(_type)) - for (_database, name, _schema, _type) in results] - - def list_schemas(self, database, model_name=None): - sql = """ - select distinct schema_name - from "{database}".information_schema.schemata - where catalog_name ilike '{database}' - """.format(database=database).strip() # noqa - - connection, cursor = self.add_query(sql, model_name, auto_begin=False) - results = cursor.fetchall() - - return [row[0] for row in results] - - def check_schema_exists(self, database, schema, model_name=None): - sql = """ - select count(*) - from information_schema.schemata - where upper(schema_name) = upper('{schema}') - and upper(catalog_name) = upper('{database}') - """.format(database=database, schema=schema).strip() # noqa - - connection, cursor = self.add_query(sql, model_name, auto_begin=False) - results = cursor.fetchone() - - return results[0] > 0 - @classmethod def _catalog_filter_table(cls, table, manifest): # On snowflake, users can set QUOTED_IDENTIFIERS_IGNORE_CASE, so force @@ -99,34 +42,3 @@ def _make_match_kwargs(self, database, schema, identifier): return filter_null_values({'identifier': identifier, 'schema': schema, 'database': database}) - - @classmethod - def get_columns_in_relation_sql(cls, relation): - source_name = 'information_schema.columns' - db_filter = '1=1' - if relation.database: - db_filter = "table_catalog ilike '{}'".format(relation.database) - source_name = '{}.{}'.format(relation.database, source_name) - - schema_filter = '1=1' - if relation.schema: - schema_filter = "table_schema ilike '{}'".format(relation.schema) - - sql = """ - select - column_name, - data_type, - character_maximum_length, - numeric_precision || ',' || numeric_scale as numeric_size - - from {source_name} - where table_name ilike '{table_name}' - and {schema_filter} - and {db_filter} - order by ordinal_position - """.format(source_name=source_name, - table_name=relation.identifier, - schema_filter=schema_filter, - db_filter=db_filter).strip() - - return sql diff --git a/plugins/snowflake/dbt/include/snowflake/__init__.py b/plugins/snowflake/dbt/include/snowflake/__init__.py index 87098354afd..564a3d1e80a 100644 --- a/plugins/snowflake/dbt/include/snowflake/__init__.py +++ b/plugins/snowflake/dbt/include/snowflake/__init__.py @@ -1,2 +1,2 @@ import os -PACKAGE_PATH = os.path.dirname(os.path.dirname(__file__)) +PACKAGE_PATH = os.path.dirname(__file__) diff --git a/plugins/snowflake/dbt/include/snowflake/macros/adapters.sql b/plugins/snowflake/dbt/include/snowflake/macros/adapters.sql index d63875b648b..255e0a9df12 100644 --- a/plugins/snowflake/dbt/include/snowflake/macros/adapters.sql +++ b/plugins/snowflake/dbt/include/snowflake/macros/adapters.sql @@ -11,3 +11,60 @@ {{ sql }} ); {% endmacro %} + +{% macro snowflake__get_columns_in_relation(relation) -%} + {% call statement('get_columns_in_relation', fetch_result=True) %} + select + column_name, + data_type, + character_maximum_length, + numeric_precision, + numeric_scale + + from + {{ information_schema_name(relation.database) }}.columns + + where table_name ilike '{{ relation.identifier }}' + {% if relation.schema %} + and table_schema ilike '{{ relation.schema }}' + {% endif %} + {% if relation.database %} + and table_catalog ilike '{{ relation.database }}' + {% endif %} + order by ordinal_position + + {% endcall %} + + {% set table = load_result('get_columns_in_relation').table %} + {{ return(sql_convert_columns_in_relation(table)) }} + +{% endmacro %} + + +{% macro snowflake__list_relations_without_caching(database, schema) %} + {% call statement('list_relations_without_caching', fetch_result=True) -%} + select + table_catalog as database, + table_name as name, + table_schema as schema, + case when table_type = 'BASE TABLE' then 'table' + when table_type = 'VIEW' then 'view' + else table_type + end as table_type + from {{ information_schema_name(database) }}.tables + where table_schema ilike '{{ schema }}' + and table_catalog ilike '{{ database }}' + {% endcall %} + {{ return(load_result('list_relations_without_caching').table) }} +{% endmacro %} + + +{% macro snowflake__check_schema_exists(database, schema) -%} + {% call statement('check_schema_exists', fetch_result=True) -%} + select count(*) + from {{ information_schema_name(database) }} + where upper(schema_name) = upper('{{ schema }}') + and upper(catalog_name) = upper('{{ database }}') + {%- endcall %} + {{ return(load_result('check_schema_exists').table) }} +{%- endmacro %} diff --git a/test/integration/029_docs_generate_tests/test_docs_generate.py b/test/integration/029_docs_generate_tests/test_docs_generate.py index acb7482404a..3e660695399 100644 --- a/test/integration/029_docs_generate_tests/test_docs_generate.py +++ b/test/integration/029_docs_generate_tests/test_docs_generate.py @@ -46,7 +46,7 @@ def _normalize(path): class TestDocsGenerate(DBTIntegrationTest): setup_alternate_db = True def setUp(self): - super(TestDocsGenerate,self).setUp() + super(TestDocsGenerate, self).setUp() self.maxDiff = None @property diff --git a/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py b/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py index 43a728486c8..cfb9876a4c7 100644 --- a/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py +++ b/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py @@ -24,8 +24,9 @@ def project_config(self): } def run_select_and_check(self, rel, sql): + connection_name = '__test_{}'.format(id(threading.current_thread())) try: - res = self.run_sql(sql, fetch='one') + res = self.run_sql(sql, fetch='one', connection_name=connection_name) # The result is the output of f_sleep(), which is True if res[0] == True: diff --git a/test/integration/038_caching_test/test_caching.py b/test/integration/038_caching_test/test_caching.py index 4acbf5d55d9..4cdecf46adf 100644 --- a/test/integration/038_caching_test/test_caching.py +++ b/test/integration/038_caching_test/test_caching.py @@ -18,7 +18,7 @@ def project_config(self): def run_and_get_adapter(self): # we want to inspect the adapter that dbt used for the run, which is # not self.adapter. You can't do this until after you've run dbt once. - self.run_dbt(['run'], clear_adapters=False) + self.run_dbt(['run']) return factory._ADAPTERS[self.adapter_type] def cache_run(self): @@ -28,7 +28,7 @@ def cache_run(self): self.assertEqual(relation.inner.schema, self.unique_schema()) self.assertEqual(relation.schema, self.unique_schema().lower()) - self.run_dbt(['run'], clear_adapters=False) + self.run_dbt(['run']) self.assertEqual(len(adapter.cache.relations), 1) second_relation = next(iter(adapter.cache.relations.values())) self.assertEqual(relation, second_relation) diff --git a/test/integration/base.py b/test/integration/base.py index cd397ee1fe7..b783194a720 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -320,8 +320,14 @@ def _clean_files(self): def tearDown(self): self._clean_files() + # get any current run adapter and clean up its connections before we + # reset them. It'll probably be different from ours because + # handle_and_check() calls reset_adapters(). + adapter = get_adapter(self.config) + if adapter is not self.adapter: + adapter.cleanup_connections() if not hasattr(self, 'adapter'): - self.adapter = get_adapter(self.config) + self.adapter = adapter self._drop_schemas() @@ -382,7 +388,6 @@ def _drop_schemas_sql(self): self._created_schemas.clear() - def _drop_schemas(self): if self.adapter_type == 'bigquery': self._drop_schemas_bigquery() @@ -397,10 +402,7 @@ def project_config(self): def profile_config(self): return {} - def run_dbt(self, args=None, expect_pass=True, strict=True, clear_adapters=True): - # clear the adapter cache - if clear_adapters: - reset_adapters() + def run_dbt(self, args=None, expect_pass=True, strict=True): if args is None: args = ["run"] @@ -463,7 +465,10 @@ def run_sql_bigquery(self, sql, fetch): else: return list(res) - def run_sql(self, query, fetch='None', kwargs=None): + def run_sql(self, query, fetch='None', kwargs=None, connection_name=None): + if connection_name is None: + connection_name = '__test' + if query.strip() == "": return @@ -471,7 +476,7 @@ def run_sql(self, query, fetch='None', kwargs=None): if self.adapter_type == 'bigquery': return self.run_sql_bigquery(sql, fetch) - conn = self.adapter.acquire_connection('__test') + conn = self.adapter.acquire_connection(connection_name) with conn.handle.cursor() as cursor: try: cursor.execute(sql) diff --git a/test/unit/test_bigquery_adapter.py b/test/unit/test_bigquery_adapter.py index b56813d84ce..b667cf38a6f 100644 --- a/test/unit/test_bigquery_adapter.py +++ b/test/unit/test_bigquery_adapter.py @@ -9,7 +9,7 @@ import dbt.exceptions from dbt.logger import GLOBAL_LOGGER as logger # noqa -from .utils import config_from_parts_or_dicts +from .utils import config_from_parts_or_dicts, inject_adapter def _bq_conn(): @@ -68,7 +68,9 @@ def get_adapter(self, target): project=project, profile=profile, ) - return BigQueryAdapter(config) + adapter = BigQueryAdapter(config) + inject_adapter('bigquery', adapter) + return adapter @patch('dbt.adapters.bigquery.BigQueryConnectionManager.open', return_value=_bq_conn()) diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index f989d2a167e..88f3f15e694 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -1,4 +1,4 @@ -from mock import MagicMock +from mock import MagicMock, patch import os import six import unittest @@ -10,6 +10,7 @@ import dbt.linker import dbt.config import dbt.utils +import dbt.loader try: from queue import Empty @@ -27,21 +28,24 @@ class GraphTest(unittest.TestCase): def tearDown(self): - nx.write_gpickle = self.real_write_gpickle - dbt.utils.dependency_projects = self.real_dependency_projects - dbt.clients.system.find_matching = self.real_find_matching - dbt.clients.system.load_file_contents = self.real_load_file_contents + self.write_gpickle_patcher.stop() + self.load_projects_patcher.stop() + self.find_matching_patcher.stop() + self.load_file_contents_patcher.stop() def setUp(self): dbt.flags.STRICT_MODE = True + self.graph_result = None + + self.write_gpickle_patcher = patch('networkx.write_gpickle') + self.load_projects_patcher = patch('dbt.loader._load_projects') + self.find_matching_patcher = patch('dbt.clients.system.find_matching') + self.load_file_contents_patcher = patch('dbt.clients.system.load_file_contents') def mock_write_gpickle(graph, outfile): self.graph_result = graph - - self.real_write_gpickle = nx.write_gpickle - nx.write_gpickle = mock_write_gpickle - - self.graph_result = None + self.mock_write_gpickle = self.write_gpickle_patcher.start() + self.mock_write_gpickle.side_effect = mock_write_gpickle self.profile = { 'outputs': { @@ -59,14 +63,13 @@ def mock_write_gpickle(graph, outfile): 'target': 'test' } - self.real_dependency_projects = dbt.utils.dependency_projects - dbt.utils.dependency_projects = MagicMock(return_value=[]) + self.mock_load_projects = self.load_projects_patcher.start() + self.mock_load_projects.return_value = [] self.mock_models = [] self.mock_content = {} - def mock_find_matching(root_path, relative_paths_to_search, - file_pattern): + def mock_find_matching(root_path, relative_paths_to_search, file_pattern): if 'sql' not in file_pattern: return [] @@ -77,16 +80,14 @@ def mock_find_matching(root_path, relative_paths_to_search, return to_return - self.real_find_matching = dbt.clients.system.find_matching - dbt.clients.system.find_matching = MagicMock( - side_effect=mock_find_matching) + self.mock_find_matching = self.find_matching_patcher.start() + self.mock_find_matching.side_effect = mock_find_matching def mock_load_file_contents(path): return self.mock_content[path] - self.real_load_file_contents = dbt.clients.system.load_file_contents - dbt.clients.system.load_file_contents = MagicMock( - side_effect=mock_load_file_contents) + self.mock_load_file_contents = self.load_file_contents_patcher.start() + self.mock_load_file_contents.side_effect = mock_load_file_contents def get_config(self, extra_cfg=None): if extra_cfg is None: @@ -119,8 +120,10 @@ def test__single_model(self): 'model_one': 'select * from events', }) - compiler = self.get_compiler(self.get_config()) - graph, linker = compiler.compile() + config = self.get_config() + manifest = dbt.loader.GraphLoader.load_all(config) + compiler = self.get_compiler(config) + linker = compiler.compile(manifest) self.assertEquals( linker.nodes(), @@ -136,8 +139,10 @@ def test__two_models_simple_ref(self): 'model_two': "select * from {{ref('model_one')}}", }) - compiler = self.get_compiler(self.get_config()) - graph, linker = compiler.compile() + config = self.get_config() + manifest = dbt.loader.GraphLoader.load_all(config) + compiler = self.get_compiler(config) + linker = compiler.compile(manifest) six.assertCountEqual(self, linker.nodes(), @@ -170,8 +175,10 @@ def test__model_materializations(self): } } - compiler = self.get_compiler(self.get_config(cfg)) - manifest, linker = compiler.compile() + config = self.get_config(cfg) + manifest = dbt.loader.GraphLoader.load_all(config) + compiler = self.get_compiler(config) + linker = compiler.compile(manifest) expected_materialization = { "model_one": "table", @@ -205,9 +212,10 @@ def test__model_incremental(self): } } - - compiler = self.get_compiler(self.get_config(cfg)) - manifest, linker = compiler.compile() + config = self.get_config(cfg) + manifest = dbt.loader.GraphLoader.load_all(config) + compiler = self.get_compiler(config) + linker = compiler.compile(manifest) node = 'model.test_models_compile.model_one' @@ -231,8 +239,10 @@ def test__dependency_list(self): 'model_4': 'select * from {{ ref("model_3") }}' }) - compiler = self.get_compiler(self.get_config({})) - graph, linker = compiler.compile() + config = self.get_config() + graph = dbt.loader.GraphLoader.load_all(config) + compiler = self.get_compiler(config) + linker = compiler.compile(graph) models = ('model_1', 'model_2', 'model_3', 'model_4') model_ids = ['model.test_models_compile.{}'.format(m) for m in models] diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index f1e404ae5dd..1677d9e2140 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -10,7 +10,7 @@ from psycopg2 import extensions as psycopg2_extensions import agate -from .utils import config_from_parts_or_dicts +from .utils import config_from_parts_or_dicts, inject_adapter class TestPostgresAdapter(unittest.TestCase): @@ -45,6 +45,7 @@ def setUp(self): def adapter(self): if self._adapter is None: self._adapter = PostgresAdapter(self.config) + inject_adapter('postgres', self._adapter) return self._adapter def test_acquire_connection_validations(self): @@ -196,6 +197,7 @@ def setUp(self): self.psycopg2.connect.return_value = self.handle self.adapter = PostgresAdapter(self.config) + inject_adapter('postgres', self.adapter) def tearDown(self): # we want a unique self.handle every time. diff --git a/test/unit/test_schema.py b/test/unit/test_schema.py index 6c065a422d7..ec411568331 100644 --- a/test/unit/test_schema.py +++ b/test/unit/test_schema.py @@ -1,5 +1,6 @@ import unittest +import decimal import dbt.schema class TestStringType(unittest.TestCase): @@ -20,7 +21,8 @@ def test__numeric_type(self): col = dbt.schema.Column( 'fieldname', 'numeric', - numeric_size='12,2') + numeric_precision=decimal.Decimal('12'), + numeric_scale=decimal.Decimal('2')) self.assertEqual(col.data_type, 'numeric(12,2)') @@ -29,6 +31,6 @@ def test__numeric_type_with_no_precision(self): col = dbt.schema.Column( 'fieldname', 'numeric', - numeric_size=None) + numeric_precision=None) self.assertEqual(col.data_type, 'numeric') diff --git a/test/unit/test_snowflake_adapter.py b/test/unit/test_snowflake_adapter.py index 1df8f11eb1d..ad13dc742e5 100644 --- a/test/unit/test_snowflake_adapter.py +++ b/test/unit/test_snowflake_adapter.py @@ -9,7 +9,7 @@ from dbt.logger import GLOBAL_LOGGER as logger # noqa from snowflake import connector as snowflake_connector -from .utils import config_from_parts_or_dicts +from .utils import config_from_parts_or_dicts, inject_adapter class TestSnowflakeAdapter(unittest.TestCase): def setUp(self): @@ -50,6 +50,8 @@ def setUp(self): self.snowflake.return_value = self.handle self.adapter = SnowflakeAdapter(self.config) + # patch our new adapter into the factory so macros behave + inject_adapter('snowflake', self.adapter) def tearDown(self): # we want a unique self.handle every time. @@ -75,6 +77,7 @@ def test_quoting_on_drop(self): quote_policy=self.adapter.config.quoting, ) self.adapter.drop_relation(relation) + self.mock_execute.assert_has_calls([ mock.call('drop table if exists "test_database"."test_schema".test_table cascade', None) ]) @@ -88,6 +91,7 @@ def test_quoting_on_truncate(self): quote_policy=self.adapter.config.quoting, ) self.adapter.truncate_relation(relation) + self.mock_execute.assert_has_calls([ mock.call('truncate table "test_database"."test_schema".test_table', None) ]) diff --git a/test/unit/utils.py b/test/unit/utils.py index 21281d97824..7e977763dbf 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -27,3 +27,12 @@ def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'): profile=profile, args=args ) + + +def inject_adapter(key, value): + """Inject the given adapter into the adapter factory, so your hand-crafted + artisanal adapter will be available from get_adapter() as if dbt loaded it. + """ + from dbt.adapters import factory + factory._ADAPTERS[key] = value + factory.ADAPTER_TYPES[key] = type(value)