From 49f7cf8ecab4bd74b9b67958df4ace56353b1bed Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 24 Jun 2019 09:39:06 -0400 Subject: [PATCH] Convert dbt to use dataclasses and hologram for representing things Most of the things that previously used manually created jsonschemas Split tests into their own node type Change tests to reflect that tables require a freshness block add a lot more debug-logging on exceptions Make things that get passed to Var() tell it about their vars finally make .empty a property documentation resource type is now a property, not serialized added a Mergeable helper mixin to perform simple merges Convert some oneOf checks into if-else chains to get better errors Add more tests Use "Any" as value in type defs - accept the warning from hologram for now, PR out to suppress it set default values for enabled/materialized Clean up the Parsed/Compiled type hierarchy Allow generic snapshot definitions remove the "graph" entry in the context - This improves performance on large projects significantly Update changelog to reflect removing graph --- CHANGELOG.md | 5 + core/dbt/adapters/base/connections.py | 58 +- core/dbt/adapters/base/relation.py | 8 +- core/dbt/adapters/sql/connections.py | 2 +- core/dbt/clients/jinja.py | 18 +- core/dbt/compilation.py | 31 +- core/dbt/config/__init__.py | 2 +- core/dbt/config/profile.py | 89 +- core/dbt/config/project.py | 33 +- core/dbt/config/runtime.py | 10 +- core/dbt/context/common.py | 55 +- core/dbt/context/parser.py | 17 +- core/dbt/context/runtime.py | 22 +- core/dbt/contracts/connection.py | 104 +- core/dbt/contracts/graph/compiled.py | 272 +-- core/dbt/contracts/graph/manifest.py | 277 ++- core/dbt/contracts/graph/parsed.py | 1057 ++++------- core/dbt/contracts/graph/unparsed.py | 545 ++---- core/dbt/contracts/project.py | 507 ++--- core/dbt/contracts/results.py | 613 ++---- core/dbt/contracts/util.py | 29 + core/dbt/exceptions.py | 71 +- core/dbt/graph/selector.py | 8 +- core/dbt/hooks.py | 49 +- core/dbt/linker.py | 2 +- core/dbt/loader.py | 11 +- core/dbt/main.py | 8 +- core/dbt/node_runners.py | 79 +- core/dbt/node_types.py | 49 +- core/dbt/parser/base.py | 46 +- core/dbt/parser/base_sql.py | 30 +- core/dbt/parser/docs.py | 6 +- core/dbt/parser/hooks.py | 11 +- core/dbt/parser/macros.py | 10 +- core/dbt/parser/schemas.py | 103 +- core/dbt/parser/seeds.py | 2 - core/dbt/parser/snapshots.py | 43 +- core/dbt/parser/util.py | 58 +- core/dbt/semver.py | 57 +- core/dbt/task/compile.py | 3 +- core/dbt/task/deps.py | 748 ++++---- core/dbt/task/list.py | 4 +- core/dbt/task/run.py | 13 +- core/dbt/task/runnable.py | 8 +- core/dbt/types.py | 17 + core/dbt/ui/printer.py | 30 +- core/dbt/utils.py | 26 +- core/dbt/writer.py | 9 +- core/setup.py | 2 + .../dbt/adapters/bigquery/connections.py | 58 +- .../dbt/adapters/postgres/connections.py | 61 +- .../dbt/adapters/redshift/connections.py | 91 +- .../dbt/adapters/snowflake/connections.py | 128 +- .../dbt/adapters/snowflake/relation.py | 8 - .../test-check-col-snapshots-bq/snapshot.sql | 4 +- .../test_graph_selection.py | 4 +- .../test_schema_test_graph_selection.py | 2 +- .../test_tag_selection.py | 8 +- .../test_schema_v2_tests.py | 14 +- .../009_data_tests_test/test_data_tests.py | 4 +- .../test_invalid_models.py | 20 +- .../test_cli_invocation.py | 2 +- .../test_simple_bigquery_view.py | 2 +- .../test_docs_generate.py | 164 +- .../test_changing_relation_type.py | 26 +- .../test_view_binding_dependency.py | 8 +- .../test_simple_presto_view.py | 2 +- .../042_sources_test/models/schema.yml | 3 +- .../042_sources_test/test_sources.py | 14 +- test/unit/test_bigquery_adapter.py | 4 +- test/unit/test_compiler.py | 176 +- test/unit/test_config.py | 19 +- test/unit/test_context.py | 12 +- test/unit/test_contracts_graph_compiled.py | 395 ++++ test/unit/test_contracts_graph_parsed.py | 1678 +++++++++++++++++ test/unit/test_contracts_graph_unparsed.py | 447 +++++ test/unit/test_deps.py | 403 ++-- test/unit/test_docs_blocks.py | 1 - test/unit/test_graph.py | 14 +- test/unit/test_manifest.py | 192 +- test/unit/test_parser.py | 640 +++---- test/unit/test_postgres_adapter.py | 16 +- test/unit/test_redshift_adapter.py | 38 +- test/unit/test_snowflake_adapter.py | 27 +- test/unit/utils.py | 30 + 85 files changed, 5556 insertions(+), 4416 deletions(-) create mode 100644 core/dbt/contracts/util.py create mode 100644 core/dbt/types.py create mode 100644 test/unit/test_contracts_graph_compiled.py create mode 100644 test/unit/test_contracts_graph_parsed.py create mode 100644 test/unit/test_contracts_graph_unparsed.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 08dcbe0e13d..1c502edacb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## dbt ? - Louisa May Alcott + +### Breaking changes + - the undocumented "graph" variable was removed from the parsing context ([#1589](https://github.com/fishtown-analytics/dbt/pull/1589)) + ## dbt 0.14.0 - Wilt Chamberlain (July 10, 2019) ### Overview diff --git a/core/dbt/adapters/base/connections.py b/core/dbt/adapters/base/connections.py index 0c5522a5f55..b8ed37c3048 100644 --- a/core/dbt/adapters/base/connections.py +++ b/core/dbt/adapters/base/connections.py @@ -5,23 +5,26 @@ import dbt.exceptions import dbt.flags -from dbt.api import APIObject from dbt.contracts.connection import Connection +from dbt.contracts.util import Replaceable from dbt.logger import GLOBAL_LOGGER as logger from dbt.utils import translate_aliases +from hologram.helpers import ExtensibleJsonSchemaMixin -class Credentials(APIObject): - """Common base class for credentials. This is not valid to instantiate""" - SCHEMA = NotImplemented - # map credential aliases to their canonical names. - ALIASES = {} +from dataclasses import dataclass, field +from typing import Any, ClassVar, Dict, Tuple - def __init__(self, **kwargs): - renamed = self.translate_aliases(kwargs) - super().__init__(**renamed) - @property +@dataclass +class Credentials( + ExtensibleJsonSchemaMixin, + Replaceable, + metaclass=abc.ABCMeta +): + _ALIASES: ClassVar[Dict[str, str]] = field(default={}, init=False) + + @abc.abstractproperty def type(self): raise NotImplementedError( 'type not implemented for base credentials class' @@ -30,37 +33,34 @@ def type(self): def connection_info(self): """Return an ordered iterator of key/value pairs for pretty-printing. """ + as_dict = self.to_dict() for key in self._connection_keys(): - if key in self._contents: - yield key, self._contents[key] + if key in as_dict: + yield key, as_dict[key] - def _connection_keys(self): - """The credential object keys that should be printed to users in - 'dbt debug' output. This is specific to each adapter. - """ + @abc.abstractmethod + def _connection_keys(self) -> Tuple[str, ...]: raise NotImplementedError - def incorporate(self, **kwargs): - # implementation note: we have to do this here, or - # incorporate(alias_name=...) will result in duplicate keys in the - # merged dict that APIObject.incorporate() creates. - renamed = self.translate_aliases(kwargs) - return super().incorporate(**renamed) + @classmethod + def from_dict(cls, data): + data = cls.translate_aliases(data) + return super().from_dict(data) + + @classmethod + def translate_aliases(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]: + return translate_aliases(kwargs, cls._ALIASES) - def serialize(self, with_aliases=False): - serialized = super().serialize() + def to_dict(self, omit_none=True, validate=False, with_aliases=False): + serialized = super().to_dict(omit_none=omit_none, validate=validate) if with_aliases: serialized.update({ new_name: serialized[canonical_name] - for new_name, canonical_name in self.ALIASES.items() + for new_name, canonical_name in self._ALIASES.items() if canonical_name in serialized }) return serialized - @classmethod - def translate_aliases(cls, kwargs): - return translate_aliases(kwargs, cls.ALIASES) - class BaseConnectionManager(metaclass=abc.ABCMeta): """Methods to implement: diff --git a/core/dbt/adapters/base/relation.py b/core/dbt/adapters/base/relation.py index 5ee8012dfb9..84cd6bc7ecc 100644 --- a/core/dbt/adapters/base/relation.py +++ b/core/dbt/adapters/base/relation.py @@ -220,7 +220,7 @@ def quoted(self, identifier): def create_from_source(cls, source, **kwargs): quote_policy = dbt.utils.deep_merge( cls.DEFAULTS['quote_policy'], - source.quoting, + source.quoting.to_dict(), kwargs.get('quote_policy', {}) ) return cls.create( @@ -240,9 +240,9 @@ def create_from_node(cls, config, node, table_name=None, quote_policy=None, quote_policy = dbt.utils.merge(config.quoting, quote_policy) return cls.create( - database=node.get('database'), - schema=node.get('schema'), - identifier=node.get('alias'), + database=node.database, + schema=node.schema, + identifier=node.alias, table_name=table_name, quote_policy=quote_policy, **kwargs) diff --git a/core/dbt/adapters/sql/connections.py b/core/dbt/adapters/sql/connections.py index 042239ad504..0927530ee90 100644 --- a/core/dbt/adapters/sql/connections.py +++ b/core/dbt/adapters/sql/connections.py @@ -112,7 +112,7 @@ def begin(self): if connection.transaction_open is True: raise dbt.exceptions.InternalException( 'Tried to begin a new transaction on connection "{}", but ' - 'it already had one open!'.format(connection.get('name'))) + 'it already had one open!'.format(connection.name)) self.add_begin_query() diff --git a/core/dbt/clients/jinja.py b/core/dbt/clients/jinja.py index a64d5486a06..4433726c6d6 100644 --- a/core/dbt/clients/jinja.py +++ b/core/dbt/clients/jinja.py @@ -68,15 +68,15 @@ def __init__(self): self.file_cache = {} def get_node_template(self, node): - key = (node['package_name'], node['original_file_path']) + key = (node.package_name, node.original_file_path) if key in self.file_cache: return self.file_cache[key] template = get_template( - string=node.get('raw_sql'), + string=node.raw_sql, ctx={}, - node=node + node=node, ) self.file_cache[key] = template @@ -92,7 +92,7 @@ def clear(self): def macro_generator(node): def apply_context(context): def call(*args, **kwargs): - name = node.get('name') + name = node.name template = template_cache.get_node_template(node) module = template.make_module(context, False, context) @@ -178,21 +178,21 @@ def __init__(self, hint=None, obj=None, name=None, exc=None): super().__init__(hint=hint, name=name) self.node = node self.name = name - self.package_name = node.get('package_name') + self.package_name = node.package_name # jinja uses these for safety, so we have to override them. # see https://github.com/pallets/jinja/blob/master/jinja2/sandbox.py#L332-L339 # noqa self.unsafe_callable = False self.alters_data = False def __deepcopy__(self, memo): - path = os.path.join(self.node.get('root_path'), - self.node.get('original_file_path')) + path = os.path.join(self.node.root_path, + self.node.original_file_path) logger.debug( 'dbt encountered an undefined variable, "{}" in node {}.{} ' '(source path: {})' - .format(self.name, self.node.get('package_name'), - self.node.get('name'), path)) + .format(self.name, self.node.package_name, + self.node.name, path)) # match jinja's message dbt.exceptions.raise_compiler_error( diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index 62111891d5f..3ee95f94b92 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -15,13 +15,21 @@ import dbt.flags import dbt.loader import dbt.config -from dbt.contracts.graph.compiled import CompiledNode +from dbt.contracts.graph.compiled import InjectedCTE, CompiledNode, \ + CompiledTestNode from dbt.logger import GLOBAL_LOGGER as logger graph_file_name = 'graph.gpickle' +def _compiled_type_for(model): + if model.resource_type == NodeType.Test: + return CompiledTestNode + else: + return CompiledNode + + def print_compile_stats(stats): names = { NodeType.Model: 'models', @@ -44,9 +52,9 @@ def print_compile_stats(stats): def _add_prepended_cte(prepended_ctes, new_cte): - for dct in prepended_ctes: - if dct['id'] == new_cte['id']: - dct['sql'] = new_cte['sql'] + for cte in prepended_ctes: + if cte.id == new_cte.id: + cte.sql = new_cte.sql return prepended_ctes.append(new_cte) @@ -67,20 +75,19 @@ def recursively_prepend_ctes(model, manifest): return (model, model.extra_ctes, manifest) if dbt.flags.STRICT_MODE: - # ensure that the cte we're adding to is compiled - CompiledNode(**model.serialize()) + assert isinstance(model, (CompiledNode, CompiledTestNode)) prepended_ctes = [] for cte in model.extra_ctes: - cte_id = cte['id'] + cte_id = cte.id cte_to_add = manifest.nodes.get(cte_id) cte_to_add, new_prepended_ctes, manifest = recursively_prepend_ctes( cte_to_add, manifest) _extend_prepended_ctes(prepended_ctes, new_prepended_ctes) - new_cte_name = '__dbt__CTE__{}'.format(cte_to_add.get('name')) + new_cte_name = '__dbt__CTE__{}'.format(cte_to_add.name) sql = ' {} as (\n{}\n)'.format(new_cte_name, cte_to_add.compiled_sql) - _add_prepended_cte(prepended_ctes, {'id': cte_id, 'sql': sql}) + _add_prepended_cte(prepended_ctes, InjectedCTE(id=cte_id, sql=sql)) model.prepend_ctes(prepended_ctes) @@ -101,7 +108,7 @@ def compile_node(self, node, manifest, extra_context=None): if extra_context is None: extra_context = {} - logger.debug("Compiling {}".format(node.get('unique_id'))) + logger.debug("Compiling {}".format(node.unique_id)) data = node.to_dict() data.update({ @@ -111,14 +118,14 @@ def compile_node(self, node, manifest, extra_context=None): 'extra_ctes': [], 'injected_sql': None, }) - compiled_node = CompiledNode(**data) + compiled_node = _compiled_type_for(node).from_dict(data) context = dbt.context.runtime.generate( compiled_node, self.config, manifest) context.update(extra_context) compiled_node.compiled_sql = dbt.clients.jinja.get_rendered( - node.get('raw_sql'), + node.raw_sql, context, node) diff --git a/core/dbt/config/__init__.py b/core/dbt/config/__init__.py index d20916525ee..09b5523dae1 100644 --- a/core/dbt/config/__init__.py +++ b/core/dbt/config/__init__.py @@ -1,5 +1,5 @@ # all these are just exports, they need "noqa" so flake8 will not complain. from .renderer import ConfigRenderer # noqa -from .profile import Profile, UserConfig, PROFILES_DIR # noqa +from .profile import Profile, PROFILES_DIR, read_user_config # noqa from .project import Project # noqa from .runtime import RuntimeConfig # noqa diff --git a/core/dbt/config/profile.py b/core/dbt/config/profile.py index 79975fc859f..2a33c2304cc 100644 --- a/core/dbt/config/profile.py +++ b/core/dbt/config/profile.py @@ -1,24 +1,23 @@ import os import pprint +from hologram import ValidationError + from dbt.adapters.factory import load_plugin from dbt.clients.system import load_file_contents from dbt.clients.yaml_helper import load_yaml_text -from dbt.contracts.project import ProfileConfig +from dbt.contracts.project import ProfileConfig, UserConfig from dbt.exceptions import DbtProfileError from dbt.exceptions import DbtProjectError from dbt.exceptions import ValidationException from dbt.exceptions import RuntimeException +from dbt.exceptions import validator_error_message from dbt.logger import GLOBAL_LOGGER as logger from dbt.utils import parse_cli_vars -from dbt import tracking -from dbt.ui import printer from .renderer import ConfigRenderer DEFAULT_THREADS = 1 -DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True -DEFAULT_USE_COLORS = True DEFAULT_PROFILES_DIR = os.path.join(os.path.expanduser('~'), '.dbt') PROFILES_DIR = os.path.expanduser( os.environ.get('DBT_PROFILES_DIR', DEFAULT_PROFILES_DIR) @@ -55,59 +54,20 @@ def read_profile(profiles_dir): return load_yaml_text(contents) except ValidationException as e: msg = INVALID_PROFILE_MESSAGE.format(error_string=e) - raise ValidationException(msg) + raise ValidationException(msg) from e return {} -class UserConfig: - def __init__(self, send_anonymous_usage_stats, use_colors, printer_width): - self.send_anonymous_usage_stats = send_anonymous_usage_stats - self.use_colors = use_colors - self.printer_width = printer_width - - @classmethod - def from_dict(cls, cfg=None): - if cfg is None: - cfg = {} - send_anonymous_usage_stats = cfg.get( - 'send_anonymous_usage_stats', - DEFAULT_SEND_ANONYMOUS_USAGE_STATS - ) - use_colors = cfg.get( - 'use_colors', - DEFAULT_USE_COLORS - ) - printer_width = cfg.get( - 'printer_width' - ) - return cls(send_anonymous_usage_stats, use_colors, printer_width) - - def to_dict(self): - return { - 'send_anonymous_usage_stats': self.send_anonymous_usage_stats, - 'use_colors': self.use_colors, - } - - @classmethod - def from_directory(cls, directory): +def read_user_config(directory): + try: user_cfg = None profile = read_profile(directory) if profile: user_cfg = profile.get('config', {}) - return cls.from_dict(user_cfg) - - def set_values(self, cookie_dir): - if self.send_anonymous_usage_stats: - tracking.initialize_tracking(cookie_dir) - else: - tracking.do_not_track() - - if self.use_colors: - printer.use_colors() - - if self.printer_width: - printer.printer_width(self.printer_width) + return UserConfig.from_dict(user_cfg) + except (RuntimeException, ValidationError): + return UserConfig() class Profile: @@ -135,10 +95,10 @@ def to_profile_info(self, serialize_credentials=False): 'target_name': self.target_name, 'config': self.config.to_dict(), 'threads': self.threads, - 'credentials': self.credentials.incorporate(), + 'credentials': self.credentials, } if serialize_credentials: - result['credentials'] = result['credentials'].serialize() + result['credentials'] = result['credentials'].to_dict() return result def __str__(self): @@ -152,12 +112,14 @@ def __eq__(self, other): return self.to_profile_info() == other.to_profile_info() def validate(self): - if self.credentials: - self.credentials.validate() try: - ProfileConfig(**self.to_profile_info(serialize_credentials=True)) - except ValidationException as exc: - raise DbtProfileError(str(exc)) + if self.credentials: + self.credentials.to_dict(validate=True) + ProfileConfig.from_dict( + self.to_profile_info(serialize_credentials=True) + ) + except ValidationError as exc: + raise DbtProfileError(validator_error_message(exc)) from exc @staticmethod def _credentials_from_profile(profile, profile_name, target_name): @@ -172,12 +134,14 @@ def _credentials_from_profile(profile, profile_name, target_name): try: cls = load_plugin(typename) - credentials = cls(**profile) - except RuntimeException as e: + credentials = cls.from_dict(profile) + except (RuntimeException, ValidationError) as e: + msg = str(e) if isinstance(e, RuntimeException) else e.message raise DbtProfileError( 'Credentials in profile "{}", target "{}" invalid: {}' - .format(profile_name, target_name, str(e)) - ) + .format(profile_name, target_name, msg) + ) from e + return credentials @staticmethod @@ -222,7 +186,10 @@ def from_credentials(cls, credentials, threads, profile_name, target_name, :raises DbtProfileError: If the profile is invalid. :returns Profile: The new Profile object. """ + if user_cfg is None: + user_cfg = {} config = UserConfig.from_dict(user_cfg) + profile = cls( profile_name=profile_name, target_name=target_name, diff --git a/core/dbt/config/project.py b/core/dbt/config/project.py index 20feeb6b198..0c911c2aaec 100644 --- a/core/dbt/config/project.py +++ b/core/dbt/config/project.py @@ -10,7 +10,7 @@ from dbt.exceptions import DbtProjectError from dbt.exceptions import RecursionException from dbt.exceptions import SemverException -from dbt.exceptions import ValidationException +from dbt.exceptions import validator_error_message from dbt.exceptions import warn_or_error from dbt.semver import VersionSpecifier from dbt.semver import versions_compatible @@ -23,6 +23,8 @@ from dbt.contracts.project import Project as ProjectContract from dbt.contracts.project import PackageConfig +from hologram import ValidationError + from .renderer import ConfigRenderer @@ -118,9 +120,11 @@ def package_config_from_data(packages_data): packages_data = {'packages': []} try: - packages = PackageConfig(**packages_data) - except ValidationException as e: - raise DbtProjectError('Invalid package config: {}'.format(str(e))) + packages = PackageConfig.from_dict(packages_data) + except ValidationError as e: + raise DbtProjectError( + 'Invalid package config: {}'.format(validator_error_message(e)) + ) from e return packages @@ -215,9 +219,9 @@ def from_project_config(cls, project_dict, packages_dict=None): ) # just for validation. try: - ProjectContract(**project_dict) - except ValidationException as e: - raise DbtProjectError(str(e)) + ProjectContract.from_dict(project_dict) + except ValidationError as e: + raise DbtProjectError(validator_error_message(e)) from e # name/version are required in the Project definition, so we can assume # they are present @@ -254,9 +258,12 @@ def from_project_config(cls, project_dict, packages_dict=None): try: dbt_version = _parse_versions(dbt_raw_version) except SemverException as e: - raise DbtProjectError(str(e)) + raise DbtProjectError(str(e)) from e - packages = package_config_from_data(packages_dict) + try: + packages = package_config_from_data(packages_dict) + except ValidationError as e: + raise DbtProjectError(validator_error_message(e)) from e project = cls( project_name=name, @@ -330,14 +337,14 @@ def to_project_config(self, with_packages=False): ], }) if with_packages: - result.update(self.packages.serialize()) + result.update(self.packages.to_dict()) return result def validate(self): try: - ProjectContract(**self.to_project_config()) - except ValidationException as exc: - raise DbtProjectError(str(exc)) + ProjectContract.from_dict(self.to_project_config()) + except ValidationError as e: + raise DbtProjectError(validator_error_message(e)) from e @classmethod def from_project_root(cls, project_root, cli_vars): diff --git a/core/dbt/config/runtime.py b/core/dbt/config/runtime.py index 2cfce1993d3..2ca8c31db05 100644 --- a/core/dbt/config/runtime.py +++ b/core/dbt/config/runtime.py @@ -5,12 +5,14 @@ from dbt.utils import parse_cli_vars from dbt.contracts.project import Configuration from dbt.exceptions import DbtProjectError -from dbt.exceptions import ValidationException +from dbt.exceptions import validator_error_message from dbt.adapters.factory import get_relation_class_by_name from .profile import Profile from .project import Project +from hologram import ValidationError + class RuntimeConfig(Project, Profile): """The runtime configuration, as constructed from its components. There's a @@ -154,9 +156,9 @@ def validate(self): :raises DbtProjectError: If the configuration fails validation. """ try: - Configuration(**self.serialize()) - except ValidationException as e: - raise DbtProjectError(str(e)) + Configuration.from_dict(self.serialize()) + except ValidationError as e: + raise DbtProjectError(validator_error_message(e)) from e if getattr(self.args, 'version_check', False): self.validate_version() diff --git a/core/dbt/context/common.py b/core/dbt/context/common.py index b09245108b5..c4062be27d8 100644 --- a/core/dbt/context/common.py +++ b/core/dbt/context/common.py @@ -3,7 +3,6 @@ from dbt.adapters.factory import get_adapter from dbt.node_types import NodeType -from dbt.contracts.graph.parsed import ParsedMacro, ParsedNode from dbt.include.global_project import PACKAGES from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME @@ -225,23 +224,13 @@ def __init__(self, model, context, overrides): # precedence over context-based var definitions self.overrides = overrides - if isinstance(model, dict) and model.get('unique_id'): - local_vars = model.get('config', {}).get('vars', {}) - self.model_name = model.get('name') - elif isinstance(model, ParsedMacro): - local_vars = {} # macros have no config - self.model_name = model.name - elif isinstance(model, ParsedNode): - local_vars = model.config.get('vars', {}) - self.model_name = model.name - elif model is None: + if model is None: # during config parsing we have no model and no local vars self.model_name = '' local_vars = {} else: - # still used for wrapping - self.model_name = model.nice_name - local_vars = model.config.get('vars', {}) + self.model_name = model.name + local_vars = model.local_vars() self.local_vars = dbt.utils.merge(local_vars, overrides) @@ -278,7 +267,7 @@ def __call__(self, var_name, default=_VAR_NOTSET): def write(node, target_path, subdirectory): def fn(payload): - node['build_path'] = dbt.writer.write_node( + node.build_path = dbt.writer.write_node( node, target_path, subdirectory, payload) return '' @@ -383,9 +372,10 @@ def generate_base(model, model_dict, config, manifest, source_config, target_name = config.target_name target = config.to_profile_info() del target['credentials'] - target.update(config.credentials.serialize(with_aliases=True)) + target.update(config.credentials.to_dict(with_aliases=True)) target['type'] = config.credentials.type target.pop('pass', None) + target.pop('password', None) target['name'] = target_name adapter = get_adapter(config) @@ -404,14 +394,12 @@ def generate_base(model, model_dict, config, manifest, source_config, "Column": adapter.Column, }, "column": adapter.Column, - "config": provider.Config(model_dict, source_config), + "config": provider.Config(model, source_config), "database": config.credentials.database, "env_var": env_var, "exceptions": dbt.exceptions.wrapped_exports(model), "execute": provider.execute, "flags": dbt.flags, - # TODO: Do we have to leave this in? - "graph": manifest.to_flat_graph(), "load_agate_table": _build_load_agate_table(model), "log": log, "model": model_dict, @@ -433,8 +421,7 @@ def generate_base(model, model_dict, config, manifest, source_config, return context -def modify_generated_context(context, model, model_dict, config, manifest, - provider): +def modify_generated_context(context, model, config, manifest, provider): cli_var_overrides = config.cli_vars context = _add_tracking(context) @@ -445,8 +432,8 @@ def modify_generated_context(context, model, model_dict, config, manifest, context = _add_macros(context, model, manifest) - context["write"] = write(model_dict, config.target_path, 'run') - context["render"] = render(context, model_dict) + context["write"] = write(model, config.target_path, 'run') + context["render"] = render(context, model) context["var"] = provider.Var(model, context=context, overrides=cli_var_overrides) context['context'] = context @@ -462,12 +449,11 @@ def generate_execute_macro(model, config, manifest, provider): - 'schema' does not use any 'model' information - they can't be configured with config() directives """ - model_dict = model.serialize() + model_dict = model.to_dict() context = generate_base(model, model_dict, config, manifest, None, provider) - return modify_generated_context(context, model, model_dict, config, - manifest, provider) + return modify_generated_context(context, model, config, manifest, provider) def generate_model(model, config, manifest, source_config, provider): @@ -476,19 +462,20 @@ def generate_model(model, config, manifest, source_config, provider): source_config, provider) # operations (hooks) don't get a 'this' if model.resource_type != NodeType.Operation: - this = get_this_relation(context['adapter'], config, model_dict) + this = get_this_relation(context['adapter'], config, model) context['this'] = this # overwrite schema/database if we have them, and hooks + sql + # the hooks should come in as dicts, at least for the `run_hooks` macro + # TODO: do we have to preserve this as backwards a compatibility thing? context.update({ - 'schema': model.get('schema', context['schema']), - 'database': model.get('database', context['database']), - 'pre_hooks': model.config.get('pre-hook'), - 'post_hooks': model.config.get('post-hook'), - 'sql': model.get('injected_sql'), + 'schema': getattr(model, 'schema', context['schema']), + 'database': getattr(model, 'database', context['database']), + 'pre_hooks': [h.to_dict() for h in model.config.pre_hook], + 'post_hooks': [h.to_dict() for h in model.config.post_hook], + 'sql': getattr(model, 'injected_sql', None), }) - return modify_generated_context(context, model, model_dict, config, - manifest, provider) + return modify_generated_context(context, model, config, manifest, provider) def generate(model, config, manifest, source_config=None, provider=None): diff --git a/core/dbt/context/parser.py b/core/dbt/context/parser.py index 82508dac9d2..39b860ee39a 100644 --- a/core/dbt/context/parser.py +++ b/core/dbt/context/parser.py @@ -2,6 +2,7 @@ import dbt.context.common from dbt.adapters.factory import get_adapter +from dbt.contracts.graph.parsed import Docref def docs(unparsed, docrefs, column_name=None): @@ -14,17 +15,13 @@ def do_docs(*args): if len(args) == 2: doc_package_name = args[1] - docref = { - 'documentation_package': doc_package_name, - 'documentation_name': doc_name, - } - if column_name is not None: - docref['column_name'] = column_name - + docref = Docref(documentation_package=doc_package_name, + documentation_name=doc_name, + column_name=column_name) docrefs.append(docref) - # IDK - return True + # At parse time, nothing should care about what doc() returns + return '' return do_docs @@ -131,7 +128,7 @@ def generate(model, runtime_config, manifest, source_config): # have to acquire it. # In the future, it would be nice to lazily open the connection, as in some # projects it would be possible to parse without connecting to the db - with get_adapter(runtime_config).connection_named(model.get('name')): + with get_adapter(runtime_config).connection_named(model.name): return dbt.context.common.generate( model, runtime_config, manifest, source_config, Provider() ) diff --git a/core/dbt/context/runtime.py b/core/dbt/context/runtime.py index 0c5f1765311..065a99a1c58 100644 --- a/core/dbt/context/runtime.py +++ b/core/dbt/context/runtime.py @@ -50,7 +50,7 @@ def create_relation(self, target_model, name): class RefResolver(BaseRefResolver): def validate(self, resolved, args): - if resolved.unique_id not in self.model.depends_on.get('nodes'): + if resolved.unique_id not in self.model.depends_on.nodes: dbt.exceptions.ref_bad_context(self.model, args) def __call__(self, *args): @@ -83,6 +83,9 @@ def __call__(self, source_name, table_name): return self.Relation.create_from_source(target_source) +_MISSING = object() + + class Config: def __init__(self, model, source_config=None): self.model = model @@ -97,11 +100,20 @@ def set(self, name, value): def _validate(self, validator, value): validator(value) - def require(self, name, validator=None): - if name not in self.model['config']: + def _lookup(self, name, default=_MISSING): + config = self.model.config + + if hasattr(config, name): + return getattr(config, name) + elif name in config.extra: + return config.extra[name] + elif default is not _MISSING: + return default + else: dbt.exceptions.missing_config(self.model, name) - to_return = self.model['config'][name] + def require(self, name, validator=None): + to_return = self._lookup(name) if validator is not None: self._validate(validator, to_return) @@ -109,7 +121,7 @@ def require(self, name, validator=None): return to_return def get(self, name, validator=None, default=None): - to_return = self.model['config'].get(name, default) + to_return = self._lookup(name, default) if validator is not None and default is not None: self._validate(validator, to_return) diff --git a/core/dbt/contracts/connection.py b/core/dbt/contracts/connection.py index 866e6bf5dbb..b1ca4ddbc49 100644 --- a/core/dbt/contracts/connection.py +++ b/core/dbt/contracts/connection.py @@ -1,59 +1,54 @@ -from dbt.api.object import APIObject -from dbt.contracts.common import named_property - - -CONNECTION_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'type': { - 'type': 'string', - # valid python identifiers only - 'pattern': r'^[A-Za-z_][A-Za-z0-9_]+$', - }, - 'name': { - 'type': ['null', 'string'], - }, - 'state': { - 'enum': ['init', 'open', 'closed', 'fail'], - }, - 'transaction_open': { - 'type': 'boolean', - }, - # we can't serialize this so we can't require it as part of the - # contract. - # 'handle': { - # 'type': ['null', 'object'], - # }, - # credentials are validated separately by the adapter packages - 'credentials': { - 'description': ( - 'The credentials object here should match the connection type.' - ), - 'type': 'object', - 'additionalProperties': True, - } - }, - 'required': [ - 'type', 'name', 'state', 'transaction_open', 'credentials' - ], -} - - -class Connection(APIObject): - SCHEMA = CONNECTION_CONTRACT - - def __init__(self, credentials, *args, **kwargs): - # we can't serialize handles - self._handle = kwargs.pop('handle') - super().__init__(credentials=credentials.serialize(), *args, **kwargs) - # this will validate itself in its own __init__. - self._credentials = credentials +from hologram.helpers import StrEnum, NewPatternType, ExtensibleJsonSchemaMixin +from hologram import JsonSchemaMixin +from dbt.contracts.util import Replaceable + +from dataclasses import dataclass +from typing import Any, Optional + + +Identifier = NewPatternType('Identifier', r'^[A-Za-z_][A-Za-z0-9_]+$') + + +class ConnectionState(StrEnum): + INIT = 'init' + OPEN = 'open' + CLOSED = 'closed' + FAIL = 'fail' + + +@dataclass(init=False) +class Connection(ExtensibleJsonSchemaMixin, Replaceable): + type: Identifier + name: Optional[str] + _credentials: JsonSchemaMixin = None # underscore to prevent serialization + state: ConnectionState = ConnectionState.INIT + transaction_open: bool = False + _handle: Optional[Any] = None # underscore to prevent serialization + + def __init__( + self, + type: Identifier, + name: Optional[str], + credentials: JsonSchemaMixin, + state: ConnectionState = ConnectionState.INIT, + transaction_open: bool = False, + handle: Optional[Any] = None, + ) -> None: + self.type = type + self.name = name + self.credentials = credentials + self.state = state + self.transaction_open = transaction_open + self.handle = handle @property def credentials(self): return self._credentials + @credentials.setter + def credentials(self, value): + self._credentials = value + @property def handle(self): return self._handle @@ -61,10 +56,3 @@ def handle(self): @handle.setter def handle(self, value): self._handle = value - - name = named_property('name', 'The name of this connection') - state = named_property('state', 'The state of the connection') - transaction_open = named_property( - 'transaction_open', - 'True if there is an open transaction, False otherwise.' - ) diff --git a/core/dbt/contracts/graph/compiled.py b/core/dbt/contracts/graph/compiled.py index 104bcc9ca6c..a0db5b00d07 100644 --- a/core/dbt/contracts/graph/compiled.py +++ b/core/dbt/contracts/graph/compiled.py @@ -1,127 +1,89 @@ -from dbt.api import APIObject -from dbt.utils import deep_merge -from dbt.contracts.graph.parsed import PARSED_NODE_CONTRACT, \ - PARSED_MACRO_CONTRACT, ParsedNode +from dbt.contracts.graph.parsed import ( + ParsedNodeMixins, ParsedNode, ParsedSourceDefinition, + ParsedNodeDefaults, TestType, ParsedTestNode, TestConfig +) + +from dbt.contracts.util import Replaceable +from hologram import JsonSchemaMixin +from dataclasses import dataclass, field import sqlparse +from typing import Optional, List, Union -INJECTED_CTE_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'description': 'A single entry in the CTEs list', - 'properties': { - 'id': { - 'type': 'string', - 'description': 'The id of the CTE', - }, - 'sql': { - 'type': ['string', 'null'], - 'description': 'The compiled SQL of the CTE', - 'additionalProperties': True, - }, - }, - 'required': ['id', 'sql'], -} - - -COMPILED_NODE_CONTRACT = deep_merge( - PARSED_NODE_CONTRACT, - { - # TODO: when we add 'extra_ctes' back in, flip this back to False - 'additionalProperties': True, - 'properties': { - 'compiled': { - 'description': ( - 'This is true after the node has been compiled, but ctes ' - 'have not necessarily been injected into the node.' - ), - 'type': 'boolean' - }, - 'compiled_sql': { - 'type': ['string', 'null'], - }, - 'extra_ctes_injected': { - 'description': ( - 'This is true after extra ctes have been injected into ' - 'the compiled node.' - ), - 'type': 'boolean', - }, - 'extra_ctes': { - 'type': 'array', - 'description': 'The injected CTEs for a model', - 'items': INJECTED_CTE_CONTRACT, - }, - 'injected_sql': { - 'type': ['string', 'null'], - 'description': 'The SQL after CTEs have been injected', - }, - 'wrapped_sql': { - 'type': ['string', 'null'], - 'description': ( - 'The SQL after it has been wrapped (for tests, ' - 'operations, and analysis)' - ), - }, - }, - 'required': PARSED_NODE_CONTRACT['required'] + [ - 'compiled', 'compiled_sql', 'extra_ctes_injected', - 'injected_sql', 'extra_ctes' - ] - } -) +@dataclass +class InjectedCTE(JsonSchemaMixin, Replaceable): + id: str + sql: Optional[str] = None + +# for some frustrating reason, we can't subclass from ParsedNode directly, +# or typing.Union will flatten CompiledNode+ParsedNode into just ParsedNode. +# TODO: understand that issue and come up with some way for these two to share +# logic + + +@dataclass +class CompiledNodeDefaults(ParsedNodeDefaults, ParsedNodeMixins): + compiled: bool = False + compiled_sql: Optional[str] = None + extra_ctes_injected: bool = False + extra_ctes: List[InjectedCTE] = field(default_factory=list) + injected_sql: Optional[str] = None + wrapped_sql: Optional[str] = None + + def prepend_ctes(self, prepended_ctes: List[InjectedCTE]): + self.extra_ctes_injected = True + self.extra_ctes = prepended_ctes + self.injected_sql = _inject_ctes_into_sql( + self.compiled_sql, + prepended_ctes, + ) + self.validate(self.to_dict()) + + def set_cte(self, cte_id: str, sql: str): + """This is the equivalent of what self.extra_ctes[cte_id] = sql would + do if extra_ctes were an OrderedDict + """ + for cte in self.extra_ctes: + if cte.id == cte_id: + cte.sql = sql + break + else: + self.extra_ctes.append(InjectedCTE(id=cte_id, sql=sql)) + + +@dataclass +class CompiledNode(CompiledNodeDefaults): + index: Optional[int] = None + + @classmethod + def from_parsed_node(cls, parsed, **kwargs): + dct = parsed.to_dict() + dct.update(kwargs) + return cls.from_dict(dct) + + +@dataclass +class CompiledTestNode(CompiledNodeDefaults): + resource_type: TestType + column_name: Optional[str] = None + config: TestConfig = field(default_factory=TestConfig) -COMPILED_NODES_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'description': ( - 'A collection of the compiled nodes, stored by their unique IDs.' - ), - 'patternProperties': { - '.*': COMPILED_NODE_CONTRACT - }, -} - - -COMPILED_MACRO_CONTRACT = PARSED_MACRO_CONTRACT - - -COMPILED_MACROS_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'description': ( - 'A collection of the compiled macros, stored by their unique IDs.' - ), - 'patternProperties': { - '.*': COMPILED_MACRO_CONTRACT - }, -} - - -COMPILED_GRAPH_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'description': ( - 'The full compiled graph, with both the required nodes and required ' - 'macros.' - ), - 'properties': { - 'nodes': COMPILED_NODES_CONTRACT, - 'macros': COMPILED_MACROS_CONTRACT, - }, - 'required': ['nodes', 'macros'], -} - - -def _inject_ctes_into_sql(sql, ctes): - """ - `ctes` is a dict of CTEs in the form: - { - "cte_id_1": "__dbt__CTE__ephemeral as (select * from table)", - "cte_id_2": "__dbt__CTE__events as (select id, type from events)" - } +def _inject_ctes_into_sql(sql: str, ctes: List[InjectedCTE]) -> str: + """ + `ctes` is a list of InjectedCTEs like: + + [ + InjectedCTE( + id="cte_id_1", + sql="__dbt__CTE__ephemeral as (select * from table)", + ), + InjectedCTE( + id="cte_id_2", + sql="__dbt__CTE__events as (select id, type from events)", + ), + ] Given `sql` like: @@ -161,74 +123,18 @@ def _inject_ctes_into_sql(sql, ctes): token = sqlparse.sql.Token( sqlparse.tokens.Keyword, - ", ".join(c['sql'] for c in ctes) + ", ".join(c.sql for c in ctes) ) parsed.insert_after(with_stmt, token) return str(parsed) -class CompiledNode(ParsedNode): - SCHEMA = COMPILED_NODE_CONTRACT - - def prepend_ctes(self, prepended_ctes): - self._contents['extra_ctes_injected'] = True - self._contents['extra_ctes'] = prepended_ctes - self._contents['injected_sql'] = _inject_ctes_into_sql( - self.compiled_sql, - prepended_ctes - ) - self.validate() - - @property - def extra_ctes_injected(self): - return self._contents.get('extra_ctes_injected') - - @property - def extra_ctes(self): - return self._contents.get('extra_ctes') - - @property - def compiled(self): - return self._contents.get('compiled') - - @compiled.setter - def compiled(self, value): - self._contents['compiled'] = value - - @property - def injected_sql(self): - return self._contents.get('injected_sql') - - @property - def compiled_sql(self): - return self._contents.get('compiled_sql') - - @compiled_sql.setter - def compiled_sql(self, value): - self._contents['compiled_sql'] = value - - @property - def wrapped_sql(self): - return self._contents.get('wrapped_sql') - - @wrapped_sql.setter - def wrapped_sql(self, value): - self._contents['wrapped_sql'] = value - - def set_cte(self, cte_id, sql): - """This is the equivalent of what self.extra_ctes[cte_id] = sql would - do if extra_ctes were an OrderedDict - """ - for cte in self.extra_ctes: - if cte['id'] == cte_id: - cte['sql'] = sql - break - else: - self.extra_ctes.append( - {'id': cte_id, 'sql': sql} - ) - - -class CompiledGraph(APIObject): - SCHEMA = COMPILED_GRAPH_CONTRACT +# We allow either parsed or compiled nodes, or parsed sources, as some +# 'compile()' calls in the runner actually just return the original parsed +# node they were given. +CompileResultNode = Union[ + CompiledNode, ParsedNode, + CompiledTestNode, ParsedTestNode, + ParsedSourceDefinition, +] diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 32c19df7bd7..568759645e1 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -1,144 +1,30 @@ -from dbt.api import APIObject -from dbt.contracts.graph.parsed import PARSED_NODE_CONTRACT, \ - PARSED_MACRO_CONTRACT, PARSED_DOCUMENTATION_CONTRACT, \ - PARSED_SOURCE_DEFINITION_CONTRACT -from dbt.contracts.graph.compiled import COMPILED_NODE_CONTRACT, CompiledNode +from dbt.contracts.graph.parsed import ParsedNode, ParsedMacro, \ + ParsedDocumentation +from dbt.contracts.graph.compiled import CompileResultNode +from dbt.contracts.util import Writable, Replaceable +from dbt.config import Project from dbt.exceptions import raise_duplicate_resource_name from dbt.node_types import NodeType from dbt.logger import GLOBAL_LOGGER as logger from dbt import tracking import dbt.utils -# We allow either parsed or compiled nodes, or parsed sources, as some -# 'compile()' calls in the runner actually just return the original parsed -# node they were given. -COMPILE_RESULT_NODE_CONTRACT = { - 'anyOf': [ - PARSED_NODE_CONTRACT, - COMPILED_NODE_CONTRACT, - PARSED_SOURCE_DEFINITION_CONTRACT, - ] -} - - -COMPILE_RESULT_NODES_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'description': ( - 'A collection of the parsed nodes, stored by their unique IDs.' - ), - 'patternProperties': { - '.*': COMPILE_RESULT_NODE_CONTRACT - }, -} - - -PARSED_MACROS_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'description': ( - 'A collection of the parsed macros, stored by their unique IDs.' - ), - 'patternProperties': { - '.*': PARSED_MACRO_CONTRACT - }, -} - - -PARSED_DOCUMENTATIONS_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'description': ( - 'A collection of the parsed docs, stored by their uniqe IDs.' - ), - 'patternProperties': { - '.*': PARSED_DOCUMENTATION_CONTRACT, - }, -} - - -NODE_EDGE_MAP = { - 'type': 'object', - 'additionalProperties': False, - 'description': 'A map of node relationships', - 'patternProperties': { - '.*': { - 'type': 'array', - 'items': { - 'type': 'string', - 'description': 'A node name', - } - } - } -} - - -PARSED_MANIFEST_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'description': ( - 'The full parsed manifest of the graph, with both the required nodes' - ' and required macros.' - ), - 'properties': { - 'nodes': COMPILE_RESULT_NODES_CONTRACT, - 'macros': PARSED_MACROS_CONTRACT, - 'docs': PARSED_DOCUMENTATIONS_CONTRACT, - 'disabled': { - 'type': 'array', - 'items': PARSED_NODE_CONTRACT, - 'description': 'An array of disabled nodes', - }, - 'generated_at': { - 'type': 'string', - 'format': 'date-time', - 'description': ( - 'The time at which the manifest was generated' - ), - }, - 'parent_map': NODE_EDGE_MAP, - 'child_map': NODE_EDGE_MAP, - 'metadata': { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'project_id': { - 'type': ('string', 'null'), - 'description': ( - 'The anonymized ID of the project. Persists as long ' - 'as the project name stays the same.' - ), - 'pattern': '[0-9a-f]{32}', - }, - 'user_id': { - 'type': ('string', 'null'), - 'description': ( - 'The user ID assigned by dbt. Persists per-user as ' - 'long as the user cookie file remains in place.' - ), - 'pattern': ( - '[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-' - '[0-9a-f]{12}' - ), - }, - 'send_anonymous_usage_stats': { - 'type': ('boolean', 'null'), - 'description': ( - 'Whether or not to send anonymized usage statistics.' - ), - }, - }, - 'required': [ - 'project_id', 'user_id', 'send_anonymous_usage_stats', - ], - }, - }, - 'required': ['nodes', 'macros', 'docs', 'generated_at', 'metadata'], -} +from hologram import JsonSchemaMixin + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Dict, List, Optional +from uuid import UUID -class CompileResultNode(CompiledNode): - SCHEMA = COMPILE_RESULT_NODE_CONTRACT +NodeEdgeMap = Dict[str, List[str]] + + +@dataclass +class ManifestMetadata(JsonSchemaMixin, Replaceable): + project_id: Optional[str] + user_id: Optional[UUID] + send_anonymous_usage_stats: Optional[bool] def _sort_values(dct): @@ -163,32 +49,39 @@ def build_edges(nodes): return _sort_values(forward_edges), _sort_values(backward_edges) -class Manifest(APIObject): - SCHEMA = PARSED_MANIFEST_CONTRACT +def _deepcopy(value): + return value.from_dict(value.to_dict()) + + +@dataclass(init=False) +class Manifest: """The manifest for the full graph, after parsing and during compilation. - Nodes may be either ParsedNodes or CompiledNodes or a mix, depending upon - the current state of the compiler. Macros will always be ParsedMacros and - docs will always be ParsedDocumentations. """ - def __init__(self, nodes, macros, docs, generated_at, disabled, - config=None): - """The constructor. nodes and macros are dictionaries mapping unique - IDs to ParsedNode/CompiledNode and ParsedMacro objects, respectively. - docs is a dictionary mapping unique IDs to ParsedDocumentation objects. - generated_at is a text timestamp in RFC 3339 format. - disabled is a list of disabled FQNs (as strings). - """ - metadata = self.get_metadata(config) + nodes: Dict[str, CompileResultNode] + macros: Dict[str, ParsedMacro] + docs: Dict[str, ParsedDocumentation] + generated_at: datetime + disabled: List[ParsedNode] + metadata: ManifestMetadata = field(init=False) + + def __init__( + self, + nodes: Dict[str, CompileResultNode], + macros: Dict[str, ParsedMacro], + docs: Dict[str, ParsedDocumentation], + generated_at: datetime, + disabled: List[ParsedNode], + config: Optional[Project] = None + ) -> None: + self.metadata = self.get_metadata(config) self.nodes = nodes self.macros = macros self.docs = docs self.generated_at = generated_at - self.metadata = metadata self.disabled = disabled - super().__init__() @staticmethod - def get_metadata(config): + def get_metadata(config: Optional[Project]) -> ManifestMetadata: project_id = None user_id = None send_anonymous_usage_stats = None @@ -200,11 +93,11 @@ def get_metadata(config): user_id = tracking.active_user.id send_anonymous_usage_stats = not tracking.active_user.do_not_track - return { - 'project_id': project_id, - 'user_id': user_id, - 'send_anonymous_usage_stats': send_anonymous_usage_stats, - } + return ManifestMetadata( + project_id=project_id, + user_id=user_id, + send_anonymous_usage_stats=send_anonymous_usage_stats, + ) def serialize(self): """Convert the parsed manifest to a nested dict structure that we can @@ -213,14 +106,14 @@ def serialize(self): forward_edges, backward_edges = build_edges(self.nodes.values()) return { - 'nodes': {k: v.serialize() for k, v in self.nodes.items()}, - 'macros': {k: v.serialize() for k, v in self.macros.items()}, - 'docs': {k: v.serialize() for k, v in self.docs.items()}, + 'nodes': {k: v.to_dict() for k, v in self.nodes.items()}, + 'macros': {k: v.to_dict() for k, v in self.macros.items()}, + 'docs': {k: v.to_dict() for k, v in self.docs.items()}, 'parent_map': backward_edges, 'child_map': forward_edges, 'generated_at': self.generated_at, 'metadata': self.metadata, - 'disabled': [v.serialize() for v in self.disabled], + 'disabled': [v.to_dict() for v in self.disabled], } def find_disabled_by_name(self, name, package=None): @@ -344,7 +237,7 @@ def predicate(model): return self._model_matches_schema_and_table(schema, table, model) matching = list(self._filter_subgraph(self.nodes, predicate)) - return [match.get('unique_id') for match in matching] + return [match.unique_id for match in matching] def add_nodes(self, new_nodes): """Add the given dict of new nodes to the manifest.""" @@ -391,7 +284,10 @@ def to_flat_graph(self): Ideally in the future we won't need to have this method. """ return { - 'nodes': {k: v.to_shallow_dict() for k, v in self.nodes.items()}, + 'nodes': { + k: v.to_dict(omit_none=False) + for k, v in self.nodes.items() + }, 'macros': self.macros, } @@ -412,10 +308,61 @@ def get_used_databases(self): def deepcopy(self, config=None): return Manifest( - nodes={k: v.incorporate() for k, v in self.nodes.items()}, - macros={k: v.incorporate() for k, v in self.macros.items()}, - docs={k: v.incorporate() for k, v in self.docs.items()}, + nodes={k: _deepcopy(v) for k, v in self.nodes.items()}, + macros={k: _deepcopy(v) for k, v in self.macros.items()}, + docs={k: _deepcopy(v) for k, v in self.docs.items()}, generated_at=self.generated_at, - disabled=[n.incorporate() for n in self.disabled], + disabled=[_deepcopy(n) for n in self.disabled], config=config ) + + def writable_manifest(self): + forward_edges, backward_edges = build_edges(self.nodes.values()) + return WritableManifest( + nodes=self.nodes, + macros=self.macros, + docs=self.docs, + generated_at=self.generated_at, + metadata=self.metadata, + disabled=self.disabled, + child_map=forward_edges, + parent_map=backward_edges + ) + + @classmethod + def from_writable_manifest(cls, writable): + self = cls( + nodes=writable.nodes, + macros=writable.macros, + docs=writable.docs, + generated_at=writable.generated_at, + metadata=writable.metadata, + disabled=writable.disabled, + ) + self.metadata = writable.metadata + return self + + @classmethod + def from_dict(cls, data, validate=True): + writable = WritableManifest.from_dict(data=data, validate=validate) + return cls.from_writable_manifest(writable) + + def to_dict(self, omit_none=True, validate=False): + return self.writable_manifest().to_dict( + omit_none=omit_none, validate=validate + ) + + def write(self, path): + self.writable_manifest().write(path) + + +@dataclass +class WritableManifest(JsonSchemaMixin, Writable): + nodes: Dict[str, CompileResultNode] + macros: Dict[str, ParsedMacro] + docs: Dict[str, ParsedDocumentation] + disabled: Optional[List[ParsedNode]] + generated_at: datetime + parent_map: Optional[NodeEdgeMap] + child_map: Optional[NodeEdgeMap] + metadata: ManifestMetadata diff --git a/core/dbt/contracts/graph/parsed.py b/core/dbt/contracts/graph/parsed.py index 66a833fc647..2bd63569556 100644 --- a/core/dbt/contracts/graph/parsed.py +++ b/core/dbt/contracts/graph/parsed.py @@ -1,366 +1,144 @@ -from dbt.api import APIObject -from dbt.utils import deep_merge -from dbt.node_types import NodeType +from dataclasses import dataclass, field +from typing import Optional, Union, List, Dict, Any, Type, Tuple + +from hologram import JsonSchemaMixin +from hologram.helpers import StrEnum, NewPatternType, ExtensibleJsonSchemaMixin import dbt.clients.jinja +from dbt.contracts.graph.unparsed import ( + UnparsedNode, UnparsedMacro, UnparsedDocumentationFile, Quoting, + UnparsedBaseNode, FreshnessThreshold +) +from dbt.contracts.util import Replaceable +from dbt.logger import GLOBAL_LOGGER as logger # noqa +from dbt.node_types import ( + NodeType, SourceType, SnapshotType, MacroType, TestType +) -from dbt.contracts.graph.unparsed import UNPARSED_NODE_CONTRACT, \ - UNPARSED_MACRO_CONTRACT, UNPARSED_DOCUMENTATION_FILE_CONTRACT, \ - UNPARSED_BASE_CONTRACT, TIME_CONTRACT -from dbt.logger import GLOBAL_LOGGER as logger # noqa +class TimestampStrategy(StrEnum): + Timestamp = 'timestamp' -# TODO: which of these do we _really_ support? or is it both? -HOOK_CONTRACT = { - 'anyOf': [ - { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'sql': { - 'type': 'string', - }, - 'transaction': { - 'type': 'boolean', - }, - 'index': { - 'type': 'integer', - } - }, - 'required': ['sql', 'transaction'], - }, - { - 'type': 'string', - }, - ], -} - - -CONFIG_CONTRACT = { - 'type': 'object', - 'additionalProperties': True, - 'properties': { - 'enabled': { - 'type': 'boolean', - }, - 'materialized': { - 'type': 'string', - }, - 'persist_docs': { - 'type': 'object', - 'additionalProperties': True, - }, - 'post-hook': { - 'type': 'array', - 'items': HOOK_CONTRACT, - }, - 'pre-hook': { - 'type': 'array', - 'items': HOOK_CONTRACT, - }, - 'vars': { - 'type': 'object', - 'additionalProperties': True, - }, - 'quoting': { - 'type': 'object', - 'additionalProperties': True, - }, - 'column_types': { - 'type': 'object', - 'additionalProperties': True, - }, - 'tags': { - 'anyOf': [ - { - 'type': 'array', - 'items': { - 'type': 'string' - }, - }, - { - 'type': 'string' - } - ] - }, - 'severity': { - 'type': 'string', - 'pattern': '([eE][rR][rR][oO][rR]|[wW][aA][rR][nN])', - }, - }, - 'required': [ - 'enabled', 'materialized', 'post-hook', 'pre-hook', 'vars', - 'quoting', 'column_types', 'tags', 'persist_docs' - ] -} - - -# Note that description must be present, but may be empty. -COLUMN_INFO_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'description': 'Information about a single column in a model', - 'properties': { - 'name': { - 'type': 'string', - 'description': 'The column name', - }, - 'description': { - 'type': 'string', - 'description': 'A description of the column', - }, - }, - 'required': ['name', 'description'], -} +class CheckStrategy(StrEnum): + Check = 'check' + + +class All(StrEnum): + All = 'all' + + +@dataclass +class Hook(JsonSchemaMixin, Replaceable): + sql: str + transaction: bool = True + index: Optional[int] = None + + +def insensitive_patterns(*patterns: str): + lowercased = [] + for pattern in patterns: + lowercased.append( + ''.join('[{}{}]'.format(s.upper(), s.lower()) for s in pattern) + ) + return '^({})$'.format('|'.join(lowercased)) + + +Severity = NewPatternType('Severity', insensitive_patterns('warn', 'error')) + + +@dataclass +class NodeConfig(ExtensibleJsonSchemaMixin, Replaceable): + enabled: bool = True + materialized: str = 'view' + persist_docs: Dict[str, Any] = field(default_factory=dict) + post_hook: List[Hook] = field(default_factory=list) + pre_hook: List[Hook] = field(default_factory=list) + vars: Dict[str, Any] = field(default_factory=dict) + quoting: Dict[str, Any] = field(default_factory=dict) + column_types: Dict[str, Any] = field(default_factory=dict) + tags: Union[List[str], str] = field(default_factory=list) + _extra: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if isinstance(self.tags, str): + self.tags = [self.tags] + + @property + def extra(self): + return self._extra + + @classmethod + def from_dict(cls, data, validate=True): + self = super().from_dict(data=data, validate=validate) + keys = self.to_dict(validate=False, omit_none=False) + for key, value in data.items(): + if key not in keys: + self._extra[key] = value + return self + + def to_dict(self, omit_none=True, validate=False): + data = super().to_dict(omit_none=omit_none, validate=validate) + data.update(self._extra) + return data + + def replace(self, **kwargs): + dct = self.to_dict(omit_none=False, validate=False) + dct.update(kwargs) + return self.from_dict(dct) + + @classmethod + def field_mapping(cls): + return {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'} + + +@dataclass +class ColumnInfo(JsonSchemaMixin, Replaceable): + name: str + description: str = '' # Docrefs are not quite like regular references, as they indicate what they # apply to as well as what they are referring to (so the doc package + doc # name, but also the column name if relevant). This is because column # descriptions are rendered separately from their models. -DOCREF_CONTRACT = { - 'type': 'object', - 'properties': { - 'documentation_name': { - 'type': 'string', - 'description': 'The name of the documentation block referred to', - }, - 'documentation_package': { - 'type': 'string', - 'description': ( - 'If provided, the documentation package name referred to' - ), - }, - 'column_name': { - 'type': 'string', - 'description': ( - 'If the documentation refers to a column instead of the ' - 'model, the column name should be set' - ), - }, - }, - 'required': ['documentation_name', 'documentation_package'] -} - - -HAS_FQN_CONTRACT = { - 'properties': { - 'fqn': { - 'type': 'array', - 'items': { - 'type': 'string', - } - }, - }, - 'required': ['fqn'], -} - - -HAS_UNIQUE_ID_CONTRACT = { - 'properties': { - 'unique_id': { - 'type': 'string', - 'minLength': 1, - }, - }, - 'required': ['unique_id'], -} - -CAN_REF_CONTRACT = { - 'properties': { - 'refs': { - 'type': 'array', - 'items': { - 'type': 'array', - 'description': ( - 'The list of arguments passed to a single ref call.' - ), - }, - 'description': ( - 'The list of call arguments, one list of arguments per ' - 'call.' - ) - }, - 'sources': { - 'type': 'array', - 'items': { - 'type': 'array', - 'description': ( - 'The list of arguments passed to a single source call.' - ), - }, - 'description': ( - 'The list of call arguments, one list of arguments per ' - 'call.' - ) - }, - 'depends_on': { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'nodes': { - 'type': 'array', - 'items': { - 'type': 'string', - 'minLength': 1, - 'description': ( - 'A node unique ID that this depends on.' - ) - } - }, - 'macros': { - 'type': 'array', - 'items': { - 'type': 'string', - 'minLength': 1, - 'description': ( - 'A macro unique ID that this depends on.' - ) - } - }, - }, - 'description': ( - 'A list of unique IDs for nodes and macros that this ' - 'node depends upon.' - ), - 'required': ['nodes', 'macros'], - }, - }, - 'required': ['refs', 'sources', 'depends_on'], -} - - -HAS_DOCREFS_CONTRACT = { - 'properties': { - 'docrefs': { - 'type': 'array', - 'items': DOCREF_CONTRACT, - }, - }, -} - - -HAS_DESCRIPTION_CONTRACT = { - 'properties': { - 'description': { - 'type': 'string', - 'description': 'A user-supplied description of the model', - }, - 'columns': { - 'type': 'object', - 'properties': { - '.*': COLUMN_INFO_CONTRACT, - }, - }, - }, - 'required': ['description', 'columns'], -} - -# does this belong inside another contract? -HAS_CONFIG_CONTRACT = { - 'properties': { - 'config': CONFIG_CONTRACT, - }, - 'required': ['config'], -} - - -COLUMN_TEST_CONTRACT = { - 'properties': { - 'column_name': { - 'type': 'string', - 'description': ( - 'In tests parsed from a v2 schema, the column the test is ' - 'associated with (if there is one)' - ) - }, - } -} - - -HAS_RELATION_METADATA_CONTRACT = { - 'properties': { - 'database': { - 'type': 'string', - 'description': ( - 'The actual database string that this will build into.' - ) - }, - 'schema': { - 'type': 'string', - 'description': ( - 'The actual schema string that this will build into.' - ) - }, - }, - 'required': ['database', 'schema'], -} - - -PARSED_NODE_CONTRACT = deep_merge( - UNPARSED_NODE_CONTRACT, - HAS_UNIQUE_ID_CONTRACT, - HAS_FQN_CONTRACT, - CAN_REF_CONTRACT, - HAS_DOCREFS_CONTRACT, - HAS_DESCRIPTION_CONTRACT, - HAS_CONFIG_CONTRACT, - COLUMN_TEST_CONTRACT, - HAS_RELATION_METADATA_CONTRACT, - { - 'properties': { - 'alias': { - 'type': 'string', - 'description': ( - 'The name of the relation that this will build into' - ) - }, - # TODO: move this into a class property. - 'empty': { - 'type': 'boolean', - 'description': 'True if the SQL is empty', - }, - 'tags': { - 'type': 'array', - 'items': { - 'type': 'string', - } - }, - # this is really nodes-only - 'patch_path': { - 'type': 'string', - 'description': ( - 'The path to the patch source if the node was patched' - ), - }, - 'build_path': { - 'type': 'string', - 'description': ( - 'In seeds, the path to the source file used during build.' - ), - }, - }, - 'required': ['empty', 'tags', 'alias'], - } -) +@dataclass +class Docref(JsonSchemaMixin, Replaceable): + documentation_name: str + documentation_package: str + column_name: Optional[str] = None -class ParsedNode(APIObject): - SCHEMA = PARSED_NODE_CONTRACT +@dataclass +class HasFqn(JsonSchemaMixin, Replaceable): + fqn: List[str] + + +@dataclass +class HasUniqueID(JsonSchemaMixin, Replaceable): + unique_id: str + + +@dataclass +class DependsOn(JsonSchemaMixin, Replaceable): + nodes: List[str] = field(default_factory=list) + macros: List[str] = field(default_factory=list) + + +@dataclass +class HasRelationMetadata(JsonSchemaMixin, Replaceable): + database: str + schema: str - def __init__(self, **kwargs): - kwargs.setdefault('columns', {}) - kwargs.setdefault('description', '') - super().__init__(**kwargs) +class ParsedNodeMixins: @property def is_refable(self): return self.resource_type in NodeType.refable() @property def is_ephemeral(self): - return self.get('config', {}).get('materialized') == 'ephemeral' + return self.config.materialized == 'ephemeral' @property def is_ephemeral_model(self): @@ -368,419 +146,256 @@ def is_ephemeral_model(self): @property def depends_on_nodes(self): - """Return the list of node IDs that this node depends on.""" - return self.depends_on['nodes'] - - def to_dict(self): - ret = self.serialize() - return ret - - def to_shallow_dict(self): - ret = self._contents.copy() - return ret + return self.depends_on.nodes def patch(self, patch): """Given a ParsedNodePatch, add the new information to the node.""" # explicitly pick out the parts to update so we don't inadvertently # step on the model name or anything - self._contents.update({ - 'patch_path': patch.original_file_path, - 'description': patch.description, - 'columns': patch.columns, - 'docrefs': patch.docrefs, - }) - # patches always trigger re-validation - self.validate() + self.patch_path = patch.original_file_path + self.description = patch.description + self.columns = patch.columns + self.docrefs = patch.docrefs + # patches should always trigger re-validation + self.to_dict(validate=True) def get_materialization(self): - return self.config.get('materialized') - - @property - def build_path(self): - return self._contents.get('build_path') + return self.config.materialized + + def local_vars(self): + return self.config.vars + + +@dataclass +class ParsedNodeMandatory( + UnparsedNode, + HasUniqueID, + HasFqn, + HasRelationMetadata, +): + alias: str + + +@dataclass +class ParsedNodeDefaults(ParsedNodeMandatory): + config: NodeConfig = field(default_factory=NodeConfig) + tags: List[str] = field(default_factory=list) + refs: List[List[Any]] = field(default_factory=list) + sources: List[List[Any]] = field(default_factory=list) + depends_on: DependsOn = field(default_factory=DependsOn) + docrefs: List[Docref] = field(default_factory=list) + description: str = field(default='') + columns: Dict[str, ColumnInfo] = field(default_factory=dict) + patch_path: Optional[str] = None + build_path: Optional[str] = None + + +# TODO(jeb): hooks should get their own parsed type instead of including +# index everywhere! +@dataclass +class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins): + index: Optional[int] = None + + +@dataclass +class TestConfig(NodeConfig): + severity: Severity = 'error' + + +@dataclass +class ParsedTestNode(ParsedNodeDefaults, ParsedNodeMixins): + resource_type: TestType + column_name: Optional[str] = None + config: TestConfig = field(default_factory=TestConfig) + + +@dataclass(init=False) +class _SnapshotConfig(NodeConfig): + unique_key: str + target_schema: str + target_database: str + + def __init__( + self, + unique_key: str, + target_database: str, + target_schema: str, + **kwargs + ) -> None: + self.target_database = target_database + self.target_schema = target_schema + self.unique_key = unique_key + super().__init__(**kwargs) - @build_path.setter - def build_path(self, value): - self._contents['build_path'] = value - @property - def database(self): - return self._contents['database'] +@dataclass(init=False) +class GenericSnapshotConfig(_SnapshotConfig): + strategy: str - @database.setter - def database(self, value): - self._contents['database'] = value + def __init__(self, strategy: str, **kwargs) -> None: + self.strategy = strategy + super().__init__(**kwargs) - @property - def schema(self): - return self._contents['schema'] - @schema.setter - def schema(self, value): - self._contents['schema'] = value +@dataclass(init=False) +class TimestampSnapshotConfig(_SnapshotConfig): + strategy: TimestampStrategy + updated_at: str - @property - def alias(self): - return self._contents['alias'] + def __init__( + self, strategy: TimestampStrategy, updated_at: str, **kwargs + ) -> None: + self.strategy = strategy + self.updated_at = updated_at + super().__init__(**kwargs) - @alias.setter - def alias(self, value): - self._contents['alias'] = value - @property - def config(self): - return self._contents['config'] - - @config.setter - def config(self, value): - self._contents['config'] = value - - -SNAPSHOT_CONFIG_CONTRACT = { - 'properties': { - 'target_database': { - 'type': 'string', - }, - 'target_schema': { - 'type': 'string', - }, - 'unique_key': { - 'type': 'string', - }, - 'anyOf': [ - { - 'properties': { - 'strategy': { - 'enum': ['timestamp'], - }, - 'updated_at': { - 'type': 'string', - 'description': ( - 'The column name with the timestamp to compare' - ), - }, - }, - 'required': ['updated_at'], - }, - { - 'properties': { - 'strategy': { - 'enum': ['check'], - }, - 'check_cols': { - 'oneOf': [ - { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': 'The columns to check', - 'minLength': 1, - }, - { - 'enum': ['all'], - 'description': 'Check all columns', - }, - ], - }, - }, - 'required': ['check_cols'], - } - ] - }, - 'required': [ - 'target_schema', 'unique_key', 'strategy', - ], -} - - -PARSED_SNAPSHOT_NODE_CONTRACT = deep_merge( - PARSED_NODE_CONTRACT, - { - 'properties': { - 'config': SNAPSHOT_CONFIG_CONTRACT, - 'resource_type': { - 'enum': [NodeType.Snapshot], - }, - }, - } -) +@dataclass(init=False) +class CheckSnapshotConfig(_SnapshotConfig): + strategy: CheckStrategy + # TODO: is there a way to get this to accept tuples of strings? Adding + # `Tuple[str, ...]` to the list of types results in this: + # ['email'] is valid under each of {'type': 'array', 'items': + # {'type': 'string'}}, {'type': 'array', 'items': {'type': 'string'}} + # but without it, parsing gets upset about values like `('email',)` + # maybe hologram itself should support this behavior? It's not like tuples + # are meaningful in json + check_cols: Union[All, List[str]] + + def __init__( + self, strategy: CheckStrategy, check_cols: Union[All, List[str]], + **kwargs + ) -> None: + self.strategy = strategy + self.check_cols = check_cols + super().__init__(**kwargs) +@dataclass +class IntermediateSnapshotNode(ParsedNode): + # at an intermediate stage in parsing, where we've built something better + # than an unparsed node for rendering in parse mode, it's pretty possible + # that we won't have critical snapshot-related information that is only + # defined in config blocks. To fix that, we have an intermediate type that + # uses a regular node config, which the snapshot parser will then convert + # into a full ParsedSnapshotNode after rendering. + resource_type: SnapshotType + + +def _create_if_else_chain( + key: str, + criteria: List[Tuple[str, Type[JsonSchemaMixin]]], + default: Type[JsonSchemaMixin] +) -> dict: + """Mutate a given schema key that contains a 'oneOf' to instead be an + 'if-then-else' chain. This results is much better/more consistent errors + from jsonschema. + """ + result = schema = {} + criteria = criteria[:] + while criteria: + if_clause, then_clause = criteria.pop() + schema['if'] = {'properties': { + key: {'enum': [if_clause]} + }} + schema['then'] = then_clause.json_schema() + schema['else'] = {} + schema = schema['else'] + schema.update(default.json_schema()) + return result + + +@dataclass class ParsedSnapshotNode(ParsedNode): - SCHEMA = PARSED_SNAPSHOT_NODE_CONTRACT + resource_type: SnapshotType + config: Union[ + CheckSnapshotConfig, + TimestampSnapshotConfig, + GenericSnapshotConfig, + ] + + @classmethod + def json_schema(cls): + schema = super().json_schema() + + # mess with config + configs = [ + (str(CheckStrategy.Check), CheckSnapshotConfig), + (str(TimestampStrategy.Timestamp), TimestampSnapshotConfig), + ] + + schema['properties']['config'] = _create_if_else_chain( + 'strategy', configs, GenericSnapshotConfig + ) + return schema # The parsed node update is only the 'patch', not the test. The test became a # regular parsed node. Note that description and columns must be present, but # may be empty. -PARSED_NODE_PATCH_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'description': 'A collection of values that can be set on a node', - 'properties': { - 'name': { - 'type': 'string', - 'description': 'The name of the node this modifies', - }, - 'description': { - 'type': 'string', - 'description': 'The description of the node to add', - }, - 'original_file_path': { - 'type': 'string', - 'description': ( - 'Relative path to the originating file path for the patch ' - 'from the project root' - ), - }, - 'columns': { - 'type': 'object', - 'properties': { - '.*': COLUMN_INFO_CONTRACT, - } - }, - 'docrefs': { - 'type': 'array', - 'items': DOCREF_CONTRACT, - } - }, - 'required': [ - 'name', 'original_file_path', 'description', 'columns', 'docrefs' - ], -} - - -class ParsedNodePatch(APIObject): - SCHEMA = PARSED_NODE_PATCH_CONTRACT - - -PARSED_MACRO_CONTRACT = deep_merge( - UNPARSED_MACRO_CONTRACT, - { - # This is required for the 'generator' field to work. - # TODO: fix before release - 'additionalProperties': True, - 'properties': { - 'name': { - 'type': 'string', - 'description': ( - 'Name of this node. For models, this is used as the ' - 'identifier in the database.'), - 'minLength': 1, - 'maxLength': 127, - }, - 'resource_type': { - 'enum': [ - NodeType.Macro, - ], - }, - 'unique_id': { - 'type': 'string', - 'minLength': 1, - 'maxLength': 255, - }, - 'tags': { - 'description': ( - 'An array of arbitrary strings to use as tags.' - ), - 'type': 'array', - 'items': { - 'type': 'string', - }, - }, - 'depends_on': { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'macros': { - 'type': 'array', - 'items': { - 'type': 'string', - 'minLength': 1, - 'maxLength': 255, - 'description': 'A single macro unique ID.' - } - } - }, - 'description': 'A list of all macros this macro depends on.', - 'required': ['macros'], - }, - }, - 'required': [ - 'resource_type', 'unique_id', 'tags', 'depends_on', 'name', - ] - } -) +@dataclass +class ParsedNodePatch(JsonSchemaMixin, Replaceable): + name: str + description: str + original_file_path: str + columns: Dict[str, ColumnInfo] + docrefs: List[Docref] + + +@dataclass +class MacroDependsOn(JsonSchemaMixin, Replaceable): + macros: List[str] = field(default_factory=list) + +@dataclass +class ParsedMacro(UnparsedMacro): + name: str + resource_type: MacroType + unique_id: str + tags: List[str] + depends_on: MacroDependsOn -class ParsedMacro(APIObject): - SCHEMA = PARSED_MACRO_CONTRACT + def local_vars(self): + return {} @property def generator(self): """ Returns a function that can be called to render the macro results. """ - # TODO: we can generate self.template from the other properties - # available in this class. should we just generate this here? - return dbt.clients.jinja.macro_generator(self._contents) - - -# This is just the file + its ID -PARSED_DOCUMENTATION_CONTRACT = deep_merge( - UNPARSED_DOCUMENTATION_FILE_CONTRACT, - { - 'properties': { - 'name': { - 'type': 'string', - 'description': ( - 'Name of this node, as referred to by doc() references' - ), - }, - 'unique_id': { - 'type': 'string', - 'minLength': 1, - 'maxLength': 255, - 'description': ( - 'The unique ID of this node as stored in the manifest' - ), - }, - 'block_contents': { - 'type': 'string', - 'description': 'The contents of just the docs block', - }, - }, - 'required': ['name', 'unique_id', 'block_contents'], - } -) - - -NODE_EDGE_MAP = { - 'type': 'object', - 'additionalProperties': False, - 'description': 'A map of node relationships', - 'patternProperties': { - '.*': { - 'type': 'array', - 'items': { - 'type': 'string', - 'description': 'A node name', - } - } - } -} - - -class ParsedDocumentation(APIObject): - SCHEMA = PARSED_DOCUMENTATION_CONTRACT - - -class Hook(APIObject): - SCHEMA = HOOK_CONTRACT - - -FRESHNESS_CONTRACT = { - 'properties': { - 'loaded_at_field': { - 'type': ['null', 'string'], - 'description': 'The field to use as the "loaded at" timestamp', - }, - 'freshness': { - 'anyOf': [ - {'type': 'null'}, - { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'warn_after': TIME_CONTRACT, - 'error_after': TIME_CONTRACT, - }, - }, - ], - }, - }, -} - - -QUOTING_CONTRACT = { - 'properties': { - 'quoting': { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'database': {'type': 'boolean'}, - 'schema': {'type': 'boolean'}, - 'identifier': {'type': 'boolean'}, - }, - }, - }, - 'required': ['quoting'], -} - - -PARSED_SOURCE_DEFINITION_CONTRACT = deep_merge( - UNPARSED_BASE_CONTRACT, - FRESHNESS_CONTRACT, - QUOTING_CONTRACT, - HAS_DESCRIPTION_CONTRACT, - HAS_UNIQUE_ID_CONTRACT, - HAS_DOCREFS_CONTRACT, - HAS_RELATION_METADATA_CONTRACT, - HAS_FQN_CONTRACT, - { - 'description': ( - 'A source table definition, as parsed from the one provided in the' - '"tables" subsection of the "sources" section of schema.yml' - ), - 'properties': { - 'name': { - 'type': 'string', - 'description': ( - 'The name of this node, which is the name of the model it' - 'refers to' - ), - 'minLength': 1, - }, - 'source_name': { - 'type': 'string', - 'description': 'The reference name of the source definition', - 'minLength': 1, - }, - 'source_description': { - 'type': 'string', - 'description': 'The user-supplied description of the source', - }, - 'loader': { - 'type': 'string', - 'description': 'The user-defined loader for this source', - }, - 'identifier': { - 'type': 'string', - 'description': 'The identifier for the source table', - 'minLength': 1, - }, - # the manifest search stuff really requires this, sadly - 'resource_type': { - 'enum': [NodeType.Source], - }, - }, - # note that while required, loaded_at_field and freshness may be null - 'required': [ - 'source_name', 'source_description', 'loaded_at_field', 'loader', - 'freshness', 'description', 'columns', 'docrefs', 'identifier', - ], - } -) + return dbt.clients.jinja.macro_generator(self) + + +@dataclass +class ParsedDocumentation(UnparsedDocumentationFile): + name: str + unique_id: str + block_contents: str + + +@dataclass +class ParsedSourceDefinition( + UnparsedBaseNode, + HasUniqueID, + HasRelationMetadata, + HasFqn): + name: str + source_name: str + source_description: str + loader: str + identifier: str + resource_type: SourceType + quoting: Quoting = field(default_factory=Quoting) + loaded_at_field: Optional[str] = None + freshness: FreshnessThreshold = field(default_factory=FreshnessThreshold) + docrefs: List[Docref] = field(default_factory=list) + description: str = '' + columns: Dict[str, ColumnInfo] = field(default_factory=dict) + @property + def is_ephemeral_model(self): + return False -class ParsedSourceDefinition(APIObject): - SCHEMA = PARSED_SOURCE_DEFINITION_CONTRACT - is_ephemeral_model = False - - def to_shallow_dict(self): - return self._contents.copy() - - # provide some emtpy/meaningless properties so these look more like - # ParsedNodes @property def depends_on_nodes(self): return [] diff --git a/core/dbt/contracts/graph/unparsed.py b/core/dbt/contracts/graph/unparsed.py index 4b16f7e6146..796b7d5f160 100644 --- a/core/dbt/contracts/graph/unparsed.py +++ b/core/dbt/contracts/graph/unparsed.py @@ -1,389 +1,160 @@ -from dbt.api import APIObject - -from dbt.node_types import NodeType -from dbt.utils import deep_merge - - -UNPARSED_BASE_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'package_name': { - 'type': 'string', - }, - # filesystem - 'root_path': { - 'type': 'string', - 'description': 'The absolute path to the project root', - }, - 'path': { - 'type': 'string', - 'description': ( - 'Relative path to the source file from the project root. ' - 'Usually the same as original_file_path, but in some cases ' - 'dbt will generate a path.'), - }, - 'original_file_path': { - 'type': 'string', - 'description': ( - 'Relative path to the originating file from the project root.' - ), - } - }, - 'required': ['package_name', 'root_path', 'path', 'original_file_path'] -} - -UNPARSED_HAS_SQL_CONTRACT = { - 'properties': { - 'raw_sql': { - 'type': 'string', - 'description': ( - 'For nodes defined in SQL files, this is just the contents ' - 'of that file. For schema tests, snapshots, etc. this is ' - 'generated by dbt.'), - }, - 'index': { - 'type': 'integer', - } - }, - 'required': ['raw_sql'] -} - -UNPARSED_MACRO_CONTRACT = deep_merge( - UNPARSED_BASE_CONTRACT, - UNPARSED_HAS_SQL_CONTRACT -) - -UNPARSED_NODE_CONTRACT = deep_merge( - UNPARSED_BASE_CONTRACT, - UNPARSED_HAS_SQL_CONTRACT, - { - 'properties': { - 'name': { - 'type': 'string', - 'description': ( - 'Name of this node. For models, this is used as the ' - 'identifier in the database.'), - 'minLength': 1, - }, - 'resource_type': { - 'enum': [ - NodeType.Model, - NodeType.Test, - NodeType.Analysis, - NodeType.Operation, - NodeType.Seed, - # we need this if parse_node is going to handle snapshots. - NodeType.Snapshot, - NodeType.RPCCall, - ] - }, - }, - 'required': ['resource_type', 'name'] - } -) - - -class UnparsedMacro(APIObject): - SCHEMA = UNPARSED_MACRO_CONTRACT - - -class UnparsedNode(APIObject): - SCHEMA = UNPARSED_NODE_CONTRACT - - -COLUMN_TEST_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'name': { - 'type': 'string', - 'description': 'The name of the column this test is for', - }, - 'description': { - 'type': 'string', - 'description': 'The description of this test', - }, - 'tests': { - 'type': 'array', - 'items': { - "anyOf": [ - # 'not_null', 'unique', ... - {'type': 'string'}, - # 'relationships: {...}', 'accepted_values: {...}' - {'type': 'object', 'additionalProperties': True} - ], - }, - 'description': 'The list of tests to perform', - }, - }, - 'required': ['name'], -} - - -UNPARSED_COLUMN_DESCRIPTION_CONTRACT = { - 'properties': { - 'columns': { - 'type': 'array', - 'items': COLUMN_TEST_CONTRACT, - }, - }, -} - - -UNPARSED_NODE_DESCRIPTION_CONTRACT = { - 'properties': { - 'name': { - 'type': 'string', - 'description': ( - 'The name of this node, which is the name of the model it' - 'refers to' - ), - 'minLength': 1, - }, - 'description': { - 'type': 'string', - 'description': ( - 'The raw string description of the node after parsing the yaml' - ), - }, - 'tests': { - 'type': 'array', - 'items': { - "anyOf": [ - {'type': 'string'}, - {'type': 'object', 'additionalProperties': True} - ], - }, - }, - }, - 'required': ['name'], -} - - -UNPARSED_NODE_UPDATE_CONTRACT = deep_merge( - UNPARSED_NODE_DESCRIPTION_CONTRACT, - UNPARSED_COLUMN_DESCRIPTION_CONTRACT, - { - 'type': 'object', - 'additionalProperties': False, - 'description': ( - 'A collection of the unparsed node updates, as provided in the ' - '"models" section of schema.yml' - ), - } -) - - -class UnparsedNodeUpdate(APIObject): - """An unparsed node update is the blueprint for tests to be added and nodes - to be updated, referencing a certain node (specifically, a Model or - Source). - """ - SCHEMA = UNPARSED_NODE_UPDATE_CONTRACT - - -TIME_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'count': { - 'type': 'integer', - }, - 'period': { - 'enum': ['minute', 'hour', 'day'], - }, - }, - 'required': ['count', 'period'], -} - - -_FRESHNESS_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'warn_after': { - 'anyOf': [ - {'type': 'null'}, - TIME_CONTRACT, - ] - }, - 'error_after': { - 'anyOf': [ - {'type': 'null'}, - TIME_CONTRACT, - - ] - }, - }, -} - - -_QUOTING_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'database': {'type': 'boolean'}, - 'schema': {'type': 'boolean'}, - 'identifier': {'type': 'boolean'}, - }, -} - - -QUOTING_CONTRACT = { - 'properties': { - 'quoting': { - 'anyOf': [ - {'type': 'null'}, - _QUOTING_CONTRACT, - ], - }, - }, -} - - -FRESHNESS_CONTRACT = { - 'properties': { - 'loaded_at_field': { - 'type': ['null', 'string'], - 'description': 'The field to use as the "loaded at" timestamp', - }, - 'freshness': { - 'anyOf': [ - {'type': 'null'}, - _FRESHNESS_CONTRACT, - ], - }, - }, -} - - -UNPARSED_SOURCE_TABLE_DEFINITION_CONTRACT = deep_merge( - UNPARSED_NODE_DESCRIPTION_CONTRACT, - UNPARSED_COLUMN_DESCRIPTION_CONTRACT, - FRESHNESS_CONTRACT, - QUOTING_CONTRACT, - { - 'description': ( - 'A source table definition, as provided in the "tables" ' - 'subsection of the "sources" section of schema.yml' - ), - 'properties': { - 'identifier': { - 'type': 'string', - 'description': 'The identifier for the source table', - 'minLength': 1, - }, - }, - } -) - - -UNPARSED_SOURCE_DEFINITION_CONTRACT = deep_merge( - FRESHNESS_CONTRACT, - QUOTING_CONTRACT, - { - 'type': 'object', - 'additionalProperties': False, - 'description': ( - 'A collection of the unparsed sources, as provided in the ' - '"sources" section of schema.yml' - ), - 'properties': { - 'name': { - 'type': 'string', - 'description': 'The reference name of the source definition', - 'minLength': 1, - }, - 'loader': { - 'type': 'string', - 'description': 'The user-defined loader for this source', - 'minLength': 1, - }, - 'description': { - 'type': 'string', - 'description': 'The user-supplied description of the source', - }, - 'database': { - 'type': 'string', - 'description': 'The database name for the source table', - 'minLength': 1, - }, - 'schema': { - 'type': 'string', - 'description': 'The schema name for the source table', - 'minLength': 1, - }, - 'tables': { - 'type': 'array', - 'items': UNPARSED_SOURCE_TABLE_DEFINITION_CONTRACT, - 'description': 'The tables for this source', - 'minLength': 1, - }, - }, - 'required': ['name'], - } -) - - -class UnparsedTableDefinition(APIObject): - SCHEMA = UNPARSED_SOURCE_TABLE_DEFINITION_CONTRACT - - -class UnparsedSourceDefinition(APIObject): - SCHEMA = UNPARSED_SOURCE_DEFINITION_CONTRACT +from dbt.node_types import UnparsedNodeType, NodeType, OperationType, MacroType +from dbt.contracts.util import Replaceable, Mergeable + +from hologram import JsonSchemaMixin +from hologram.helpers import StrEnum + +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Optional, List, Union, Dict, Any + + +@dataclass +class UnparsedBaseNode(JsonSchemaMixin, Replaceable): + package_name: str + root_path: str + path: str + original_file_path: str + + +@dataclass +class HasSQL: + raw_sql: str + + @property + def empty(self): + return not self.raw_sql.strip() + + +@dataclass +class UnparsedMacro(UnparsedBaseNode, HasSQL): + resource_type: MacroType + + +@dataclass +class UnparsedNode(UnparsedBaseNode, HasSQL): + name: str + resource_type: UnparsedNodeType + + +@dataclass +class UnparsedRunHook(UnparsedNode): + resource_type: OperationType + index: Optional[int] = None + + +@dataclass +class NamedTested(JsonSchemaMixin, Replaceable): + name: str + description: str = '' + tests: Optional[List[Union[Dict[str, Any], str]]] = None + + def __post_init__(self): + if self.tests is None: + self.tests = [] + + +@dataclass +class ColumnDescription(JsonSchemaMixin, Replaceable): + columns: Optional[List[NamedTested]] = field(default_factory=list) + + def __post_init__(self): + if self.columns is None: + self.columns = [] + + +@dataclass +class NodeDescription(NamedTested): + pass + + +@dataclass +class UnparsedNodeUpdate(ColumnDescription, NodeDescription): + def __post_init__(self): + NodeDescription.__post_init__(self) + ColumnDescription.__post_init__(self) + + +class TimePeriod(StrEnum): + minute = 'minute' + hour = 'hour' + day = 'day' + + def plural(self) -> str: + return str(self) + 's' + + +@dataclass +class Time(JsonSchemaMixin, Replaceable): + count: int + period: TimePeriod + + def exceeded(self, actual_age: float) -> bool: + kwargs = {self.period.plural(): self.count} + difference = timedelta(**kwargs).total_seconds() + return actual_age > difference + + +class FreshnessStatus(StrEnum): + Pass = 'pass' + Warn = 'warn' + Error = 'error' + + +@dataclass +class FreshnessThreshold(JsonSchemaMixin, Mergeable): + warn_after: Optional[Time] = None + error_after: Optional[Time] = None + + def status(self, age: float) -> FreshnessStatus: + if self.error_after and self.error_after.exceeded(age): + return FreshnessStatus.Error + elif self.warn_after and self.warn_after.exceeded(age): + return FreshnessStatus.Warn + else: + return FreshnessStatus.Pass + + +@dataclass +class Quoting(JsonSchemaMixin, Mergeable): + database: Optional[bool] = None + schema: Optional[bool] = None + identifier: Optional[bool] = None + + +@dataclass +class UnparsedSourceTableDefinition(ColumnDescription, NodeDescription): + loaded_at_field: Optional[str] = None + identifier: Optional[str] = None + quoting: Quoting = field(default_factory=Quoting) + freshness: FreshnessThreshold = field(default_factory=FreshnessThreshold) + + def __post_init__(self): + NodeDescription.__post_init__(self) + ColumnDescription.__post_init__(self) + + +@dataclass +class UnparsedSourceDefinition(JsonSchemaMixin, Replaceable): + name: str + description: str = '' + database: Optional[str] = None + schema: Optional[str] = None + loader: str = '' + quoting: Quoting = field(default_factory=Quoting) + freshness: FreshnessThreshold = field(default_factory=FreshnessThreshold) + loaded_at_field: Optional[str] = None + tables: List[UnparsedSourceTableDefinition] = field(default_factory=list) + + +@dataclass +class UnparsedDocumentationFile(JsonSchemaMixin, Replaceable): + package_name: str + root_path: str + path: str + original_file_path: str + file_contents: str @property - def tables(self): - return [UnparsedTableDefinition(**t) for t in self.get('tables', [])] - - -UNPARSED_DOCUMENTATION_FILE_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'package_name': { - 'type': 'string', - }, - # filesystem - 'root_path': { - 'type': 'string', - 'description': 'The absolute path to the project root', - }, - 'path': { - 'type': 'string', - 'description': ( - 'Relative path to the source file from the project root. ' - 'Usually the same as original_file_path, but in some cases ' - 'dbt will generate a path.'), - }, - 'original_file_path': { - 'type': 'string', - 'description': ( - 'Relative path to the originating file from the project root.' - ), - }, - 'file_contents': { - 'type': 'string', - 'description': ( - 'The raw text provided in the documentation block, presumably ' - 'markdown.' - ), - }, - # TODO: I would like to remove this, but some graph error handling - # cares about it. - 'resource_type': { - 'enum': [ - NodeType.Documentation, - ] - }, - }, - 'required': [ - 'package_name', 'root_path', 'path', 'original_file_path', - 'file_contents', 'resource_type' - ], -} - - -class UnparsedDocumentationFile(APIObject): - SCHEMA = UNPARSED_DOCUMENTATION_FILE_CONTRACT + def resource_type(self): + return NodeType.Documentation diff --git a/core/dbt/contracts/project.py b/core/dbt/contracts/project.py index ecc9c779e0b..cc9bda84bd1 100644 --- a/core/dbt/contracts/project.py +++ b/core/dbt/contracts/project.py @@ -1,366 +1,167 @@ -from dbt.api.object import APIObject +from dbt.contracts.util import Replaceable, Mergeable from dbt.logger import GLOBAL_LOGGER as logger # noqa -from dbt.utils import deep_merge - - -PROJECT_CONTRACT = { - 'type': 'object', - 'description': 'The project configuration.', - 'additionalProperties': False, - 'properties': { - 'name': { - 'type': 'string', - 'pattern': r'^[^\d\W]\w*\Z', - }, - 'version': { - 'anyOf': [ - { - 'type': 'string', - 'pattern': ( - # this does not support the full semver (does not - # allow a trailing -fooXYZ) and is not restrictive - # enough for full semver, (allows '1.0'). But it's like - # 'semver lite'. - r'^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(\.(?:0|[1-9]\d*))?$' - ), - }, - { - # the internal global_project/dbt_project.yml is actually - # 1.0. Heaven only knows how many users have done the same - 'type': 'number', - }, - ], - }, - 'project-root': { - 'type': 'string', - }, - 'source-paths': { - 'type': 'array', - 'items': {'type': 'string'}, - }, - 'macro-paths': { - 'type': 'array', - 'items': {'type': 'string'}, - }, - 'data-paths': { - 'type': 'array', - 'items': {'type': 'string'}, - }, - 'test-paths': { - 'type': 'array', - 'items': {'type': 'string'}, - }, - 'analysis-paths': { - 'type': 'array', - 'items': {'type': 'string'}, - }, - 'docs-paths': { - 'type': 'array', - 'items': {'type': 'string'}, - }, - 'target-path': { - 'type': 'string', - }, - 'snapshot-paths': { - 'type': 'array', - 'items': {'type': 'string'}, - }, - 'clean-targets': { - 'type': 'array', - 'items': {'type': 'string'}, - }, - 'profile': { - 'type': ['null', 'string'], - }, - 'log-path': { - 'type': 'string', - }, - 'modules-path': { - 'type': 'string', - }, - 'quoting': { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'identifier': { - 'type': 'boolean', - }, - 'schema': { - 'type': 'boolean', - }, - 'database': { - 'type': 'boolean', - }, - 'project': { - 'type': 'boolean', - } - }, - }, - 'models': { - 'type': 'object', - 'additionalProperties': True, - }, - 'on-run-start': { - 'type': 'array', - 'items': {'type': 'string'}, - }, - 'on-run-end': { - 'type': 'array', - 'items': {'type': 'string'}, - }, - 'seeds': { - 'type': 'object', - 'additionalProperties': True, - }, - # we validate the regex separately, using the pattern in dbt.semver - 'require-dbt-version': { - 'type': ['string', 'array'], - 'items': {'type': 'string'}, - }, - }, - 'required': ['name', 'version'], -} - - -class Project(APIObject): - SCHEMA = PROJECT_CONTRACT - - -LOCAL_PACKAGE_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'local': { - 'type': 'string', - 'description': 'The absolute path to the local package.', - }, - 'required': ['local'], - }, -} - - -GIT_PACKAGE_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'git': { - 'type': 'string', - 'description': ( - 'The URL to the git repository that stores the pacakge' - ), - }, - 'revision': { - 'type': ['string', 'array'], - 'items': {'type': 'string'}, - 'description': 'The git revision to use, if it is not tip', - }, - 'warn-unpinned': { - 'type': 'boolean', - } - }, - 'required': ['git'], -} - - -VERSION_SPECIFICATION_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'major': { - 'type': ['string', 'null'], - }, - 'minor': { - 'type': ['string', 'null'], - }, - 'patch': { - 'type': ['string', 'null'], - }, - 'prerelease': { - 'type': ['string', 'null'], - }, - 'build': { - 'type': ['string', 'null'], - }, - 'matcher': { - 'type': 'string', - 'enum': ['=', '>=', '<=', '>', '<'], - }, - }, - 'required': ['major', 'minor', 'patch', 'prerelease', 'build', 'matcher'], -} - - -REGISTRY_PACKAGE_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'package': { - 'type': 'string', - 'description': 'The name of the package', - }, - 'version': { - 'type': ['string', 'array'], - 'items': { - 'anyOf': [ - VERSION_SPECIFICATION_CONTRACT, - {'type': 'string'} - ], - }, - 'description': 'The version of the package', - }, - }, - 'required': ['package', 'version'], -} - - -PACKAGE_FILE_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'packages': { - 'type': 'array', - 'items': { - 'anyOf': [ - LOCAL_PACKAGE_CONTRACT, - GIT_PACKAGE_CONTRACT, - REGISTRY_PACKAGE_CONTRACT, - ], - }, - }, - }, - 'required': ['packages'], -} - - -# the metadata from the registry has extra things that we don't care about. -REGISTRY_PACKAGE_METADATA_CONTRACT = deep_merge( - PACKAGE_FILE_CONTRACT, - { - 'additionalProperties': True, - 'properties': { - 'name': { - 'type': 'string', - }, - 'downloads': { - 'type': 'object', - 'additionalProperties': True, - 'properties': { - 'tarball': { - 'type': 'string', - }, - }, - 'required': ['tarball'] - }, - }, - 'required': PACKAGE_FILE_CONTRACT['required'][:] + ['downloads'] - } -) +from dbt import tracking +from dbt.ui import printer +# from dbt.utils import JSONEncoder + +from hologram import JsonSchemaMixin +from hologram.helpers import HyphenatedJsonSchemaMixin, NewPatternType, \ + ExtensibleJsonSchemaMixin + +from dataclasses import dataclass, field +from typing import Optional, List, Dict, Union, Any + +PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa +DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True +DEFAULT_USE_COLORS = True -class PackageConfig(APIObject): - SCHEMA = PACKAGE_FILE_CONTRACT - - -USER_CONFIG_CONTRACT = { - 'type': 'object', - 'additionalProperties': True, - 'properties': { - 'send_anonymous_usage_stats': { - 'type': 'boolean', - }, - 'use_colors': { - 'type': 'boolean', - }, - }, -} - - -PROFILE_INFO_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'profile_name': { - 'type': 'string', - }, - 'target_name': { - 'type': 'string', - }, - 'config': USER_CONFIG_CONTRACT, - 'threads': { - 'type': 'number', - }, - 'credentials': { - 'type': 'object', - 'additionalProperties': True, - }, - }, - 'required': [ - 'profile_name', 'target_name', 'config', 'threads', 'credentials' - ], -} - - -class ProfileConfig(APIObject): - SCHEMA = PROFILE_INFO_CONTRACT - - -def _merge_requirements(base, *args): - required = base[:] - for arg in args: - required.extend(arg['required']) - return required - - -CONFIG_CONTRACT = deep_merge( - PROJECT_CONTRACT, - PACKAGE_FILE_CONTRACT, - PROFILE_INFO_CONTRACT, - { - 'properties': { - 'cli_vars': { - 'type': 'object', - 'additionalProperties': True, - }, - # override quoting: both 'identifier' and 'schema' must be - # populated - 'quoting': { - 'required': ['identifier', 'schema'], - }, - }, - 'required': _merge_requirements( - ['cli_vars'], - PROJECT_CONTRACT, - PACKAGE_FILE_CONTRACT, - PROFILE_INFO_CONTRACT - ), - }, +Name = NewPatternType('Name', r'^[^\d\W]\w*\Z') + +# this does not support the full semver (does not allow a trailing -fooXYZ) and +# is not restrictive enough for full semver, (allows '1.0'). But it's like +# 'semver lite'. +SemverString = NewPatternType( + 'SemverString', + r'^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(\.(?:0|[1-9]\d*))?$', ) -def update_config_contract(typename, connection): - PROFILE_INFO_CONTRACT['properties']['credentials']['anyOf'].append( - connection.SCHEMA - ) - CONFIG_CONTRACT['properties']['credentials']['anyOf'].append( - connection.SCHEMA - ) +@dataclass +class Quoting(JsonSchemaMixin, Mergeable): + identifier: Optional[bool] + schema: Optional[bool] + database: Optional[bool] + project: Optional[bool] + + +@dataclass +class Package(Replaceable, HyphenatedJsonSchemaMixin): + pass + + +@dataclass +class LocalPackage(Package): + local: str + + +@dataclass +class GitPackage(Package): + git: str + revision: Optional[str] + warn_unpinned: Optional[bool] = None + + +@dataclass +class RegistryPackage(Package): + package: str + version: Union[str, List[str]] + +PackageSpec = Union[LocalPackage, GitPackage, RegistryPackage] -class Configuration(APIObject): - SCHEMA = CONFIG_CONTRACT +@dataclass +class PackageConfig(JsonSchemaMixin, Replaceable): + packages: List[PackageSpec] -PROJECTS_LIST_PROJECT = { - 'type': 'object', - 'additionalProperties': False, - 'patternProperties': { - '.*': CONFIG_CONTRACT, - }, -} +@dataclass +class ProjectPackageMetadata: + name: str + packages: List[PackageSpec] + + @classmethod + def from_project(cls, project): + return cls(name=project.project_name, + packages=project.packages.packages) + + +@dataclass +class Downloads(ExtensibleJsonSchemaMixin, Replaceable): + tarball: str + + +@dataclass +class RegistryPackageMetadata( + ExtensibleJsonSchemaMixin, + ProjectPackageMetadata, +): + downloads: Downloads + + +@dataclass +class Project(HyphenatedJsonSchemaMixin, Replaceable): + name: Name + version: Union[SemverString, float] + project_root: Optional[str] + source_paths: Optional[List[str]] + macro_paths: Optional[List[str]] + data_paths: Optional[List[str]] + test_paths: Optional[List[str]] + analysis_paths: Optional[List[str]] + docs_paths: Optional[List[str]] + target_path: Optional[str] + snapshot_paths: Optional[List[str]] + clean_targets: Optional[List[str]] + profile: Optional[str] + log_path: Optional[str] + modules_path: Optional[str] + quoting: Optional[Quoting] + on_run_start: Optional[List[str]] = field(default_factory=list) + on_run_end: Optional[List[str]] = field(default_factory=list) + require_dbt_version: Optional[Union[List[str], str]] = None + models: Dict[str, Any] = field(default_factory=dict) + seeds: Dict[str, Any] = field(default_factory=dict) + packages: List[PackageSpec] = field(default_factory=list) + + +@dataclass +class UserConfig(ExtensibleJsonSchemaMixin, Replaceable): + send_anonymous_usage_stats: bool = DEFAULT_SEND_ANONYMOUS_USAGE_STATS + use_colors: bool = DEFAULT_USE_COLORS + printer_width: Optional[int] = None + + def set_values(self, cookie_dir): + if self.send_anonymous_usage_stats: + tracking.initialize_tracking(cookie_dir) + else: + tracking.do_not_track() + + if self.use_colors: + printer.use_colors() + + if self.printer_width: + printer.printer_width(self.printer_width) + + +@dataclass +class ProfileConfig(HyphenatedJsonSchemaMixin, Replaceable): + profile_name: str = field(metadata={'preserve_underscore': True}) + target_name: str = field(metadata={'preserve_underscore': True}) + config: UserConfig + threads: int + # TODO: make this a dynamic union of some kind? + credentials: Optional[Any] + + +@dataclass +class ConfiguredQuoting(JsonSchemaMixin, Replaceable): + identifier: bool + schema: bool + database: Optional[bool] + project: Optional[bool] + + +@dataclass +class Configuration(Project, ProfileConfig): + cli_vars: Dict[str, Any] = field( + default_factory=dict, + metadata={'preserve_underscore': True}, + ) + quoting: Optional[ConfiguredQuoting] = None -class ProjectList(APIObject): - SCHEMA = PROJECTS_LIST_PROJECT - def serialize(self): - return {k: v.serialize() for k, v in self._contents.items()} +@dataclass +class ProjectList(JsonSchemaMixin): + projects: Dict[str, Project] diff --git a/core/dbt/contracts/results.py b/core/dbt/contracts/results.py index 771806b379c..da98c019854 100644 --- a/core/dbt/contracts/results.py +++ b/core/dbt/contracts/results.py @@ -1,47 +1,34 @@ -from dbt.api.object import APIObject -from dbt.utils import deep_merge, timestring -from dbt.contracts.common import named_property -from dbt.contracts.graph.manifest import COMPILE_RESULT_NODE_CONTRACT -from dbt.contracts.graph.unparsed import TIME_CONTRACT -from dbt.contracts.graph.parsed import PARSED_SOURCE_DEFINITION_CONTRACT - - -TIMING_INFO_CONTRACT = { - 'type': 'object', - 'properties': { - 'name': { - 'type': 'string', - }, - 'started_at': { - 'type': 'string', - 'format': 'date-time', - }, - 'completed_at': { - 'type': 'string', - 'format': 'date-time', - }, - } -} - - -class TimingInfo(APIObject): - - SCHEMA = TIMING_INFO_CONTRACT - - @classmethod - def create(cls, name): - return cls(name=name) +from dbt.contracts.graph.manifest import CompileResultNode +from dbt.contracts.graph.unparsed import Time, FreshnessStatus +from dbt.contracts.graph.parsed import ParsedSourceDefinition +from dbt.contracts.util import Writable +from hologram.helpers import StrEnum +from hologram import JsonSchemaMixin + +import agate + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Union, Dict, List, Optional, Any +from numbers import Real + + +@dataclass +class TimingInfo(JsonSchemaMixin): + name: str + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None def begin(self): - self.set('started_at', timestring()) + self.started_at = datetime.utcnow() def end(self): - self.set('completed_at', timestring()) + self.completed_at = datetime.utcnow() class collect_timing_info: def __init__(self, name): - self.timing_info = TimingInfo.create(name) + self.timing_info = TimingInfo(name=name) def __enter__(self): self.timing_info.begin() @@ -51,65 +38,15 @@ def __exit__(self, exc_type, exc_value, traceback): self.timing_info.end() -class NodeSerializable(APIObject): - def serialize(self): - result = super().serialize() - result['node'] = self.node.serialize() - return result - - -PARTIAL_RESULT_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'description': 'The partial result of a single node being run', - 'properties': { - 'error': { - 'type': ['string', 'null'], - 'description': 'The error string, or None if there was no error', - }, - 'status': { - 'type': ['string', 'null', 'number', 'boolean'], - 'description': 'The status result of the node execution', - }, - 'execution_time': { - 'type': 'number', - 'description': 'The execution time, in seconds', - }, - 'thread_id': { - 'type': ['string', 'null'], - 'description': 'ID of the executing thread, e.g. Thread-3', - }, - 'timing': { - 'type': 'array', - 'items': TIMING_INFO_CONTRACT, - }, - 'node': COMPILE_RESULT_NODE_CONTRACT, - }, - 'required': ['node', 'status', 'error', 'execution_time', 'thread_id', - 'timing'], -} - - -class PartialResult(NodeSerializable): - """Represent a "partial" execution result, i.e. one that has not (fully) - been executed. - - This may be an ephemeral node (they are not compiled) or any error. - """ - SCHEMA = PARTIAL_RESULT_CONTRACT - - def __init__(self, node, error=None, status=None, execution_time=0, - thread_id=None, timing=None): - if timing is None: - timing = [] - super().__init__( - node=node, - error=error, - status=status, - execution_time=execution_time, - thread_id=thread_id, - timing=timing, - ) +@dataclass +class PartialResult(JsonSchemaMixin, Writable): + node: CompileResultNode + error: Optional[str] = None + status: Union[None, str, int, bool] = None + execution_time: Union[str, int] = 0 + thread_id: Optional[int] = 0 + timing: List[TimingInfo] = field(default_factory=list) + fail: Optional[bool] = None # if the result got to the point where it could be skipped/failed, we would # be returning a real result, not a partial. @@ -117,417 +54,147 @@ def __init__(self, node, error=None, status=None, execution_time=0, def skipped(self): return False - @property - def failed(self): - return None +@dataclass +class WritableRunModelResult(PartialResult): + skip: bool = False -RUN_MODEL_RESULT_CONTRACT = deep_merge(PARTIAL_RESULT_CONTRACT, { - 'description': 'The result of a single node being run', - 'properties': { - 'skip': { - 'type': 'boolean', - 'description': 'True if this node was skipped', - }, - # This is assigned by dbt.ui.printer.print_test_result_line, if a test - # has no error and a non-zero status - 'fail': { - 'type': ['boolean', 'null'], - 'description': 'On tests, true if the test failed', - }, - }, - 'required': ['skip', 'fail'] -}) - - -class RunModelResult(NodeSerializable): - SCHEMA = RUN_MODEL_RESULT_CONTRACT - - def __init__(self, node, error=None, skip=False, status=None, failed=None, - thread_id=None, timing=None, execution_time=0, - agate_table=None): - if timing is None: - timing = [] - self.agate_table = agate_table - super().__init__( - node=node, - error=error, - skip=skip, - status=status, - fail=failed, - execution_time=execution_time, - thread_id=thread_id, - timing=timing, - ) - # these all get set after the fact, generally - error = named_property('error', - 'If there was an error, the text of that error') - skip = named_property('skip', 'True if the model was skipped') - fail = named_property('fail', 'True if this was a test and it failed') - status = named_property('status', 'The status of the model execution') - execution_time = named_property('execution_time', - 'The time in seconds to execute the model') - thread_id = named_property( - 'thread_id', - 'ID of the executing thread, e.g. Thread-3' - ) - timing = named_property( - 'timing', - 'List of TimingInfo objects' - ) +@dataclass +class RunModelResult(WritableRunModelResult): + agate_table: Optional[agate.Table] = None - @property - def failed(self): - return self.fail + def to_dict(self, *args, **kwargs): + dct = super().to_dict(*args, **kwargs) + dct.pop('agate_table', None) + return dct - @property - def skipped(self): - return self.skip - - -EXECUTION_RESULT_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'description': 'The result of a single dbt invocation', - 'properties': { - 'results': { - 'type': 'array', - 'items': { - 'anyOf': [ - RUN_MODEL_RESULT_CONTRACT, - PARTIAL_RESULT_CONTRACT, - ] - }, - 'description': 'An array of results, one per model', - }, - 'generated_at': { - 'type': 'string', - 'format': 'date-time', - 'description': ( - 'The time at which the execution result was generated' - ), - }, - 'elapsed_time': { - 'type': 'number', - 'description': ( - 'The time elapsed from before_run to after_run (hooks are not ' - 'included)' - ), - } - }, - 'required': ['results', 'generated_at', 'elapsed_time'], -} - - -class ExecutionResult(APIObject): - SCHEMA = EXECUTION_RESULT_CONTRACT - - def serialize(self): - return { - 'results': [r.serialize() for r in self.results], - 'generated_at': self.generated_at, - 'elapsed_time': self.elapsed_time, - } - - -SOURCE_FRESHNESS_RESULT_CONTRACT = deep_merge(PARTIAL_RESULT_CONTRACT, { - 'properties': { - 'max_loaded_at': { - 'type': 'string', - 'format': 'date-time', - }, - 'snapshotted_at': { - 'type': 'string', - 'format': 'date-time', - }, - 'age': { - 'type': 'number', - }, - 'status': { - 'enum': ['pass', 'warn', 'error'] - }, - 'node': PARSED_SOURCE_DEFINITION_CONTRACT, - }, - 'required': ['max_loaded_at', 'snapshotted_at', 'age'] -}) - - -class SourceFreshnessResult(NodeSerializable): - SCHEMA = SOURCE_FRESHNESS_RESULT_CONTRACT - - def __init__(self, node, max_loaded_at, snapshotted_at, - age, status, thread_id, error=None, - timing=None, execution_time=0): - max_loaded_at = max_loaded_at.isoformat() - snapshotted_at = snapshotted_at.isoformat() - if timing is None: - timing = [] - super().__init__( - node=node, - max_loaded_at=max_loaded_at, - snapshotted_at=snapshotted_at, - age=age, - status=status, - thread_id=thread_id, - error=error, - timing=timing, - execution_time=execution_time - ) - @property - def failed(self): - return self.status == 'error' +@dataclass +class ExecutionResult(JsonSchemaMixin, Writable): + results: List[Union[WritableRunModelResult, PartialResult]] + generated_at: datetime + elapsed_time: Real + + +# due to issues with typing.Union collapsing subclasses, this can't subclass +# PartialResult +@dataclass +class SourceFreshnessResult(JsonSchemaMixin, Writable): + node: ParsedSourceDefinition + max_loaded_at: datetime + snapshotted_at: datetime + age: Real + status: FreshnessStatus + error: Optional[str] = None + status: Union[None, str, int, bool] = None + execution_time: Union[str, int] = 0 + thread_id: Optional[int] = 0 + timing: List[TimingInfo] = field(default_factory=list) + fail: Optional[bool] = None + + def __post_init__(self): + self.fail = self.status == 'error' @property def skipped(self): return False -FRESHNESS_METADATA_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'generated_at': { - 'type': 'string', - 'format': 'date-time', - 'description': ( - 'The time at which the execution result was generated' - ), - }, - 'elapsed_time': { - 'type': 'number', - 'description': ( - 'The time elapsed from before_run to after_run (hooks ' - 'are not included)' - ), - }, - }, - 'required': ['generated_at', 'elapsed_time'] -} - - -FRESHNESS_RESULTS_CONTRACT = deep_merge(FRESHNESS_METADATA_CONTRACT, { - 'description': 'The result of a single dbt source freshness invocation', - 'properties': { - 'results': { - 'type': 'array', - 'items': { - 'anyOf': [ - PARTIAL_RESULT_CONTRACT, - SOURCE_FRESHNESS_RESULT_CONTRACT, - ], - }, - }, - }, - 'required': ['results'], -}) - - -class FreshnessExecutionResult(APIObject): - SCHEMA = FRESHNESS_RESULTS_CONTRACT - - def __init__(self, elapsed_time, generated_at, results): - super().__init__( - elapsed_time=elapsed_time, - generated_at=generated_at, - results=results - ) +@dataclass +class FreshnessMetadata(JsonSchemaMixin): + generated_at: datetime + elapsed_time: Real - def serialize(self): - return { - 'generated_at': self.generated_at, - 'elapsed_time': self.elapsed_time, - 'results': [s.serialize() for s in self.results] - } - def write(self, path): +@dataclass +class FreshnessExecutionResult(FreshnessMetadata): + results: List[Union[PartialResult, SourceFreshnessResult]] + + def write(self, path, omit_none=True): """Create a new object with the desired output schema and write it.""" - meta = { - 'generated_at': self.generated_at, - 'elapsed_time': self.elapsed_time, - } + meta = FreshnessMetadata( + generated_at=self.generated_at, + elapsed_time=self.elapsed_time, + ) sources = {} for result in self.results: unique_id = result.node.unique_id if result.error is not None: - result_dict = { - 'error': result.error, - 'state': 'runtime error' - } + result_value = SourceFreshnessRuntimeError( + error=result.error, + state=FreshnessErrorEnum.runtime_error, + ) else: - result_dict = { - 'max_loaded_at': result.max_loaded_at, - 'snapshotted_at': result.snapshotted_at, - 'max_loaded_at_time_ago_in_s': result.age, - 'state': result.status, - 'criteria': result.node.freshness, - } - sources[unique_id] = result_dict + result_value = SourceFreshnessOutput( + max_loaded_at=result.max_loaded_at, + snapshotted_at=result.snapshotted_at, + max_loaded_at_time_ago_in_s=result.age, + state=result.status, + criteria=result.node.freshness, + ) + sources[unique_id] = result_value output = FreshnessRunOutput(meta=meta, sources=sources) - output.write(path) + output.write(path, omit_none=omit_none) def _copykeys(src, keys, **updates): return {k: getattr(src, k) for k in keys} -SOURCE_FRESHNESS_OUTPUT_ERROR_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'description': ( - 'The source freshness output for a single source table', - ), - 'properties': { - 'error': { - 'type': 'string', - 'description': 'The error string', - }, - 'state': { - 'enum': ['runtime error'], - }, - } -} - - -SOURCE_FRESHNESS_OUTPUT_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'description': ( - 'The source freshness output for a single source table', - ), - 'properties': { - 'max_loaded_at': { - 'type': 'string', - 'format': 'date-time', - }, - 'snapshotted_at': { - 'type': 'string', - 'format': 'date-time', - }, - 'max_loaded_at_time_ago_in_s': { - 'type': 'number', - }, - 'state': { - 'enum': ['pass', 'warn', 'error'] - }, - 'criteria': { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'warn_after': TIME_CONTRACT, - 'error_after': TIME_CONTRACT, - }, - }, - 'required': ['state', 'criteria', 'max_loaded_at', 'snapshotted_at', - 'max_loaded_at_time_ago_in_s'] - } -} - - -FRESHNESS_RUN_OUTPUT_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'description': 'The output contract for dbt source freshness invocations', - 'properties': { - 'meta': FRESHNESS_METADATA_CONTRACT, - 'sources': { - 'type': 'object', - 'additionalProperties': False, - 'description': ( - 'A collection of the source results, stored by their unique ' - 'IDs.' - ), - 'patternProperties': { - '.*': { - 'anyOf': [ - SOURCE_FRESHNESS_OUTPUT_ERROR_CONTRACT, - SOURCE_FRESHNESS_OUTPUT_CONTRACT - ], - }, - }, - } - } -} - - -class FreshnessRunOutput(APIObject): - SCHEMA = FRESHNESS_RUN_OUTPUT_CONTRACT - - def __init__(self, meta, sources): - super().__init__(meta=meta, sources=sources) - - -REMOTE_COMPILE_RESULT_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'raw_sql': { - 'type': 'string', - }, - 'compiled_sql': { - 'type': 'string', - }, - 'timing': { - 'type': 'array', - 'items': TIMING_INFO_CONTRACT, - }, - }, - 'required': ['raw_sql', 'compiled_sql', 'timing'] -} - - -class RemoteCompileResult(APIObject): - SCHEMA = REMOTE_COMPILE_RESULT_CONTRACT - - def __init__(self, raw_sql, compiled_sql, node, timing=None, **kwargs): - if timing is None: - timing = [] - # this should not show up in the serialized output. - self.node = node - super().__init__( - raw_sql=raw_sql, - compiled_sql=compiled_sql, - timing=timing, - **kwargs - ) +@dataclass +class FreshnessCriteria(JsonSchemaMixin): + warn_after: Time + error_after: Time + + +class FreshnessErrorEnum(StrEnum): + runtime_error = 'runtime error' + + +@dataclass +class SourceFreshnessRuntimeError(JsonSchemaMixin): + error: str + state: FreshnessErrorEnum + + +@dataclass +class SourceFreshnessOutput(JsonSchemaMixin): + max_loaded_at: datetime + snapshotted_at: datetime + max_loaded_at_time_ago_in_s: Real + state: FreshnessStatus + criteria: FreshnessCriteria + + +SourceFreshnessRunResult = Union[SourceFreshnessOutput, + SourceFreshnessRuntimeError] + + +@dataclass +class FreshnessRunOutput(JsonSchemaMixin, Writable): + meta: FreshnessMetadata + sources: Dict[str, SourceFreshnessRunResult] + + +@dataclass +class RemoteCompileResult(JsonSchemaMixin): + raw_sql: str + compiled_sql: str + node: CompileResultNode + timing: List[TimingInfo] @property def error(self): return None -REMOTE_RUN_RESULT_CONTRACT = deep_merge(REMOTE_COMPILE_RESULT_CONTRACT, { - 'properties': { - 'table': { - 'type': 'object', - 'properties': { - 'column_names': { - 'type': 'array', - 'items': {'type': 'string'}, - }, - 'rows': { - 'type': 'array', - # any item type is ok - }, - }, - 'required': ['rows', 'column_names'], - }, - }, - 'required': ['table'], -}) +@dataclass +class ResultTable(JsonSchemaMixin): + column_names: List[str] + rows: List[Any] +@dataclass class RemoteRunResult(RemoteCompileResult): - SCHEMA = REMOTE_RUN_RESULT_CONTRACT - - def __init__(self, raw_sql, compiled_sql, node, timing=None, table=None): - if table is None: - table = [] - super().__init__( - raw_sql=raw_sql, - compiled_sql=compiled_sql, - timing=timing, - table=table, - node=node - ) + table: ResultTable diff --git a/core/dbt/contracts/util.py b/core/dbt/contracts/util.py new file mode 100644 index 00000000000..23d770ad204 --- /dev/null +++ b/core/dbt/contracts/util.py @@ -0,0 +1,29 @@ +from dbt.clients.system import write_json + +import dataclasses + + +class Replaceable: + def replace(self, **kwargs): + return dataclasses.replace(self, **kwargs) + + +class Mergeable(Replaceable): + def merged(self, *args): + """Perform a shallow merge, where the last non-None write wins. This is + intended to merge dataclasses that are a collection of optional values. + """ + replacements = {} + cls = type(self) + for field in dataclasses.fields(cls): + for arg in args: + value = getattr(arg, field.name) + if value is not None: + replacements[field.name] = value + + return self.replace(**replacements) + + +class Writable: + def write(self, path: str, omit_none: bool = False): + write_json(path, self.to_dict(omit_none=omit_none)) diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index a733e3a2d77..9d3ab0e9c7e 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -4,6 +4,18 @@ from dbt.logger import GLOBAL_LOGGER as logger import dbt.flags +import hologram + + +def validator_error_message(exc): + """Given a hologram.ValidationError (which is basically a + jsonschema.ValidationError), return the relevant parts as a string + """ + if not isinstance(exc, hologram.ValidationError): + return str(exc) + path = "[%s]" % "][".join(map(repr, exc.relative_path)) + return 'at path {}: {}'.format(path, exc.message) + class Exception(builtins.Exception): CODE = -32000 @@ -46,11 +58,16 @@ def type(self): def node_to_string(self, node): if node is None: return "" - + if not hasattr(node, 'name'): + # we probably failed to parse a block, so we can't know the name + return '{} ({})'.format( + node.resource_type, + node.original_file_path + ) return "{} {} ({})".format( - node.get('resource_type'), - node.get('name', 'unknown'), - node.get('original_file_path')) + node.resource_type, + node.name, + node.original_file_path) def process_stack(self): lines = [] @@ -98,8 +115,9 @@ def data(self): return result result.update({ - 'raw_sql': self.node.get('raw_sql'), - 'compiled_sql': self.node.get('injected_sql'), + 'raw_sql': self.node.raw_sql, + # the node isn't always compiled, but if it is, include that! + 'compiled_sql': getattr(self.node, 'injected_sql', None), }) return result @@ -149,9 +167,8 @@ class DatabaseException(RuntimeException): def process_stack(self): lines = [] - if self.node is not None and self.node.get('build_path'): - lines.append( - "compiled SQL at {}".format(self.node.get('build_path'))) + if hasattr(self.node, 'build_path') and self.node.build_path: + lines.append("compiled SQL at {}".format(self.node.build_path)) return lines + RuntimeException.process_stack(self) @@ -350,7 +367,7 @@ def doc_target_not_found(model, target_doc_name, target_doc_package): msg = ( "Documentation for '{}' depends on doc '{}' {} which was not found" ).format( - model.get('unique_id'), + model.unique_id, target_doc_name, target_package_string ) @@ -365,11 +382,11 @@ def _get_target_failure_msg(model, target_model_name, target_model_package, source_path_string = '' if include_path: - source_path_string = ' ({})'.format(model.get('original_file_path')) + source_path_string = ' ({})'.format(model.original_file_path) return ("{} '{}'{} depends on model '{}' {}which {}" - .format(model.get('resource_type').title(), - model.get('unique_id'), + .format(model.resource_type.title(), + model.unique_id, source_path_string, target_model_name, target_package_string, @@ -403,9 +420,9 @@ def ref_target_not_found(model, target_model_name, target_model_package): def source_disabled_message(model, target_name, target_table_name): return ("{} '{}' ({}) depends on source '{}.{}' which was not found" - .format(model.get('resource_type').title(), - model.get('unique_id'), - model.get('original_file_path'), + .format(model.resource_type.title(), + model.unique_id, + model.original_file_path, target_name, target_table_name)) @@ -418,15 +435,15 @@ def source_target_not_found(model, target_name, target_table_name): def ref_disabled_dependency(model, target_model): raise_compiler_error( "Model '{}' depends on model '{}' which is disabled in " - "the project config".format(model.get('unique_id'), - target_model.get('unique_id')), + "the project config".format(model.unique_id, + target_model.unique_id), model) def dependency_not_found(model, target_model_name): raise_compiler_error( "'{}' depends on '{}' which is not in the graph!" - .format(model.get('unique_id'), target_model_name), + .format(model.unique_id, target_model_name), model) @@ -434,7 +451,7 @@ def macro_not_found(model, target_macro_id): raise_compiler_error( model, "'{}' references macro '{}' which is not defined!" - .format(model.get('unique_id'), target_macro_id)) + .format(model.unique_id, target_macro_id)) def materialization_not_available(model, adapter_type): @@ -475,7 +492,7 @@ def raise_cache_inconsistent(message): def missing_config(model, name): raise_compiler_error( "Model '{}' does not define a required config parameter '{}'." - .format(model.get('unique_id'), name), + .format(model.unique_id, name), model) @@ -559,7 +576,7 @@ def approximate_relation_match(target, relation): def raise_duplicate_resource_name(node_1, node_2): - duped_name = node_1['name'] + duped_name = node_1.name raise_compiler_error( 'dbt found two resources with the name "{}". Since these resources ' @@ -568,12 +585,12 @@ def raise_duplicate_resource_name(node_1, node_2): 'these resources:\n- {} ({})\n- {} ({})'.format( duped_name, duped_name, - node_1['unique_id'], node_1['original_file_path'], - node_2['unique_id'], node_2['original_file_path'])) + node_1.unique_id, node_1.original_file_path, + node_2.unique_id, node_2.original_file_path)) def raise_ambiguous_alias(node_1, node_2): - duped_name = "{}.{}".format(node_1['schema'], node_1['alias']) + duped_name = "{}.{}".format(node_1.schema, node_1.alias) raise_compiler_error( 'dbt found two resources with the database representation "{}".\ndbt ' @@ -581,8 +598,8 @@ def raise_ambiguous_alias(node_1, node_2): 'To fix this,\nchange the "schema" or "alias" configuration of one of ' 'these resources:\n- {} ({})\n- {} ({})'.format( duped_name, - node_1['unique_id'], node_1['original_file_path'], - node_2['unique_id'], node_2['original_file_path'])) + node_1.unique_id, node_1.original_file_path, + node_2.unique_id, node_2.original_file_path)) def raise_ambiguous_catalog_match(unique_id, match_1, match_2): diff --git a/core/dbt/graph/selector.py b/core/dbt/graph/selector.py index b93a0b3a9d6..6b556a034bf 100644 --- a/core/dbt/graph/selector.py +++ b/core/dbt/graph/selector.py @@ -341,7 +341,7 @@ def _is_graph_member(self, node_name): node = self.manifest.nodes[node_name] if node.resource_type == NodeType.Source: return True - return not node.get('empty') and is_enabled(node) + return not node.empty and is_enabled(node) def _is_match(self, node_name, resource_types, tags, required): node = self.manifest.nodes[node_name] @@ -376,9 +376,9 @@ def get_selected(self, include, exclude, resource_types, tags, required): return filtered_nodes def is_ephemeral_model(self, node): - is_model = node.get('resource_type') == NodeType.Model - is_ephemeral = get_materialization(node) == 'ephemeral' - return is_model and is_ephemeral + # if it's not a model, `get_materialization` is probably an error! + return node.resource_type == NodeType.Model and \ + get_materialization(node) == 'ephemeral' def get_ancestor_ephemeral_nodes(self, selected_nodes): node_names = {} diff --git a/core/dbt/hooks.py b/core/dbt/hooks.py index d531155ae20..a5d47f01f67 100644 --- a/core/dbt/hooks.py +++ b/core/dbt/hooks.py @@ -1,42 +1,29 @@ -from enum import Enum +from hologram.helpers import StrEnum import json +from dbt.contracts.graph.parsed import Hook -class ModelHookType(str, Enum): +from typing import Union, Dict, Any + + +class ModelHookType(StrEnum): PreHook = 'pre-hook' PostHook = 'post-hook' - def __str__(self): - return self._value_ - -def _parse_hook_to_dict(hook_string): +def get_hook_dict(source: Union[str, Dict[str, Any]]) -> Dict[str, Any]: + """From a source string-or-dict, get a dictionary that can be passed to + Hook.from_dict + """ + if isinstance(source, dict): + return source try: - hook_dict = json.loads(hook_string) + return json.loads(source) except ValueError: - hook_dict = {"sql": hook_string} - - if 'transaction' not in hook_dict: - hook_dict['transaction'] = True - - return hook_dict - - -def get_hook_dict(hook, index): - if isinstance(hook, dict): - hook_dict = hook - else: - hook_dict = _parse_hook_to_dict(hook) - - hook_dict['index'] = index - return hook_dict - - -def get_hooks(model, hook_key): - hooks = model.config.get(hook_key, []) + return {'sql': source} - if not isinstance(hooks, (list, tuple)): - hooks = [hooks] - wrapped = [get_hook_dict(hook, i) for i, hook in enumerate(hooks)] - return wrapped +def get_hook(source, index): + hook_dict = get_hook_dict(source) + hook_dict.setdefault('index', index) + return Hook.from_dict(hook_dict) diff --git a/core/dbt/linker.py b/core/dbt/linker.py index f29d7bfd1d7..1e5b8d2d6f3 100644 --- a/core/dbt/linker.py +++ b/core/dbt/linker.py @@ -259,6 +259,6 @@ def read_graph(self, infile): def _updated_graph(graph, manifest): graph = graph.copy() for node_id in graph.nodes(): - data = manifest.nodes[node_id].serialize() + data = manifest.nodes[node_id].to_dict() graph.add_node(node_id, **data) return graph diff --git a/core/dbt/loader.py b/core/dbt/loader.py index 5daaecb5a75..48488c2d02d 100644 --- a/core/dbt/loader.py +++ b/core/dbt/loader.py @@ -7,13 +7,12 @@ from dbt.node_types import NodeType from dbt.contracts.graph.manifest import Manifest -from dbt.utils import timestring from dbt.parser import MacroParser, ModelParser, SeedParser, AnalysisParser, \ DocumentationParser, DataTestParser, HookParser, SchemaParser, \ ParserUtils, SnapshotParser -from dbt.contracts.project import ProjectList +from datetime import datetime class GraphLoader: @@ -132,7 +131,8 @@ def load(self, internal_manifest=None): self._load_macros(internal_manifest=internal_manifest) # make a manifest with just the macros to get the context self.macro_manifest = Manifest(macros=self.macros, nodes={}, docs={}, - generated_at=timestring(), disabled=[]) + generated_at=datetime.utcnow(), + disabled=[]) self._load_nodes() self._load_docs() self._load_schema_tests() @@ -142,7 +142,7 @@ def create_manifest(self): nodes=self.nodes, macros=self.macros, docs=self.docs, - generated_at=timestring(), + generated_at=datetime.utcnow(), config=self.root_project, disabled=self.disabled ) @@ -156,9 +156,6 @@ def create_manifest(self): @classmethod def _load_from_projects(cls, root_config, projects, internal_manifest): - if dbt.flags.STRICT_MODE: - ProjectList(**projects) - loader = cls(root_config, projects) loader.load(internal_manifest=internal_manifest) return loader.create_manifest() diff --git a/core/dbt/main.py b/core/dbt/main.py index 4491f6530de..92bc7be25a7 100644 --- a/core/dbt/main.py +++ b/core/dbt/main.py @@ -32,7 +32,7 @@ import dbt.profiler from dbt.utils import ExitCodes -from dbt.config import UserConfig, PROFILES_DIR +from dbt.config import PROFILES_DIR, read_user_config from dbt.exceptions import RuntimeException @@ -120,11 +120,7 @@ def initialize_config_values(parsed): twice, but dbt's intialization is not structured in a way that makes that easy. """ - try: - cfg = UserConfig.from_directory(parsed.profiles_dir) - except RuntimeException: - cfg = UserConfig.from_dict(None) - + cfg = read_user_config(parsed.profiles_dir) cfg.set_values(parsed.profiles_dir) diff --git a/core/dbt/node_runners.py b/core/dbt/node_runners.py index 3ba097b2f60..e5edcf76070 100644 --- a/core/dbt/node_runners.py +++ b/core/dbt/node_runners.py @@ -2,8 +2,10 @@ from dbt.exceptions import NotImplementedException, CompilationException, \ RuntimeException, InternalException, missing_materialization from dbt.node_types import NodeType -from dbt.contracts.results import RunModelResult, collect_timing_info, \ - SourceFreshnessResult, PartialResult, RemoteCompileResult, RemoteRunResult +from dbt.contracts.results import ( + RunModelResult, collect_timing_info, SourceFreshnessResult, PartialResult, + RemoteCompileResult, RemoteRunResult, ResultTable, +) from dbt.compilation import compile_node import dbt.context.runtime @@ -16,7 +18,6 @@ import threading import time import traceback -from datetime import timedelta INTERNAL_ERROR_STRING = """This is an error in dbt. Please try again. If \ @@ -77,19 +78,18 @@ def run_with_hooks(self, manifest): return result def _build_run_result(self, node, start_time, error, status, timing_info, - skip=False, failed=None, agate_table=None): + skip=False, fail=None, agate_table=None): execution_time = time.time() - start_time thread_id = threading.current_thread().name - timing = [t.serialize() for t in timing_info] return RunModelResult( node=node, error=error, skip=skip, status=status, - failed=failed, + fail=fail, execution_time=execution_time, thread_id=thread_id, - timing=timing, + timing=timing_info, agate_table=agate_table, ) @@ -118,14 +118,14 @@ def from_run_result(self, result, start_time, timing_info): error=result.error, skip=result.skip, status=result.status, - failed=result.failed, + fail=result.fail, timing_info=timing_info, agate_table=result.agate_table, ) def compile_and_execute(self, manifest, ctx): result = None - self.adapter.acquire_connection(self.node.get('name')) + self.adapter.acquire_connection(self.node.name) with collect_timing_info('compile') as timing_info: # if we fail here, we still have a compiled node to return # this has the benefit of showing a build path for the errant @@ -147,6 +147,7 @@ def _handle_catchable_exception(self, e, ctx): if e.node is None: e.node = ctx.node + logger.debug(str(e), exc_info=True) return str(e) def _handle_internal_exception(self, e, ctx): @@ -158,11 +159,11 @@ def _handle_internal_exception(self, e, ctx): error=str(e).strip(), note=INTERNAL_ERROR_STRING ) - logger.debug(error) + logger.debug(error, exc_info=True) return str(e) def _handle_generic_exception(self, e, ctx): - node_description = self.node.get('build_path') + node_description = self.node.build_path if node_description is None: node_description = self.node.unique_id prefix = "Unhandled error while executing {}".format(node_description) @@ -372,33 +373,10 @@ def after_execute(self, result): self.node_index, self.num_nodes) - def _calculate_status(self, target_freshness, freshness): - """Calculate the status of a run. - - :param dict target_freshness: The target freshness dictionary. It must - match the freshness spec. - :param timedelta freshness: The actual freshness of the data, as - calculated from the database's timestamps - """ - # if freshness > warn_after > error_after, you'll get an error, not a - # warning - for key in ('error', 'warn'): - fullkey = '{}_after'.format(key) - if fullkey not in target_freshness: - continue - - target = target_freshness[fullkey] - kwname = target['period'] + 's' - kwargs = {kwname: target['count']} - if freshness > timedelta(**kwargs).total_seconds(): - return key - return 'pass' - def _build_run_result(self, node, start_time, error, status, timing_info, skip=False, failed=None): execution_time = time.time() - start_time thread_id = threading.current_thread().name - timing = [t.serialize() for t in timing_info] if status is not None: status = status.lower() return PartialResult( @@ -407,12 +385,12 @@ def _build_run_result(self, node, start_time, error, status, timing_info, error=error, execution_time=execution_time, thread_id=thread_id, - timing=timing + timing=timing_info, ) def from_run_result(self, result, start_time, timing_info): result.execution_time = (time.time() - start_time) - result.timing.extend(t.serialize() for t in timing_info) + result.timing.extend(timing_info) return result def execute(self, compiled_node, manifest): @@ -426,10 +404,7 @@ def execute(self, compiled_node, manifest): manifest=manifest ) - status = self._calculate_status( - compiled_node.freshness, - freshness['age'] - ) + status = compiled_node.freshness.status(freshness['age']) return SourceFreshnessResult( node=compiled_node, @@ -474,7 +449,7 @@ def execute_test(self, test): num_cols = len(table.columns) raise RuntimeError( "Bad test {name}: Returned {rows} rows and {cols} cols" - .format(name=test.get('name'), rows=num_rows, cols=num_cols)) + .format(name=test.name, rows=num_rows, cols=num_cols)) return table[0][0] @@ -529,6 +504,7 @@ def __init__(self, config, adapter, node, node_index, num_nodes): super().__init__(config, adapter, node, node_index, num_nodes) def handle_exception(self, e, ctx): + logger.debug('Got an exception: {}'.format(e), exc_info=True) if isinstance(e, dbt.exceptions.Exception): if isinstance(e, dbt.exceptions.RuntimeException): e.node = ctx.node @@ -552,7 +528,8 @@ def execute(self, compiled_node, manifest): return RemoteCompileResult( raw_sql=compiled_node.raw_sql, compiled_sql=compiled_node.injected_sql, - node=compiled_node + node=compiled_node, + timing=[], # this will get added later ) def error_result(self, node, error, start_time, timing_info): @@ -564,37 +541,37 @@ def ephemeral_result(self, node, start_time, timing_info): ) def from_run_result(self, result, start_time, timing_info): - timing = [t.serialize() for t in timing_info] return RemoteCompileResult( raw_sql=result.raw_sql, compiled_sql=result.compiled_sql, node=result.node, - timing=timing + timing=timing_info, ) class RPCExecuteRunner(RPCCompileRunner): def from_run_result(self, result, start_time, timing_info): - timing = [t.serialize() for t in timing_info] return RemoteRunResult( raw_sql=result.raw_sql, compiled_sql=result.compiled_sql, node=result.node, table=result.table, - timing=timing + timing=timing_info, ) def execute(self, compiled_node, manifest): status, table = self.adapter.execute(compiled_node.injected_sql, fetch=True) - table = { - 'column_names': list(table.column_names), - 'rows': [list(row) for row in table] - } + + table = ResultTable( + column_names=list(table.column_names), + rows=[list(row) for row in table], + ) return RemoteRunResult( raw_sql=compiled_node.raw_sql, compiled_sql=compiled_node.injected_sql, node=compiled_node, - table=table + table=table, + timing=[], ) diff --git a/core/dbt/node_types.py b/core/dbt/node_types.py index 63aca1efd1d..ffbc1eb9af8 100644 --- a/core/dbt/node_types.py +++ b/core/dbt/node_types.py @@ -1,21 +1,17 @@ -from enum import Enum +from hologram.helpers import StrEnum -class NodeType(str, Enum): - Base = 'base' +class NodeType(StrEnum): Model = 'model' Analysis = 'analysis' Test = 'test' Snapshot = 'snapshot' - Macro = 'macro' Operation = 'operation' Seed = 'seed' + RPCCall = 'rpc' Documentation = 'docs' Source = 'source' - RPCCall = 'rpc' - - def __str__(self): - return self._value_ + Macro = 'macro' @classmethod def executable(cls): @@ -39,9 +35,40 @@ def refable(cls): ]] -class RunHookType(str, Enum): +class UnparsedNodeType(StrEnum): + Model = str(NodeType.Model) + Analysis = str(NodeType.Analysis) + Test = str(NodeType.Test) + Snapshot = str(NodeType.Snapshot) + Operation = str(NodeType.Operation) + Seed = str(NodeType.Seed) + RPCCall = str(NodeType.RPCCall) + + +class DocumentationType(StrEnum): + Documentation = str(NodeType.Documentation) + + +class RunHookType(StrEnum): Start = 'on-run-start' End = 'on-run-end' - def __str__(self): - return self._value_ + +class OperationType(StrEnum): + Operation = str(NodeType.Operation) + + +class SnapshotType(StrEnum): + Snapshot = str(NodeType.Snapshot) + + +class MacroType(StrEnum): + Macro = str(NodeType.Macro) + + +class SourceType(StrEnum): + Source = str(NodeType.Source) + + +class TestType(StrEnum): + Test = str(NodeType.Test) diff --git a/core/dbt/parser/base.py b/core/dbt/parser/base.py index b750565f17c..646288e426e 100644 --- a/core/dbt/parser/base.py +++ b/core/dbt/parser/base.py @@ -12,14 +12,24 @@ from dbt.utils import coalesce from dbt.logger import GLOBAL_LOGGER as logger from dbt.contracts.graph.parsed import ParsedNode +from dbt.contracts.project import ProjectList from dbt.parser.source_config import SourceConfig from dbt import deprecations +from dbt import hooks class BaseParser: - def __init__(self, root_project_config, all_projects): + def __init__(self, root_project_config, all_projects: ProjectList): self.root_project_config = root_project_config self.all_projects = all_projects + if dbt.flags.STRICT_MODE: + dct = { + 'projects': { + name: project.to_project_config(with_packages=True) + for name, project in all_projects.items() + } + } + ProjectList.from_dict(dct, validate=True) @property def default_schema(self): @@ -137,6 +147,15 @@ def get_alias(custom_alias_name, node): self._get_alias_func = get_alias return self._get_alias_func + def _mangle_hooks(self, config): + """Given a config dict that may have `pre-hook`/`post-hook` keys, + convert it from the yucky maybe-a-string, maybe-a-dict to a dict. + """ + # Like most of parsing, this is a horrible hack :( + for key in hooks.ModelHookType: + if key in config: + config[key] = [hooks.get_hook_dict(h) for h in config[key]] + def _build_intermediate_node_dict(self, config, node_dict, node_path, package_project_config, tags, fqn, snapshot_config, column_name): @@ -149,10 +168,7 @@ def _build_intermediate_node_dict(self, config, node_dict, node_path, # TODO: Restructure this? config_dict = coalesce(snapshot_config, {}) config_dict.update(config.config) - - empty = ( - 'raw_sql' in node_dict and len(node_dict['raw_sql'].strip()) == 0 - ) + self._mangle_hooks(config_dict) node_dict.update({ 'refs': [], @@ -162,7 +178,6 @@ def _build_intermediate_node_dict(self, config, node_dict, node_path, 'macros': [], }, 'unique_id': node_path, - 'empty': empty, 'fqn': fqn, 'tags': tags, 'config': config_dict, @@ -192,7 +207,7 @@ def _render_with_context(self, parsed_node, config): config) dbt.clients.jinja.get_rendered( - parsed_node.raw_sql, context, parsed_node.to_shallow_dict(), + parsed_node.raw_sql, context, parsed_node, capture_macros=True) def _update_parsed_node_info(self, parsed_node, config): @@ -230,13 +245,14 @@ def _update_parsed_node_info(self, parsed_node, config): parsed_node.tags.extend(model_tags) # Overwrite node config - config_dict = parsed_node.get('config', {}) + config_dict = parsed_node.config.to_dict() config_dict.update(config.config) - parsed_node.config = config_dict + # re-mangle hooks, in case we got new ones + self._mangle_hooks(config_dict) + parsed_node.config = parsed_node.config.from_dict(config_dict) - for hook_type in dbt.hooks.ModelHookType: - parsed_node.config[hook_type] = dbt.hooks.get_hooks(parsed_node, - hook_type) + def _parse_from_dict(self, parsed_dict): + return ParsedNode.from_dict(parsed_dict) def parse_node(self, node, node_path, package_project_config, tags=None, fqn_extra=None, fqn=None, snapshot_config=None, @@ -262,15 +278,15 @@ def parse_node(self, node, node_path, package_project_config, tags=None, node.resource_type) parsed_dict = self._build_intermediate_node_dict( - config, node.serialize(), node_path, config, tags, fqn, + config, node.to_dict(), node_path, config, tags, fqn, snapshot_config, column_name ) - parsed_node = ParsedNode(**parsed_dict) + parsed_node = self._parse_from_dict(parsed_dict) self._render_with_context(parsed_node, config) self._update_parsed_node_info(parsed_node, config) - parsed_node.validate() + parsed_node.to_dict(validate=True) return parsed_node diff --git a/core/dbt/parser/base_sql.py b/core/dbt/parser/base_sql.py index cb788c94ed6..bc6a6b504ca 100644 --- a/core/dbt/parser/base_sql.py +++ b/core/dbt/parser/base_sql.py @@ -2,20 +2,27 @@ import os import dbt.contracts.project -import dbt.exceptions import dbt.clients.system import dbt.utils import dbt.flags +from dbt.exceptions import ( + CompilationException, InternalException, NotImplementedException, + raise_duplicate_resource_name, validator_error_message +) from dbt.contracts.graph.unparsed import UnparsedNode from dbt.parser.base import MacrosKnownParser from dbt.node_types import NodeType +from hologram import ValidationError + class BaseSqlParser(MacrosKnownParser): + UnparsedNodeType = UnparsedNode + @classmethod def get_compiled_path(cls, name, relative_path): - raise dbt.exceptions.NotImplementedException("Not implemented") + raise NotImplementedException("Not implemented") def load_and_parse(self, package_name, root_dir, relative_dirs, resource_type, tags=None): @@ -27,9 +34,6 @@ def load_and_parse(self, package_name, root_dir, relative_dirs, if tags is None: tags = [] - if dbt.flags.STRICT_MODE: - dbt.contracts.project.ProjectList(**self.all_projects) - file_matches = dbt.clients.system.find_matching( root_dir, relative_dirs, @@ -67,7 +71,7 @@ def parse_sql_node(self, node_dict, tags=None): if tags is None: tags = [] - node = UnparsedNode(**node_dict) + node = self.UnparsedNodeType.from_dict(node_dict) package_name = node.package_name unique_id = self.get_path(node.resource_type, @@ -82,11 +86,17 @@ def parse_sql_node(self, node_dict, tags=None): node.name, node.original_file_path, node.raw_sql ) - node_parsed = self.parse_node(node, unique_id, project, tags=tags) + try: + node_parsed = self.parse_node(node, unique_id, project, tags=tags) + except ValidationError as exc: + # we got a ValidationError - probably bad types in config() + msg = validator_error_message(exc) + raise CompilationException(msg, node=node) from exc + if not parse_ok: # if we had a parse error in parse_node, we would not get here. So # this means we rejected a good file :( - raise dbt.exceptions.InternalException( + raise InternalException( 'the block parser rejected a good node: {} was marked invalid ' 'but is actually valid!'.format(node.original_file_path) ) @@ -102,7 +112,7 @@ def parse_sql_nodes(self, nodes, tags=None): node_path, node_parsed = self.parse_sql_node(n, tags) # Ignore disabled nodes - if not node_parsed.config['enabled']: + if not node_parsed.config.enabled: results.disable(node_parsed) continue @@ -127,7 +137,7 @@ def disable(self, node): def keep(self, unique_id, node): if unique_id in self.parsed: - dbt.exceptions.raise_duplicate_resource_name( + raise_duplicate_resource_name( self.parsed[unique_id], node ) diff --git a/core/dbt/parser/docs.py b/core/dbt/parser/docs.py index 8fb0d1f3c9d..18db1ab0754 100644 --- a/core/dbt/parser/docs.py +++ b/core/dbt/parser/docs.py @@ -1,5 +1,4 @@ import dbt.exceptions -from dbt.node_types import NodeType from dbt.parser.base import BaseParser from dbt.contracts.graph.unparsed import UnparsedDocumentationFile from dbt.contracts.graph.parsed import ParsedDocumentation @@ -35,7 +34,6 @@ def load_file(cls, package_name, root_dir, relative_dirs): yield UnparsedDocumentationFile( root_path=root_dir, - resource_type=NodeType.Documentation, path=path, original_file_path=original_file_path, package_name=package_name, @@ -77,14 +75,14 @@ def _parse_template_docs(self, template, docfile): unique_id = '{}.{}'.format(docfile.package_name, name) merged = dbt.utils.deep_merge( - docfile.serialize(), + docfile.to_dict(), { 'name': name, 'unique_id': unique_id, 'block_contents': item().strip(), } ) - yield ParsedDocumentation(**merged) + yield ParsedDocumentation.from_dict(merged) def load_and_parse(self, package_name, root_dir, relative_dirs): to_return = {} diff --git a/core/dbt/parser/hooks.py b/core/dbt/parser/hooks.py index 93763a94bba..4d5d38a0652 100644 --- a/core/dbt/parser/hooks.py +++ b/core/dbt/parser/hooks.py @@ -5,11 +5,14 @@ import dbt.contracts.project import dbt.utils +from dbt.contracts.graph.unparsed import UnparsedRunHook from dbt.parser.base_sql import BaseSqlParser from dbt.node_types import NodeType, RunHookType class HookParser(BaseSqlParser): + UnparsedNodeType = UnparsedRunHook + @classmethod def get_hooks_from_project(cls, config, hook_type): if hook_type == RunHookType.Start: @@ -55,14 +58,14 @@ def load_and_parse_run_hook_type(self, hook_type): 'index': i }) - tags = [hook_type] + # hook_type is a RunHookType member, which "is a string", but it's also + # an enum, so hologram gets mad about that before even looking at if + # it's a string - bypass it by explicitly calling str(). + tags = [str(hook_type)] results = self.parse_sql_nodes(result, tags=tags) return results.parsed def load_and_parse(self): - if dbt.flags.STRICT_MODE: - dbt.contracts.project.ProjectList(**self.all_projects) - hook_nodes = {} for hook_type in RunHookType: project_hooks = self.load_and_parse_run_hook_type( diff --git a/core/dbt/parser/macros.py b/core/dbt/parser/macros.py index 3ad2978dee7..e3976ce0d67 100644 --- a/core/dbt/parser/macros.py +++ b/core/dbt/parser/macros.py @@ -28,13 +28,13 @@ def parse_macro_file(self, macro_file_path, macro_file_contents, root_path, if tags is None: tags = [] - # change these to actual kwargs base_node = UnparsedMacro( path=macro_file_path, original_file_path=macro_file_path, package_name=package_name, raw_sql=macro_file_contents, root_path=root_path, + resource_type=resource_type, ) try: @@ -57,16 +57,15 @@ def parse_macro_file(self, macro_file_path, macro_file_contents, root_path, unique_id = self.get_path(resource_type, package_name, name) merged = dbt.utils.deep_merge( - base_node.serialize(), + base_node.to_dict(), { 'name': name, 'unique_id': unique_id, 'tags': tags, - 'resource_type': resource_type, 'depends_on': {'macros': []}, }) - new_node = ParsedMacro(**merged) + new_node = ParsedMacro.from_dict(merged) to_return[unique_id] = new_node @@ -79,9 +78,6 @@ def load_and_parse(self, package_name, root_dir, relative_dirs, if tags is None: tags = [] - if dbt.flags.STRICT_MODE: - dbt.contracts.project.ProjectList(**self.all_projects) - file_matches = dbt.clients.system.find_matching( root_dir, relative_dirs, diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index 8cbd07f04c3..cbd908138c0 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -4,6 +4,8 @@ import re import hashlib +from hologram import ValidationError + import dbt.exceptions import dbt.flags import dbt.utils @@ -12,6 +14,7 @@ import dbt.context.parser import dbt.contracts.project +from dbt.contracts.graph.parsed import ColumnInfo, Docref from dbt.context.common import generate_config_context from dbt.clients.jinja import get_rendered from dbt.node_types import NodeType @@ -19,9 +22,13 @@ from dbt.utils import get_pseudo_test_path from dbt.contracts.graph.unparsed import UnparsedNode, UnparsedNodeUpdate, \ UnparsedSourceDefinition -from dbt.contracts.graph.parsed import ParsedNodePatch, ParsedSourceDefinition +from dbt.contracts.graph.parsed import ParsedNodePatch, ParsedTestNode, \ + ParsedSourceDefinition from dbt.parser.base import MacrosKnownParser from dbt.config.renderer import ConfigRenderer +from dbt.exceptions import JSONValidationException, validator_error_message + +from typing import Dict, List def get_nice_schema_test_name(test_type, test_name, args): @@ -179,11 +186,11 @@ def build_raw_sql(self): class RefTestBuilder(TestBuilder): def build_model_str(self): - return "ref('{}')".format(self.target['name']) + return "ref('{}')".format(self.target.name) def get_test_name(self): return get_nice_schema_test_name(self.name, - self.target['name'], + self.target.name, self.args) def describe_test_target(self): @@ -193,13 +200,13 @@ def describe_test_target(self): class SourceTestBuilder(TestBuilder): def build_model_str(self): return "source('{}', '{}')".format( - self.target['source']['name'], - self.target['table']['name'] + self.target['source'].name, + self.target['table'].name ) def get_test_name(self): - target_name = '{}_{}'.format(self.target['source']['name'], - self.target['table']['name']) + target_name = '{}_{}'.format(self.target['source'].name, + self.target['table'].name) return get_nice_schema_test_name( 'source_' + self.name, target_name, @@ -228,10 +235,14 @@ def _filter_validate(filepath, location, values, validate): warn_invalid(filepath, location, value, '(expected a dict)') continue try: - yield validate(**value) - except dbt.exceptions.JSONValidationException as exc: + yield validate(value) # we don't want to fail the full run, but we do want to fail # parsing this file + except ValidationError as exc: + msg = validator_error_message(exc) + warn_invalid(filepath, location, value, '- ' + msg) + continue + except JSONValidationException as exc: warn_invalid(filepath, location, value, '- ' + exc.msg) continue @@ -239,14 +250,12 @@ def _filter_validate(filepath, location, values, validate): class ParserRef: """A helper object to hold parse-time references.""" def __init__(self): - self.column_info = {} - self.docrefs = [] + self.column_info: Dict[str, ColumnInfo] = {} + self.docrefs: List[Docref] = [] def add(self, column_name, description): - self.column_info[column_name] = { - 'name': column_name, - 'description': description, - } + self.column_info[column_name] = ColumnInfo(name=column_name, + description=description) class SchemaBaseTestParser(MacrosKnownParser): @@ -255,8 +264,8 @@ class SchemaBaseTestParser(MacrosKnownParser): def _parse_column(self, target, column, package_name, root_dir, path, refs): # this should yield ParsedNodes where resource_type == NodeType.Test - column_name = column['name'] - description = column.get('description', '') + column_name = column.name + description = column.description refs.add(column_name, description) context = { @@ -264,7 +273,7 @@ def _parse_column(self, target, column, package_name, root_dir, path, } get_rendered(description, context) - for test in column.get('tests', []): + for test in column.tests: try: yield self.build_test_node( target, package_name, test, root_dir, @@ -277,6 +286,9 @@ def _parse_column(self, target, column, package_name, root_dir, path, ) continue + def _parse_from_dict(self, parsed_dict): + return ParsedTestNode.from_dict(parsed_dict) + def build_test_node(self, test_target, package_name, test, root_dir, path, column_name=None): """Build a test node against the given target (a model or a source). @@ -320,7 +332,7 @@ def build_test_node(self, test_target, package_name, test, root_dir, path, # supply our own fqn which overrides the hashed version from the path # TODO: is this necessary even a little bit for tests? - fqn_override = self.get_fqn(unparsed.incorporate(path=full_path), + fqn_override = self.get_fqn(unparsed.replace(path=full_path), source_package) node_path = self.get_path(NodeType.Test, unparsed.package_name, @@ -348,18 +360,18 @@ def build_test_node(self, test_target, package_name, test, root_dir, path, class SchemaModelParser(SchemaBaseTestParser): Builder = RefTestBuilder - def parse_models_entry(self, model_dict, path, package_name, root_dir): - model_name = model_dict['name'] + def parse_models_entry(self, model, path, package_name, root_dir): + model_name = model.name refs = ParserRef() - for column in model_dict.get('columns', []): - column_tests = self._parse_column(model_dict, column, package_name, + for column in model.columns: + column_tests = self._parse_column(model, column, package_name, root_dir, path, refs) for node in column_tests: yield 'test', node - for test in model_dict.get('tests', []): + for test in model.tests: try: - node = self.build_test_node(model_dict, package_name, test, + node = self.build_test_node(model, package_name, test, root_dir, path) except dbt.exceptions.CompilationException as exc: dbt.exceptions.warn_or_error( @@ -369,8 +381,8 @@ def parse_models_entry(self, model_dict, path, package_name, root_dir): continue yield 'test', node - context = {'doc': dbt.context.parser.docs(model_dict, refs.docrefs)} - description = model_dict.get('description', '') + context = {'doc': dbt.context.parser.docs(model, refs.docrefs)} + description = model.description get_rendered(description, context) patch = ParsedNodePatch( @@ -391,7 +403,8 @@ def parse_all(self, models, path, package_name, root_dir): :param str package_name: The name of the current package :param str root_dir: The root directory of the search """ - filtered = _filter_validate(path, 'models', models, UnparsedNodeUpdate) + filtered = _filter_validate(path, 'models', models, + UnparsedNodeUpdate.from_dict) nodes = itertools.chain.from_iterable( self.parse_models_entry(model, path, package_name, root_dir) for model in filtered @@ -436,28 +449,22 @@ def generate_source_node(self, source, table, path, package_name, root_dir, source.name, table.name) context = {'doc': dbt.context.parser.docs(source, refs.docrefs)} - description = table.get('description', '') - source_description = source.get('description', '') + description = table.description or '' + source_description = source.description or '' get_rendered(description, context) get_rendered(source_description, context) - freshness = dbt.utils.deep_merge(source.get('freshness', {}), - table.get('freshness', {})) - - loaded_at_field = table.get('loaded_at_field', - source.get('loaded_at_field')) + loaded_at_field = table.loaded_at_field or source.loaded_at_field + freshness = source.freshness.merged(table.freshness) - # use 'or {}' to allow quoting: null - source_quoting = source.get('quoting') or {} - table_quoting = table.get('quoting') or {} - quoting = dbt.utils.deep_merge(source_quoting, table_quoting) + quoting = source.quoting.merged(table.quoting) default_database = self.root_project_config.credentials.database return ParsedSourceDefinition( package_name=package_name, - database=source.get('database', default_database), - schema=source.get('schema', source.name), - identifier=table.get('identifier', table.name), + database=(source.database or default_database), + schema=(source.schema or source.name), + identifier=(table.identifier or table.name), root_path=root_dir, path=path, original_file_path=path, @@ -467,7 +474,7 @@ def generate_source_node(self, source, table, path, package_name, root_dir, description=description, source_name=source.name, source_description=source_description, - loader=source.get('loader', ''), + loader=source.loader, docrefs=refs.docrefs, loaded_at_field=loaded_at_field, freshness=freshness, @@ -479,14 +486,14 @@ def generate_source_node(self, source, table, path, package_name, root_dir, def parse_source_table(self, source, table, path, package_name, root_dir): refs = ParserRef() test_target = {'source': source, 'table': table} - for column in table.get('columns', []): + for column in table.columns: column_tests = self._parse_column(test_target, column, package_name, root_dir, path, refs) for node in column_tests: yield 'test', node - for test in table.get('tests', []): + for test in table.tests: try: node = self.build_test_node(test_target, package_name, test, root_dir, path) @@ -510,9 +517,9 @@ def parse_source_entry(self, source, path, package_name, root_dir): for node_type, node in nodes: yield node_type, node - def _sources_validate(self, **kwargs): + def _sources_validate(self, kwargs): kwargs = self._renderer.render_schema_source(kwargs) - return UnparsedSourceDefinition(**kwargs) + return UnparsedSourceDefinition.from_dict(kwargs) def parse_all(self, sources, path, package_name, root_dir): """Parse all the model dictionaries in sources. @@ -605,8 +612,6 @@ def _parse_format_version(self, path, test_yml): return version def load_and_parse(self, package_name, root_dir, relative_dirs): - if dbt.flags.STRICT_MODE: - dbt.contracts.project.ProjectList(**self.all_projects) new_tests = {} # test unique ID -> ParsedNode node_patches = {} # model name -> dict new_sources = {} # source unique ID -> ParsedSourceDefinition diff --git a/core/dbt/parser/seeds.py b/core/dbt/parser/seeds.py index be6f04e6355..f53f39cfa0f 100644 --- a/core/dbt/parser/seeds.py +++ b/core/dbt/parser/seeds.py @@ -41,8 +41,6 @@ def load_and_parse(self, package_name, root_dir, relative_dirs, tags=None): that maps unique ids onto ParsedNodes""" extension = "[!.#~]*.csv" - if dbt.flags.STRICT_MODE: - dbt.contracts.project.ProjectList(**self.all_projects) file_matches = dbt.clients.system.find_matching( root_dir, diff --git a/core/dbt/parser/snapshots.py b/core/dbt/parser/snapshots.py index bbe340bfd19..a80d63a1208 100644 --- a/core/dbt/parser/snapshots.py +++ b/core/dbt/parser/snapshots.py @@ -1,30 +1,20 @@ -from dbt.contracts.graph.parsed import ParsedSnapshotNode +from dbt.contracts.graph.parsed import ParsedSnapshotNode, \ + IntermediateSnapshotNode +from dbt.exceptions import CompilationException, validator_error_message from dbt.node_types import NodeType from dbt.parser.base_sql import BaseSqlParser, SQLParseResult import dbt.clients.jinja -import dbt.exceptions import dbt.utils +from hologram import ValidationError + def set_snapshot_attributes(node): - # Default the target database to the database specified in the target - # This line allows target_database to be optional in the snapshot config - if 'target_database' not in node.config: - node.config['target_database'] = node.database - - # Set the standard node configs (database+schema) to be the specified - # values from target_database and target_schema. This ensures that the - # database and schema names are interopolated correctly when snapshots - # are ref'd from other models - config_keys = { - 'target_database': 'database', - 'target_schema': 'schema' - } - - for config_key, node_key in config_keys.items(): - if config_key in node.config: - setattr(node, node_key, node.config[config_key]) + if node.config.target_database: + node.database = node.config.target_database + if node.config.target_schema: + node.schema = node.config.target_schema return node @@ -39,7 +29,7 @@ def parse_snapshots_from_file(self, file_node, tags=None): allowed_blocks={'snapshot'}, collect_raw_data=False ) - except dbt.exceptions.CompilationException as exc: + except CompilationException as exc: if exc.node is None: exc.node = file_node raise @@ -66,15 +56,18 @@ def get_fqn(cls, node, package_project_config, extra=[]): return fqn + def _parse_from_dict(self, parsed_dict): + return IntermediateSnapshotNode.from_dict(parsed_dict) + @staticmethod def validate_snapshots(node): if node.resource_type == NodeType.Snapshot: try: - parsed_node = ParsedSnapshotNode(**node.to_shallow_dict()) + parsed_node = ParsedSnapshotNode.from_dict(node.to_dict()) return set_snapshot_attributes(parsed_node) - except dbt.exceptions.JSONValidationException as exc: - raise dbt.exceptions.CompilationException(str(exc), node) + except ValidationError as exc: + raise CompilationException(validator_error_message(exc), node) else: return node @@ -90,8 +83,8 @@ def parse_sql_nodes(self, nodes, tags=None): self.parse_snapshots_from_file(file_node, tags=tags) ) found = super().parse_sql_nodes(nodes=snapshot_nodes, tags=tags) - # make sure our blocks are going to work when we try to snapshot - # them! + # Our snapshots are all stored as IntermediateSnapshotNodes, so + # convert them to their final form found.parsed = {k: self.validate_snapshots(v) for k, v in found.parsed.items()} diff --git a/core/dbt/parser/util.py b/core/dbt/parser/util.py index afe16807eef..5d93cf1e1d8 100644 --- a/core/dbt/parser/util.py +++ b/core/dbt/parser/util.py @@ -2,6 +2,7 @@ import dbt.exceptions import dbt.utils from dbt.node_types import NodeType +from dbt.contracts.graph.parsed import ColumnInfo def docs(node, manifest, config, column_name=None): @@ -112,54 +113,49 @@ def _get_node_column(cls, node, column_name): reference to the dict that refers to the given column, creating it if it doesn't yet exist. """ - if not hasattr(node, 'columns'): - node.set('columns', {}) + # if not hasattr(node, 'columns'): + # node.set('columns', {}) if column_name in node.columns: column = node.columns[column_name] else: - column = {'name': column_name, 'description': ''} + node.columns[column_name] = ColumnInfo(name=column_name) node.columns[column_name] = column return column @classmethod def process_docs_for_node(cls, manifest, current_project, node): - for docref in node.get('docrefs', []): - column_name = docref.get('column_name') + for docref in node.docrefs: + column_name = docref.column_name + if column_name is None: - description = node.get('description', '') + obj = node else: - column = cls._get_node_column(node, column_name) - description = column.get('description', '') + obj = cls._get_node_column(node, column_name) + context = { 'doc': docs(node, manifest, current_project, column_name), } - # At this point, target_doc is a ParsedDocumentation, and we - # know that our documentation string has a 'docs("...")' - # pointing at it. We want to render it. - description = dbt.clients.jinja.get_rendered(description, - context) - # now put it back. - if column_name is None: - node.set('description', description) - else: - column['description'] = description + raw = obj.description or '' + # At this point, we know that our documentation string has a + # 'docs("...")' pointing at it. We want to render it. + obj.description = dbt.clients.jinja.get_rendered(raw, context) @classmethod def process_docs_for_source(cls, manifest, current_project, source): context = { 'doc': docs(source, manifest, current_project), } - table_description = source.get('description', '') - source_description = source.get('source_description', '') + table_description = source.description + source_description = source.source_description table_description = dbt.clients.jinja.get_rendered(table_description, context) source_description = dbt.clients.jinja.get_rendered(source_description, context) - source.set('description', table_description) - source.set('source_description', source_description) + source.description = table_description + source.source_description = source_description @classmethod def process_docs(cls, manifest, current_project): @@ -188,12 +184,12 @@ def process_refs_for_node(cls, manifest, current_project, node): target_model_name, target_model_package, current_project, - node.get('package_name')) + node.package_name) if target_model is None or target_model is cls.DISABLED: # This may raise. Even if it doesn't, we don't want to add # this node to the graph b/c there is no destination node - node.config['enabled'] = False + node.config.enabled = False dbt.utils.invalid_ref_fail_unless_test( node, target_model_name, target_model_package, disabled=(target_model is cls.DISABLED) @@ -201,10 +197,10 @@ def process_refs_for_node(cls, manifest, current_project, node): continue - target_model_id = target_model.get('unique_id') + target_model_id = target_model.unique_id - node.depends_on['nodes'].append(target_model_id) - manifest.nodes[node['unique_id']] = node + node.depends_on.nodes.append(target_model_id) + manifest.nodes[node.unique_id] = node @classmethod def process_refs(cls, manifest, current_project): @@ -221,19 +217,19 @@ def process_sources_for_node(cls, manifest, current_project, node): source_name, table_name, current_project, - node.get('package_name')) + node.package_name) if target_source is None: # this folows the same pattern as refs - node.config['enabled'] = False + node.config.enabled = False dbt.utils.invalid_source_fail_unless_test( node, source_name, table_name) continue target_source_id = target_source.unique_id - node.depends_on['nodes'].append(target_source_id) - manifest.nodes[node['unique_id']] = node + node.depends_on.nodes.append(target_source_id) + manifest.nodes[node.unique_id] = node @classmethod def process_sources(cls, manifest, current_project): diff --git a/core/dbt/semver.py b/core/dbt/semver.py index bb7b9afba71..03e9bb162b5 100644 --- a/core/dbt/semver.py +++ b/core/dbt/semver.py @@ -1,12 +1,33 @@ -from enum import Enum +from dataclasses import dataclass import re import logging -from dbt.api.object import APIObject -from dbt.contracts.project import VERSION_SPECIFICATION_CONTRACT from dbt.exceptions import VersionsNotCompatibleException import dbt.utils +from hologram import JsonSchemaMixin +from hologram.helpers import StrEnum +from typing import Optional + + +class Matchers(StrEnum): + GREATER_THAN = '>' + GREATER_THAN_OR_EQUAL = '>=' + LESS_THAN = '<' + LESS_THAN_OR_EQUAL = '<=' + EXACT = '=' + + +@dataclass +class VersionSpecification(JsonSchemaMixin): + major: Optional[str] + minor: Optional[str] + patch: Optional[str] + prerelease: Optional[str] + build: Optional[str] + matcher: Matchers = Matchers.EXACT + + logger = logging.getLogger(__name__) _MATCHERS = r"(?P\>=|\>|\<|\<=|=)?" @@ -46,17 +67,6 @@ _VERSION_REGEX = re.compile(_VERSION_REGEX_PAT_STR, re.VERBOSE) -class Matchers(str, Enum): - GREATER_THAN = '>' - GREATER_THAN_OR_EQUAL = '>=' - LESS_THAN = '<' - LESS_THAN_OR_EQUAL = '<=' - EXACT = '=' - - def __str__(self): - return self._value_ - - class VersionRange(dbt.utils.AttrDict): def _try_combine_exact(self, a, b): @@ -166,15 +176,8 @@ def to_version_string_pair(self): return to_return -class VersionSpecifier(APIObject): - SCHEMA = VERSION_SPECIFICATION_CONTRACT - - def __init__(self, *args, **kwargs): - kwargs = dict(*args, **kwargs) - if kwargs.get('matcher') is None: - kwargs['matcher'] = Matchers.EXACT - super().__init__(**kwargs) - +@dataclass +class VersionSpecifier(VersionSpecification): def to_version_string(self, skip_matcher=False): prerelease = '' build = '' @@ -204,7 +207,9 @@ def from_version_string(cls, version_string): raise dbt.exceptions.SemverException( 'Could not parse version "{}"'.format(version_string)) - return VersionSpecifier(match.groupdict()) + matched = {k: v for k, v in match.groupdict().items() if v is not None} + + return cls.from_dict(matched) def __str__(self): return self.to_version_string() @@ -234,7 +239,7 @@ def compare(self, other): return 0 for key in ['major', 'minor', 'patch']: - comparison = int(self[key]) - int(other[key]) + comparison = int(getattr(self, key)) - int(getattr(other, key)) if comparison > 0: return 1 @@ -300,7 +305,7 @@ def is_exact(self): class UnboundedVersionSpecifier(VersionSpecifier): def __init__(self, *args, **kwargs): super().__init__( - matcher='=', + matcher=Matchers.EXACT, major=None, minor=None, patch=None, diff --git a/core/dbt/task/compile.py b/core/dbt/task/compile.py index 7f4f634fb67..42ff1ce9a41 100644 --- a/core/dbt/task/compile.py +++ b/core/dbt/task/compile.py @@ -124,6 +124,7 @@ def _in_thread(self, node, thread_done): try: self.node_results.append(runner.safe_run(self.manifest)) except Exception as exc: + rpc_logger.debug('Got exception {}'.format(exc), exc_info=True) self._raise_next_tick = exc finally: thread_done.set() @@ -160,4 +161,4 @@ def handle_request(self, name, sql, macros=None): raise dbt.exceptions.RPCKilledException(signal.SIGINT) self._raise_set_error() - return self.node_results[0].serialize() + return self.node_results[0].to_dict() diff --git a/core/dbt/task/deps.py b/core/dbt/task/deps.py index 42eb51b5fb7..e9042759d61 100644 --- a/core/dbt/task/deps.py +++ b/core/dbt/task/deps.py @@ -1,24 +1,26 @@ +import abc +import hashlib import os import shutil -import hashlib import tempfile -import yaml +from dataclasses import dataclass, field +from typing import Union, Dict, Optional, List import dbt.utils import dbt.deprecations import dbt.exceptions -import dbt.clients.git -import dbt.clients.system -import dbt.clients.registry as registry +from dbt import semver +from dbt.ui import printer from dbt.logger import GLOBAL_LOGGER as logger -from dbt.semver import VersionSpecifier, UnboundedVersionSpecifier -from dbt.ui import printer -from dbt.utils import AttrDict -from dbt.api.object import APIObject -from dbt.contracts.project import LOCAL_PACKAGE_CONTRACT, \ - GIT_PACKAGE_CONTRACT, REGISTRY_PACKAGE_CONTRACT, \ - REGISTRY_PACKAGE_METADATA_CONTRACT, PackageConfig +from dbt.clients import git, registry, system +from dbt.contracts.project import ProjectPackageMetadata, \ + RegistryPackageMetadata, \ + LocalPackage as LocalPackageContract, \ + GitPackage as GitPackageContract, \ + RegistryPackage as RegistryPackageContract +from dbt.exceptions import raise_dependency_error, package_version_not_found, \ + VersionsNotCompatibleException, DependencyException from dbt.task.base import ProjectOnlyTask @@ -40,63 +42,77 @@ def _initialize_downloads(): DOWNLOADS_PATH = tempfile.mkdtemp(prefix='dbt-downloads-') REMOVE_DOWNLOADS = True - dbt.clients.system.make_directory(DOWNLOADS_PATH) + system.make_directory(DOWNLOADS_PATH) logger.debug("Set downloads directory='{}'".format(DOWNLOADS_PATH)) -class Package(APIObject): - SCHEMA = NotImplemented +PackageContract = Union[LocalPackageContract, GitPackageContract, + RegistryPackageContract] - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._cached_metadata = None - @property - def name(self): - raise NotImplementedError +def _parse_package(dict_: dict) -> PackageContract: + only_1_keys = ['package', 'git', 'local'] + specified = [k for k in only_1_keys if dict_.get(k)] + if len(specified) > 1: + dbt.exceptions.raise_dependency_error( + 'Packages should not contain more than one of {}; ' + 'yours has {} of them - {}' + .format(only_1_keys, len(specified), specified)) + if dict_.get('package'): + return RegistryPackageContract.from_dict(dict_) + if dict_.get('git'): + if dict_.get('version'): + msg = ("Keyword 'version' specified for git package {}.\nDid " + "you mean 'revision'?".format(dict_.get('git'))) + dbt.exceptions.raise_dependency_error(msg) + return GitPackageContract.from_dict(dict_) + if dict_.get('local'): + return LocalPackageContract.from_dict(dict_) + dbt.exceptions.raise_dependency_error( + 'Malformed package definition. Must contain package, git, or local.') + + +def md5sum(s: str): + return hashlib.md5(s.encode('latin-1')).hexdigest() + + +@dataclass +class Pinned(metaclass=abc.ABCMeta): + _cached_metadata: Optional[ProjectPackageMetadata] = field(init=False) + + def __post_init__(self): + self._cached_metadata = None def __str__(self): - version = getattr(self, 'version', None) + version = self.get_version() if not version: return self.name - version_str = version[0] \ - if len(version) == 1 else '' - return '{}@{}'.format(self.name, version_str) - @classmethod - def version_to_list(cls, version): - if version is None: - return [] - if not isinstance(version, (list, str)): - dbt.exceptions.raise_dependency_error( - 'version must be list or string, got {}' - .format(type(version))) - if not isinstance(version, list): - version = [version] - return version + return '{}@{}'.format(self.name, version) - def _resolve_version(self): - pass - - def resolve_version(self): - try: - self._resolve_version() - except dbt.exceptions.VersionsNotCompatibleException as e: - new_msg = ('Version error for package {}: {}' - .format(self.name, e)) - raise dbt.exceptions.DependencyException(new_msg) from e + @abc.abstractproperty + def name(self): + raise NotImplementedError + @abc.abstractmethod def source_type(self): - raise NotImplementedError() - - def version_name(self): - raise NotImplementedError() + raise NotImplementedError - def nice_version_name(self): - raise NotImplementedError() + @abc.abstractmethod + def get_version(self) -> Optional[str]: + raise NotImplementedError + @abc.abstractmethod def _fetch_metadata(self, project): - raise NotImplementedError() + raise NotImplementedError + + @abc.abstractmethod + def install(self, project): + raise NotImplementedError + + @abc.abstractmethod + def nice_version_name(self): + raise NotImplementedError def fetch_metadata(self, project): if not self._cached_metadata: @@ -112,189 +128,89 @@ def get_installation_path(self, project): return os.path.join(project.modules_path, dest_dirname) -class RegistryPackage(Package): - SCHEMA = REGISTRY_PACKAGE_CONTRACT - - def __init__(self, *args, **kwargs): - if 'version' not in kwargs: - dbt.exceptions.raise_dependency_error( - 'package dependency {} is missing a "version" field' - .format(kwargs.get('package')) - ) - super().__init__(*args, **kwargs) - self._version = self._sanitize_version(self._contents['version']) +@dataclass +class LocalPinned(Pinned): + local: str @property def name(self): - return self.package - - @classmethod - def _sanitize_version(cls, version): - version = [v if isinstance(v, VersionSpecifier) - else VersionSpecifier.from_version_string(v) - for v in cls.version_to_list(version)] - return version or [UnboundedVersionSpecifier()] + return self.local def source_type(self): - return 'hub' - - @property - def version(self): - return self._version - - @version.setter - def version(self, version): - self._version = self._sanitize_version(version) + return 'local' - def version_name(self): - self._check_version_pinned() - version_string = self.version[0].to_version_string(skip_matcher=True) - return version_string + def get_version(self): + return None def nice_version_name(self): - return "version {}".format(self.version_name()) + return ''.format(self.local) - def incorporate(self, other): - return RegistryPackage( - package=self.package, - version=[x.to_version_string() for x in - self.version + other.version] + def resolve_path(self, project): + return system.resolve_path_from_base( + self.local, + project.project_root, ) - def _check_in_index(self): - index = registry.index_cached() - if self.package not in index: - dbt.exceptions.package_not_found(self.package) - - def _resolve_version(self): - self._check_in_index() - range_ = dbt.semver.reduce_versions(*self.version) - available = registry.get_available_versions(self.package) - # for now, pick a version and then recurse. later on, - # we'll probably want to traverse multiple options - # so we can match packages. not going to make a difference - # right now. - target = dbt.semver.resolve_to_specific_version(range_, available) - if not target: - dbt.exceptions.package_version_not_found( - self.package, range_, available) - self.version = target - - def _check_version_pinned(self): - if len(self.version) != 1: - dbt.exceptions.raise_dependency_error( - 'Cannot fetch metadata until the version is pinned.') - def _fetch_metadata(self, project): - version_string = self.version_name() - dct = registry.package_version(self.package, version_string) - return RegistryPackageMetadata(**dct) + loaded = project.from_project_root(self.resolve_path(project), {}) + return ProjectPackageMetadata.from_project(loaded) def install(self, project): - version_string = self.version_name() - metadata = self.fetch_metadata(project) - - tar_name = '{}.{}.tar.gz'.format(self.package, version_string) - tar_path = os.path.realpath(os.path.join(DOWNLOADS_PATH, tar_name)) - dbt.clients.system.make_directory(os.path.dirname(tar_path)) - - download_url = metadata['downloads']['tarball'] - dbt.clients.system.download(download_url, tar_path) - deps_path = project.modules_path - package_name = self.get_project_name(project) - dbt.clients.system.untar_package(tar_path, deps_path, package_name) - - -# the metadata is a package config with extra attributes we don't care about. -class RegistryPackageMetadata(PackageConfig): - SCHEMA = REGISTRY_PACKAGE_METADATA_CONTRACT - + src_path = self.resolve_path(project) + dest_path = self.get_installation_path(project) -class ProjectPackageMetadata: - def __init__(self, project): - self.name = project.project_name - self.packages = project.packages.packages + can_create_symlink = system.supports_symlinks() + if system.path_exists(dest_path): + if not system.path_is_symlink(dest_path): + system.rmdir(dest_path) + else: + system.remove_file(dest_path) -def md5sum(s: str): - return hashlib.md5(s.encode('latin-1')).hexdigest() + if can_create_symlink: + logger.debug(' Creating symlink to local dependency.') + system.make_symlink(src_path, dest_path) + else: + logger.debug(' Symlinks are not available on this ' + 'OS, copying dependency.') + shutil.copytree(src_path, dest_path) -class GitPackage(Package): - SCHEMA = GIT_PACKAGE_CONTRACT - def __init__(self, *args, **kwargs): - if 'warn_unpinned' in kwargs: - kwargs['warn-unpinned'] = kwargs.pop('warn_unpinned') - super().__init__(*args, **kwargs) +@dataclass +class GitPinned(Pinned): + git: str + revision: str + warn_unpinned: bool = True + _checkout_name: str = field(init=False) + def __post_init__(self): + super().__post_init__() self._checkout_name = md5sum(self.git) - self.version = self._contents.get('revision') - - @property - def other_name(self): - if self.git.endswith('.git'): - return self.git[:-4] - else: - return self.git + '.git' @property def name(self): return self.git - @classmethod - def _sanitize_version(cls, version): - return cls.version_to_list(version) or ['master'] - def source_type(self): return 'git' - @property - def version(self): - return self._version - - @version.setter - def version(self, version): - self._version = self._sanitize_version(version) - - def version_name(self): - return self._version[0] + def get_version(self): + return self.revision def nice_version_name(self): - return "revision {}".format(self.version_name()) + return 'revision {}'.format(self.revision) - def incorporate(self, other): - # if one is False, make both be False. - warn_unpinned = self.warn_unpinned and other.warn_unpinned - - return GitPackage(git=self.git, - revision=(self.version + other.version), - warn_unpinned=warn_unpinned) - - def _resolve_version(self): - requested = set(self.version) - if len(requested) != 1: - dbt.exceptions.raise_dependency_error( - 'git dependencies should contain exactly one version. ' - '{} contains: {}'.format(self.git, requested)) - self.version = requested.pop() - - @property - def warn_unpinned(self): - return self.get('warn-unpinned', True) - - def _checkout(self, project): + def _checkout(self): """Performs a shallow clone of the repository into the downloads directory. This function can be called repeatedly. If the project has already been checked out at this version, it will be a no-op. Returns the path to the checked out directory.""" - if len(self.version) != 1: - dbt.exceptions.raise_dependency_error( - 'Cannot checkout repository until the version is pinned.') try: - dir_ = dbt.clients.git.clone_and_checkout( - self.git, DOWNLOADS_PATH, branch=self.version[0], - dirname=self._checkout_name) + dir_ = git.clone_and_checkout( + self.git, DOWNLOADS_PATH, branch=self.revision, + dirname=self._checkout_name + ) except dbt.exceptions.ExecutableError as exc: if exc.cmd and exc.cmd[0] == 'git': logger.error( @@ -305,9 +221,9 @@ def _checkout(self, project): raise return os.path.join(DOWNLOADS_PATH, dir_) - def _fetch_metadata(self, project): - path = self._checkout(project) - if self.version[0] == 'master' and self.warn_unpinned: + def _fetch_metadata(self, project) -> ProjectPackageMetadata: + path = self._checkout() + if self.revision == 'master' and self.warn_unpinned: dbt.exceptions.warn_or_error( 'The git package "{}" is not pinned.\n\tThis can introduce ' 'breaking changes into your project without warning!\n\nSee {}' @@ -315,187 +231,307 @@ def _fetch_metadata(self, project): log_fmt=printer.yellow('WARNING: {}') ) loaded = project.from_project_root(path, {}) - return ProjectPackageMetadata(loaded) + return ProjectPackageMetadata.from_project(loaded) def install(self, project): dest_path = self.get_installation_path(project) if os.path.exists(dest_path): - if dbt.clients.system.path_is_symlink(dest_path): - dbt.clients.system.remove_file(dest_path) + if system.path_is_symlink(dest_path): + system.remove_file(dest_path) else: - dbt.clients.system.rmdir(dest_path) - dbt.clients.system.move(self._checkout(project), dest_path) + system.rmdir(dest_path) + + system.move(self._checkout(), dest_path) +@dataclass +class RegistryPinned(Pinned): + package: str + version: str + + @property + def name(self): + return self.package + + def source_type(self): + return 'hub' + + def get_version(self): + return self.version + + def nice_version_name(self): + return 'version {}'.format(self.version) + + def _fetch_metadata(self, project) -> RegistryPackageMetadata: + dct = registry.package_version(self.package, self.version) + return RegistryPackageMetadata.from_dict(dct) + + def install(self, project): + metadata = self.fetch_metadata(project) + + tar_name = '{}.{}.tar.gz'.format(self.package, self.version) + tar_path = os.path.realpath(os.path.join(DOWNLOADS_PATH, tar_name)) + system.make_directory(os.path.dirname(tar_path)) + + download_url = metadata.downloads.tarball + system.download(download_url, tar_path) + deps_path = project.modules_path + package_name = self.get_project_name(project) + system.untar_package(tar_path, deps_path, package_name) + + +class Package(metaclass=abc.ABCMeta): + @abc.abstractclassmethod + def from_contract(cls, contract): + raise NotImplementedError + + @abc.abstractproperty + def name(self): + raise NotImplementedError + + def all_names(self): + return [self.name] + + def _typecheck(self, other): + if not isinstance(other, self.__class__): + raise_dependency_error( + 'Cannot incorporate {0} ({0.__class__.__name__}) into ' + '{1} ({1.__class__.__name__}): mismatched types' + .format(other, self)) + + +@dataclass class LocalPackage(Package): - SCHEMA = LOCAL_PACKAGE_CONTRACT + local: str + + def source_type(self): + return 'local' @property def name(self): return self.local - def incorporate(self, _): + @classmethod + def from_contract(cls, contract: LocalPackageContract) -> 'LocalPackage': + return cls(local=contract.local) + + def incorporate( + self, other: Union['LocalPackage', LocalPinned] + ) -> 'LocalPackage': + if isinstance(other, LocalPinned): + other = LocalPackage(local=other.local) + self._typecheck(other) return LocalPackage(local=self.local) + def resolved(self) -> LocalPinned: + return LocalPinned(local=self.local) + + +@dataclass +class GitPackage(Package): + git: str + revisions: List[str] + warn_unpinned: bool = True + + @classmethod + def from_contract(cls, contract: GitPackageContract) -> 'GitPackage': + revisions = [contract.revision] if contract.revision else [] + + # we want to map None -> True + warn_unpinned = contract.warn_unpinned is not False + return cls(git=contract.git, revisions=revisions, + warn_unpinned=warn_unpinned) + + @property + def name(self): + return self.git + def source_type(self): - return 'local' + return 'git' - def version_name(self): - return ''.format(self.local) + def all_names(self): + if self.git.endswith('.git'): + other = self.git[:-4] + else: + other = self.git + '.git' + return [self.git, other] - def nice_version_name(self): - return self.version_name() + def incorporate( + self, other: Union['GitPackage', GitPinned] + ) -> 'GitPackage': - def _fetch_metadata(self, project): - project_file_path = dbt.clients.system.resolve_path_from_base( - self.local, - project.project_root) + if isinstance(other, GitPinned): + other = GitPackage(git=other.git, revisions=[other.revision], + warn_unpinned=other.warn_unpinned) - loaded = project.from_project_root(project_file_path, {}) - return ProjectPackageMetadata(loaded) + self._typecheck(other) - def install(self, project): - src_path = dbt.clients.system.resolve_path_from_base( - self.local, - project.project_root) + warn_unpinned = self.warn_unpinned and other.warn_unpinned - dest_path = self.get_installation_path(project) + return GitPackage(git=self.git, + revisions=self.revisions + other.revisions, + warn_unpinned=warn_unpinned) - can_create_symlink = dbt.clients.system.supports_symlinks() + def resolved(self) -> GitPinned: + requested = set(self.revisions) + if len(requested) == 0: + requested = {'master'} + elif len(requested) > 1: + dbt.exceptions.raise_dependency_error( + 'git dependencies should contain exactly one version. ' + '{} contains: {}'.format(self.git, requested)) - if dbt.clients.system.path_exists(dest_path): - if not dbt.clients.system.path_is_symlink(dest_path): - dbt.clients.system.rmdir(dest_path) - else: - dbt.clients.system.remove_file(dest_path) + return GitPinned( + git=self.git, revision=requested.pop(), + warn_unpinned=self.warn_unpinned + ) - if can_create_symlink: - logger.debug(' Creating symlink to local dependency.') - dbt.clients.system.make_symlink(src_path, dest_path) - else: - logger.debug(' Symlinks are not available on this ' - 'OS, copying dependency.') - shutil.copytree(src_path, dest_path) +@dataclass +class RegistryPackage(Package): + package: str + versions: List[semver.VersionSpecifier] + @property + def name(self): + return self.package -def _parse_package(dict_): - only_1_keys = ['package', 'git', 'local'] - specified = [k for k in only_1_keys if dict_.get(k)] - if len(specified) > 1: - dbt.exceptions.raise_dependency_error( - 'Packages should not contain more than one of {}; ' - 'yours has {} of them - {}' - .format(only_1_keys, len(specified), specified)) - if dict_.get('package'): - return RegistryPackage(**dict_) - if dict_.get('git'): - if dict_.get('version'): - msg = ("Keyword 'version' specified for git package {}.\nDid " - "you mean 'revision'?".format(dict_.get('git'))) - dbt.exceptions.raise_dependency_error(msg) - return GitPackage(**dict_) - if dict_.get('local'): - return LocalPackage(**dict_) - dbt.exceptions.raise_dependency_error( - 'Malformed package definition. Must contain package, git, or local.') + def _check_in_index(self): + index = registry.index_cached() + if self.package not in index: + dbt.exceptions.package_not_found(self.package) + + @classmethod + def from_contract( + cls, contract: RegistryPackageContract + ) -> 'RegistryPackage': + raw_version = contract.version + if isinstance(raw_version, str): + raw_version = [raw_version] + + versions = [ + semver.VersionSpecifier.from_version_string(v) + for v in raw_version + ] + return cls(package=contract.package, versions=versions) + + def incorporate( + self, other: ['RegistryPackage', RegistryPinned] + ) -> 'RegistryPackage': + if isinstance(other, RegistryPinned): + versions = [ + semver.VersionSpecifier.from_version_string(other.version) + ] + other = RegistryPackage(package=other.package, versions=versions) + self._typecheck(other) + return RegistryPackage(package=self.package, + versions=self.versions + other.versions) + + def resolved(self) -> RegistryPinned: + self._check_in_index() + try: + range_ = semver.reduce_versions(*self.versions) + except VersionsNotCompatibleException as e: + new_msg = ('Version error for package {}: {}' + .format(self.name, e)) + raise DependencyException(new_msg) from e + available = registry.get_available_versions(self.package) -class PackageListing(AttrDict): - def __contains__(self, package): - if isinstance(package, str): - return super().__contains__(package) - elif isinstance(package, GitPackage): - return package.name in self or package.other_name in self - else: - return package.name in self + # for now, pick a version and then recurse. later on, + # we'll probably want to traverse multiple options + # so we can match packages. not going to make a difference + # right now. + target = semver.resolve_to_specific_version(range_, available) + if not target: + package_version_not_found(self.package, range_, available) + return RegistryPinned(package=self.package, version=target) - def __setitem__(self, key, value): - if isinstance(key, str): - super().__setitem__(key, value) - elif isinstance(key, GitPackage) and key.other_name in self: - self[key.other_name] = value - else: - self[key.name] = value - def __getitem__(self, key): - if isinstance(key, str): - return super().__getitem__(key) - elif isinstance(key, GitPackage) and key.other_name in self: - return self[key.other_name] - else: - return self[key.name] +PackageResolver = Union[LocalPackage, GitPackage, RegistryPackage] +PinnedPackages = Union[LocalPinned, GitPinned, RegistryPinned] + + +@dataclass +class PackageListing: + packages: Dict[str, PackageResolver] = field(default_factory=dict) - def incorporate(self, package): - if not isinstance(package, Package): - package = _parse_package(package) + def __len__(self): + return len(self.packages) - if package in self: - self[package] = self[package].incorporate(package) + def __bool__(self): + return bool(self.packages) + + def _pick_key(self, key: Package): + for name in key.all_names(): + if name in self.packages: + return name + return key.name + + def __contains__(self, key: Package): + for name in key.all_names(): + if name in self.packages: + return True + + def __getitem__(self, key: Package): + key = self._pick_key(key) + return self.packages[key] + + def __setitem__(self, key: Package, value): + key = self._pick_key(key) + self.packages[key] = value + + def incorporate(self, package: Package): + key = self._pick_key(package) + if key in self.packages: + self.packages[key] = self.packages[key].incorporate(package) else: - self[package] = package + self.packages[key] = package + + def update_from(self, src: List[PackageContract]) -> 'PackageListing': + for contract in src: + if isinstance(contract, LocalPackageContract): + pkg = LocalPackage.from_contract(contract) + elif isinstance(contract, GitPackageContract): + pkg = GitPackage.from_contract(contract) + elif isinstance(contract, RegistryPackageContract): + pkg = RegistryPackage.from_contract(contract) + else: + raise dbt.exceptions.InternalException( + 'Invalid package type {}'.format(type(contract)) + ) + self.incorporate(pkg) @classmethod - def create(cls, parsed_yaml): - to_return = cls({}) - if not isinstance(parsed_yaml, list): - dbt.exceptions.raise_dependency_error( - 'Package definitions must be a list, got: {}' - .format(type(parsed_yaml))) - for package in parsed_yaml: - to_return.incorporate(package) - return to_return - - def incorporate_from_yaml(self, parsed_yaml): - listing = self.create(parsed_yaml) - for _, package in listing.items(): - self.incorporate(package) - - -def _split_at_branch(repo_spec): - parts = repo_spec.split('@') - error = RuntimeError( - "Invalid dep specified: '{}' -- not a repo we can clone".format( - repo_spec - ) - ) - repo = None - if repo_spec.startswith('git@'): - if len(parts) == 1: - raise error - if len(parts) == 2: - repo, branch = repo_spec, None - elif len(parts) == 3: - repo, branch = '@'.join(parts[:2]), parts[2] - else: - if len(parts) == 1: - repo, branch = parts[0], None - elif len(parts) == 2: - repo, branch = parts - if repo is None: - raise error - return repo, branch - - -def _convert_repo(repo_spec): - repo, branch = _split_at_branch(repo_spec) - return { - 'git': repo, - 'revision': branch, - } - - -def _read_packages(project_yaml): - packages = project_yaml.get('packages', []) - repos = project_yaml.get('repositories', []) - if repos: - bad_packages = [_convert_repo(r) for r in repos] - packages += bad_packages - - fixed_packages = {"packages": bad_packages} - recommendation = yaml.dump(fixed_packages, default_flow_style=False) - dbt.deprecations.warn('repositories', recommendation=recommendation) - return packages + def from_contracts( + cls: 'PackageListing', src: List[PackageContract] + ) -> 'PackageListing': + self = cls({}) + self.update_from(src) + return self + + def resolved(self) -> List[PinnedPackages]: + return [p.resolved() for p in self.packages.values()] + + def __iter__(self): + return iter(self.packages.values()) + + +def resolve_packages( + packages: List[PackageContract], config +) -> List[PinnedPackages]: + pending = PackageListing.from_contracts(packages) + final = PackageListing() + + while pending: + next_pending = PackageListing() + # resolve the dependency in question + for package in pending: + final.incorporate(package) + target = final[package].resolved().fetch_metadata(config) + next_pending.update_from(target.packages) + pending = next_pending + return final.resolved() class DepsTask(ProjectOnlyTask): @@ -511,7 +547,7 @@ def downloads_path(self): def _check_for_duplicate_project_names(self, final_deps): seen = set() - for _, package in final_deps.items(): + for package in final_deps: project_name = package.get_project_name(self.config) if project_name in seen: dbt.exceptions.raise_dependency_error( @@ -533,7 +569,7 @@ def track_package_install(self, package_name, source_type, version): }) def run(self): - dbt.clients.system.make_directory(self.config.modules_path) + system.make_directory(self.config.modules_path) _initialize_downloads() packages = self.config.packages.packages @@ -541,21 +577,11 @@ def run(self): logger.info('Warning: No packages were found in packages.yml') return - pending_deps = PackageListing.create(packages) - final_deps = PackageListing.create([]) - - while pending_deps: - sub_deps = PackageListing.create([]) - for package in pending_deps.values(): - final_deps.incorporate(package) - final_deps[package].resolve_version() - target_config = final_deps[package].fetch_metadata(self.config) - sub_deps.incorporate_from_yaml(target_config.packages) - pending_deps = sub_deps + final_deps = resolve_packages(packages, self.config) self._check_for_duplicate_project_names(final_deps) - for package in final_deps.values(): + for package in final_deps: logger.info('Installing %s', package) package.install(self.config) logger.info(' Installed from %s\n', package.nice_version_name()) @@ -563,7 +589,7 @@ def run(self): self.track_package_install( package_name=package.name, source_type=package.source_type(), - version=package.version_name()) + version=package.get_version()) if REMOVE_DOWNLOADS: - dbt.clients.system.rmtree(DOWNLOADS_PATH) + system.rmtree(DOWNLOADS_PATH) diff --git a/core/dbt/task/list.py b/core/dbt/task/list.py index fa83e27f1b9..7427a1a909b 100644 --- a/core/dbt/task/list.py +++ b/core/dbt/task/list.py @@ -77,13 +77,13 @@ def generate_json(self): for node in self._iterate_selected_nodes(): yield json.dumps({ k: v - for k, v in node.serialize().items() + for k, v in node.to_dict(omit_none=False).items() if k in self.ALLOWED_KEYS }) def generate_paths(self): for node in self._iterate_selected_nodes(): - yield node.get('original_file_path') + yield node.original_file_path def run(self): ManifestTask._runtime_initialize(self) diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 2678006ef5e..be1e2306601 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -9,8 +9,7 @@ import dbt.exceptions import dbt.flags -from dbt.contracts.graph.parsed import Hook -from dbt.hooks import get_hook_dict +from dbt.hooks import get_hook from dbt.ui.printer import \ print_hook_start_line, \ print_hook_end_line, \ @@ -66,11 +65,9 @@ def get_hook_sql(self, adapter, hook, idx, num_hooks, extra_context): compiled = compile_node(adapter, self.config, hook, self.manifest, extra_context) statement = compiled.wrapped_sql - hook_index = hook.get('index', num_hooks) - hook_dict = get_hook_dict(statement, index=hook_index) - if dbt.flags.STRICT_MODE: - Hook(**hook_dict) - return hook_dict.get('sql', '') + hook_index = hook.index or num_hooks + hook_obj = get_hook(statement, index=hook_index) + return hook_obj.sql or '' def _hook_keyfunc(self, hook): package_name = hook.package_name @@ -156,7 +153,7 @@ def after_run(self, adapter, results): # errored failed skipped schemas = list(set( r.node.schema for r in results - if not any((r.error is not None, r.failed, r.skipped)) + if not any((r.error is not None, r.fail, r.skipped)) )) with adapter.connection_named('master'): self.safe_run_hooks(adapter, RunHookType.End, diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 7a073f856b5..a8c65ad669b 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -2,6 +2,7 @@ import os import time from abc import abstractmethod +from datetime import datetime from multiprocessing.dummy import Pool as ThreadPool from dbt import rpc @@ -9,7 +10,6 @@ from dbt.adapters.factory import get_adapter from dbt.logger import GLOBAL_LOGGER as logger from dbt.compilation import compile_manifest -from dbt.contracts.graph.manifest import CompileResultNode from dbt.contracts.results import ExecutionResult from dbt.loader import GraphLoader @@ -174,7 +174,7 @@ def _handle_result(self, result): if not is_ephemeral: self.node_results.append(result) - node = CompileResultNode(**result.node) + node = result.node node_id = node.unique_id self.manifest.nodes[node_id] = node @@ -261,7 +261,7 @@ def execute_with_hooks(self, selected_uids): result = self.get_result( results=res, elapsed_time=elapsed, - generated_at=dbt.utils.timestring() + generated_at=datetime.utcnow() ) return result @@ -290,7 +290,7 @@ def interpret_results(self, results): if results is None: return False - failures = [r for r in results if r.error or r.failed] + failures = [r for r in results if r.error or r.fail] return len(failures) == 0 def get_model_schemas(self, selected_uids): diff --git a/core/dbt/types.py b/core/dbt/types.py new file mode 100644 index 00000000000..82af2e9652d --- /dev/null +++ b/core/dbt/types.py @@ -0,0 +1,17 @@ +from hologram import FieldEncoder, JsonSchemaMixin +from typing import Type, NewType + + +def NewRangedInteger(name: str, minimum: int, maximum: int) -> Type: + ranged = NewType(name, int) + + class RangeEncoder(FieldEncoder): + @property + def json_schema(self): + return {'type': 'integer', 'minimum': minimum, 'maximum': maximum} + + JsonSchemaMixin.register_field_encoders({ranged: RangeEncoder()}) + return ranged + + +Port = NewRangedInteger('Port', minimum=0, maximum=65535) diff --git a/core/dbt/ui/printer.py b/core/dbt/ui/printer.py index cd9c4b8f2d3..4cee0c74a69 100644 --- a/core/dbt/ui/printer.py +++ b/core/dbt/ui/printer.py @@ -90,11 +90,11 @@ def get_counts(flat_nodes): counts = {} for node in flat_nodes: - t = node.get('resource_type') + t = node.resource_type - if node.get('resource_type') == NodeType.Model: + if node.resource_type == NodeType.Model: t = '{} {}'.format(get_materialization(node), t) - elif node.get('resource_type') == NodeType.Operation: + elif node.resource_type == NodeType.Operation: t = 'hook' counts[t] = counts.get(t, 0) + 1 @@ -152,7 +152,7 @@ def print_test_result_line(result, schema_name, index, total): color = red elif result.status > 0: - severity = result.node.config['severity'].upper() + severity = result.node.config.severity.upper() if severity == 'ERROR' or dbt.flags.WARN_ERROR: info = 'FAIL {}'.format(result.status) color = red @@ -168,7 +168,7 @@ def print_test_result_line(result, schema_name, index, total): raise RuntimeError("unexpected status: {}".format(result.status)) print_fancy_output_line( - "{info} {name}".format(info=info, name=model.get('name')), + "{info} {name}".format(info=info, name=model.name), color(info), index, total, @@ -190,7 +190,7 @@ def print_snapshot_result_line(result, index, total): model = result.node info, status = get_printable_result(result, 'snapshotted', 'snapshotting') - cfg = model.get('config', {}) + cfg = model.config.to_dict() msg = "{info} {name}".format( info=info, name=model.name, **cfg) @@ -211,7 +211,7 @@ def print_seed_result_line(result, schema_name, index, total): "{info} seed file {schema}.{relation}".format( info=info, schema=schema_name, - relation=model.get('alias')), + relation=model.alias), status, index, total, @@ -255,7 +255,7 @@ def print_freshness_result_line(result, index, total): def interpret_run_result(result): - if result.error is not None or result.failed: + if result.error is not None or result.fail: return 'error' elif result.skipped: return 'skip' @@ -284,17 +284,17 @@ def print_run_result_error(result, newline=True): if newline: logger.info("") - if result.failed: + if result.fail: logger.info(yellow("Failure in {} {} ({})").format( - result.node.get('resource_type'), - result.node.get('name'), - result.node.get('original_file_path'))) + result.node.resource_type, + result.node.name, + result.node.original_file_path)) logger.info(" Got {} results, expected 0.".format(result.status)) - if result.node.get('build_path') is not None: + if result.node.build_path is not None: logger.info("") logger.info(" compiled SQL at {}".format( - result.node.get('build_path'))) + result.node.build_path)) else: first = True @@ -327,7 +327,7 @@ def print_end_of_run_summary(num_errors, early_exit=False): def print_run_end_messages(results, early_exit=False): - errors = [r for r in results if r.error is not None or r.failed] + errors = [r for r in results if r.error is not None or r.fail] print_end_of_run_summary(len(errors), early_exit) for error in errors: diff --git a/core/dbt/utils.py b/core/dbt/utils.py index c4abdb12272..30c338d4f3e 100644 --- a/core/dbt/utils.py +++ b/core/dbt/utils.py @@ -60,8 +60,10 @@ def get_model_name_or_none(model): name = model elif isinstance(model, dict): name = model.get('alias', model.get('name')) - elif hasattr(model, 'nice_name'): - name = model.nice_name + elif hasattr(model, 'alias'): + name = model.alias + elif hasattr(model, 'name'): + name = model.name else: name = str(model) return name @@ -82,7 +84,7 @@ def id_matches(unique_id, target_name, target_package, nodetypes, model): nodetypes should be a container of NodeTypes that implements the 'in' operator. """ - node_type = model.get('resource_type', 'node') + node_type = model.resource_type node_parts = unique_id.split('.', 2) if len(node_parts) != 3: msg = "unique_id {} is malformed".format(unique_id) @@ -127,7 +129,7 @@ def find_in_subgraph_by_name(subgraph, target_name, target_package, nodetype): def find_in_list_by_name(haystack, target_name, target_package, nodetype): """Find an entry in the given list by name.""" for model in haystack: - name = model.get('unique_id') + name = model.unique_id if id_matches(name, target_name, target_package, nodetype, model): return model @@ -139,10 +141,14 @@ def find_in_list_by_name(haystack, target_name, target_package, nodetype): def get_dbt_macro_name(name): + if name is None: + raise dbt.exceptions.InternalException('Got None for a macro name!') return '{}{}'.format(MACRO_PREFIX, name) def get_dbt_docs_name(name): + if name is None: + raise dbt.exceptions.InternalException('Got None for a doc name!') return '{}{}'.format(DOCS_PREFIX, name) @@ -301,17 +307,17 @@ def to_string(s): def get_materialization(node): - return node.get('config', {}).get('materialized') + return node.config.materialized def is_enabled(node): - return node.get('config', {}).get('enabled') is True + return node.config.enabled def is_type(node, _type): if hasattr(_type, 'value'): _type = _type.value - return node.get('resource_type') == _type + return node.resource_type == _type def get_pseudo_test_path(node_name, source_path, test_type): @@ -331,7 +337,7 @@ def get_pseudo_hook_path(hook_name): def get_nodes_by_tags(nodes, match_tags, resource_type): matched_nodes = [] for node in nodes: - node_tags = node.get('tags', []) + node_tags = node.tags if len(set(node_tags) & match_tags): matched_nodes.append(node) return matched_nodes @@ -342,11 +348,11 @@ def md5(string): def get_hash(model): - return hashlib.md5(model.get('unique_id').encode('utf-8')).hexdigest() + return hashlib.md5(model.unique_id.encode('utf-8')).hexdigest() def get_hashed_contents(model): - return hashlib.md5(model.get('raw_sql').encode('utf-8')).hexdigest() + return hashlib.md5(model.raw_sql.encode('utf-8')).hexdigest() def flatten_nodes(dep_list): diff --git a/core/dbt/writer.py b/core/dbt/writer.py index 040469184a5..e36519b2c60 100644 --- a/core/dbt/writer.py +++ b/core/dbt/writer.py @@ -4,13 +4,10 @@ def write_node(node, target_path, subdirectory, payload): - node_path = node.get('path') + node_path = node.path - full_path = os.path.join( - target_path, - subdirectory, - node.get('package_name'), - node_path) + full_path = os.path.join(target_path, subdirectory, node.package_name, + node_path) dbt.clients.system.write_file(full_path, payload) diff --git a/core/setup.py b/core/setup.py index 3a82f183fe1..80256b4dd9b 100644 --- a/core/setup.py +++ b/core/setup.py @@ -54,5 +54,7 @@ def read(fname): 'jsonschema>=3.0.1,<4', 'json-rpc>=1.12,<2', 'werkzeug>=0.14.1,<0.15', + 'dataclasses;python_version<"3.7"', + 'hologram @ git+https://github.com/fishtown-analytics/hologram.git@master#egg=hologram', ] ) diff --git a/plugins/bigquery/dbt/adapters/bigquery/connections.py b/plugins/bigquery/dbt/adapters/bigquery/connections.py index f211ba1d2f5..21bb21d8fd7 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/connections.py +++ b/plugins/bigquery/dbt/adapters/bigquery/connections.py @@ -11,40 +11,28 @@ from dbt.adapters.base import BaseConnectionManager, Credentials from dbt.logger import GLOBAL_LOGGER as logger +from hologram.helpers import StrEnum -BIGQUERY_CREDENTIALS_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'method': { - 'enum': ['oauth', 'service-account', 'service-account-json'], - }, - 'database': { - 'type': 'string', - }, - 'schema': { - 'type': 'string', - }, - 'keyfile': { - 'type': 'string', - }, - 'keyfile_json': { - 'type': 'object', - }, - 'timeout_seconds': { - 'type': 'integer', - }, - 'location': { - 'type': 'string', - }, - }, - 'required': ['method', 'database', 'schema'], -} +from dataclasses import dataclass +from typing import Optional, Any, Dict +class BigQueryConnectionMethod(StrEnum): + OAUTH = 'oauth' + SERVICE_ACCOUNT = 'service-account' + SERVICE_ACCOUNT_JSON = 'service-account-json' + + +@dataclass class BigQueryCredentials(Credentials): - SCHEMA = BIGQUERY_CREDENTIALS_CONTRACT - ALIASES = { + method: BigQueryConnectionMethod + database: str + schema: str + keyfile: Optional[str] = None + keyfile_json: Optional[Dict[str, Any]] = None + timeout_seconds: Optional[int] = 300 + location: Optional[str] = None + _ALIASES = { 'project': 'database', 'dataset': 'schema', } @@ -121,15 +109,15 @@ def get_bigquery_credentials(cls, profile_credentials): method = profile_credentials.method creds = google.oauth2.service_account.Credentials - if method == 'oauth': + if method == BigQueryConnectionMethod.OAUTH: credentials, project_id = google.auth.default(scopes=cls.SCOPE) return credentials - elif method == 'service-account': + elif method == BigQueryConnectionMethod.SERVICE_ACCOUNT: keyfile = profile_credentials.keyfile return creds.from_service_account_file(keyfile, scopes=cls.SCOPE) - elif method == 'service-account-json': + elif method == BigQueryConnectionMethod.SERVICE_ACCOUNT_JSON: details = profile_credentials.keyfile_json return creds.from_service_account_info(details, scopes=cls.SCOPE) @@ -175,8 +163,8 @@ def open(cls, connection): @classmethod def get_timeout(cls, conn): - credentials = conn['credentials'] - return credentials.get('timeout_seconds', cls.QUERY_TIMEOUT) + credentials = conn.credentials + return credentials.timeout_seconds @classmethod def get_table_from_response(cls, resp): diff --git a/plugins/postgres/dbt/adapters/postgres/connections.py b/plugins/postgres/dbt/adapters/postgres/connections.py index f66fb7c12d4..1915df00557 100644 --- a/plugins/postgres/dbt/adapters/postgres/connections.py +++ b/plugins/postgres/dbt/adapters/postgres/connections.py @@ -7,45 +7,23 @@ from dbt.adapters.sql import SQLConnectionManager from dbt.logger import GLOBAL_LOGGER as logger - -POSTGRES_CREDENTIALS_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'database': { - 'type': 'string', - }, - 'host': { - 'type': 'string', - }, - 'user': { - 'type': 'string', - }, - 'password': { - 'type': 'string', - }, - 'port': { - 'type': 'integer', - 'minimum': 0, - 'maximum': 65535, - }, - 'schema': { - 'type': 'string', - }, - 'search_path': { - 'type': 'string', - }, - 'keepalives_idle': { - 'type': 'integer', - }, - }, - 'required': ['database', 'host', 'user', 'password', 'port', 'schema'], -} +from dbt.types import Port +from dataclasses import dataclass +from typing import Optional +@dataclass class PostgresCredentials(Credentials): - SCHEMA = POSTGRES_CREDENTIALS_CONTRACT - ALIASES = { + database: str + host: str + user: str + password: str + port: Port + schema: str + search_path: Optional[str] + keepalives_idle: Optional[int] = 0 # 0 means to use the default value + + _ALIASES = { 'dbname': 'database', 'pass': 'password' } @@ -59,7 +37,6 @@ def _connection_keys(self): class PostgresConnectionManager(SQLConnectionManager): - DEFAULT_TCP_KEEPALIVE = 0 # 0 means to use the default value TYPE = 'postgres' @contextmanager @@ -97,18 +74,16 @@ def open(cls, connection): logger.debug('Connection is already open, skipping open.') return connection - credentials = cls.get_credentials(connection.credentials.incorporate()) + credentials = cls.get_credentials(connection.credentials) kwargs = {} - keepalives_idle = credentials.get('keepalives_idle', - cls.DEFAULT_TCP_KEEPALIVE) # we don't want to pass 0 along to connect() as postgres will try to # call an invalid setsockopt() call (contrary to the docs). - if keepalives_idle: - kwargs['keepalives_idle'] = keepalives_idle + if credentials.keepalives_idle: + kwargs['keepalives_idle'] = credentials.keepalives_idle # psycopg2 doesn't support search_path officially, # see https://github.com/psycopg/psycopg2/issues/465 - search_path = credentials.get('search_path') + search_path = credentials.search_path if search_path is not None and search_path != '': # see https://postgresql.org/docs/9.5/libpq-connect.html kwargs['options'] = '-c search_path={}'.format( diff --git a/plugins/redshift/dbt/adapters/redshift/connections.py b/plugins/redshift/dbt/adapters/redshift/connections.py index ecb2a57ca58..4b93061bc41 100644 --- a/plugins/redshift/dbt/adapters/redshift/connections.py +++ b/plugins/redshift/dbt/adapters/redshift/connections.py @@ -8,69 +8,33 @@ import boto3 +from dbt.types import NewRangedInteger +from hologram.helpers import StrEnum + +from dataclasses import dataclass, field +from typing import Optional + drop_lock = multiprocessing.Lock() -REDSHIFT_CREDENTIALS_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'method': { - 'enum': ['database', 'iam'], - 'description': ( - 'database: use user/pass creds; iam: use temporary creds' - ), - }, - 'database': { - 'type': 'string', - }, - 'host': { - 'type': 'string', - }, - 'user': { - 'type': 'string', - }, - 'password': { - 'type': 'string', - }, - 'port': { - 'type': 'integer', - 'minimum': 0, - 'maximum': 65535, - }, - 'schema': { - 'type': 'string', - }, - 'cluster_id': { - 'type': 'string', - 'description': ( - 'If using IAM auth, the name of the cluster' - ) - }, - 'iam_duration_seconds': { - 'type': 'integer', - 'minimum': 900, - 'maximum': 3600, - 'description': ( - 'If using IAM auth, the ttl for the temporary credentials' - ) - }, - 'search_path': { - 'type': 'string', - }, - 'keepalives_idle': { - 'type': 'integer', - }, - 'required': ['database', 'host', 'user', 'port', 'schema'] - } -} +IAMDuration = NewRangedInteger('IAMDuration', minimum=900, maximum=3600) -class RedshiftCredentials(PostgresCredentials): - SCHEMA = REDSHIFT_CREDENTIALS_CONTRACT - def __init__(self, *args, **kwargs): - kwargs.setdefault('method', 'database') - super().__init__(*args, **kwargs) +class RedshiftConnectionMethod(StrEnum): + DATABASE = 'database' + IAM = 'iam' + + +@dataclass +class RedshiftCredentials(PostgresCredentials): + method: RedshiftConnectionMethod = RedshiftConnectionMethod.DATABASE + cluster_id: Optional[str] = field( + default=None, + metadata={'description': 'If using IAM auth, the name of the cluster'}, + ) + iam_duration_seconds: Optional[int] = None + search_path: Optional[str] = None + keepalives_idle: Optional[int] = 240 @property def type(self): @@ -83,7 +47,6 @@ def _connection_keys(self): class RedshiftConnectionManager(PostgresConnectionManager): - DEFAULT_TCP_KEEPALIVE = 240 TYPE = 'redshift' @contextmanager @@ -131,11 +94,11 @@ def fetch_cluster_credentials(cls, db_user, db_name, cluster_id, @classmethod def get_tmp_iam_cluster_credentials(cls, credentials): - cluster_id = credentials.get('cluster_id') + cluster_id = credentials.cluster_id # default via: # boto3.readthedocs.io/en/latest/reference/services/redshift.html - iam_duration_s = credentials.get('iam_duration_seconds', 900) + iam_duration_s = credentials.iam_duration_seconds if not cluster_id: raise dbt.exceptions.FailedToConnectException( @@ -150,10 +113,8 @@ def get_tmp_iam_cluster_credentials(cls, credentials): ) # replace username and password with temporary redshift credentials - return credentials.incorporate( - user=cluster_creds.get('DbUser'), - password=cluster_creds.get('DbPassword') - ) + return credentials.replace(user=cluster_creds.get('DbUser'), + password=cluster_creds.get('DbPassword')) @classmethod def get_credentials(cls, credentials): diff --git a/plugins/snowflake/dbt/adapters/snowflake/connections.py b/plugins/snowflake/dbt/adapters/snowflake/connections.py index b8e05764d10..197e5a09113 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/connections.py +++ b/plugins/snowflake/dbt/adapters/snowflake/connections.py @@ -12,52 +12,23 @@ from dbt.adapters.sql import SQLConnectionManager from dbt.logger import GLOBAL_LOGGER as logger - -SNOWFLAKE_CREDENTIALS_CONTRACT = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'account': { - 'type': 'string', - }, - 'user': { - 'type': 'string', - }, - 'password': { - 'type': 'string', - }, - 'authenticator': { - 'type': 'string', - 'description': "Either 'externalbrowser', or a valid Okta url" - }, - 'private_key_path': { - 'type': 'string', - }, - 'private_key_passphrase': { - 'type': 'string', - }, - 'database': { - 'type': 'string', - }, - 'schema': { - 'type': 'string', - }, - 'warehouse': { - 'type': 'string', - }, - 'role': { - 'type': 'string', - }, - 'client_session_keep_alive': { - 'type': 'boolean', - } - }, - 'required': ['account', 'user', 'database', 'schema'], -} +from dataclasses import dataclass +from typing import Optional +@dataclass class SnowflakeCredentials(Credentials): - SCHEMA = SNOWFLAKE_CREDENTIALS_CONTRACT + account: str + user: str + database: str + schema: str + warehouse: Optional[str] + role: Optional[str] + password: Optional[str] + authenticator: Optional[str] + private_key_path: Optional[str] + private_key_passphrase: Optional[str] + client_session_keep_alive: bool = False @property def type(self): @@ -66,6 +37,33 @@ def type(self): def _connection_keys(self): return ('account', 'user', 'database', 'schema', 'warehouse', 'role') + def auth_args(self): + # Pull all of the optional authentication args for the connector, + # let connector handle the actual arg validation + result = {} + if self.password: + result['password'] = self.password + if self.authenticator: + result['authenticator'] = self.authenticator + result['private_key'] = self._get_private_key() + return result + + def _get_private_key(self): + """Get Snowflake private key by path or None.""" + if not self.private_key_path or self.private_key_passphrase is None: + return None + + with open(self.private_key_path, 'rb') as key: + p_key = serialization.load_pem_private_key( + key.read(), + password=self.private_key_passphrase.encode(), + backend=default_backend()) + + return p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption()) + class SnowflakeConnectionManager(SQLConnectionManager): TYPE = 'snowflake' @@ -110,27 +108,18 @@ def open(cls, connection): return connection try: - credentials = connection.credentials - # Pull all of the optional authentication args for the connector, - # let connector handle the actual arg validation - auth_args = {auth_key: credentials[auth_key] - for auth_key in ['user', 'password', 'authenticator'] - if auth_key in credentials} - - auth_args['private_key'] = cls._get_private_key( - credentials.get('private_key_path'), - credentials.get('private_key_passphrase')) + creds = connection.credentials handle = snowflake.connector.connect( - account=credentials.account, - database=credentials.database, - schema=credentials.schema, - warehouse=credentials.warehouse, - role=credentials.get('role', None), + account=creds.account, + user=creds.user, + database=creds.database, + schema=creds.schema, + warehouse=creds.warehouse, + role=creds.role, autocommit=False, - client_session_keep_alive=credentials.get( - 'client_session_keep_alive', False), - **auth_args + client_session_keep_alive=creds.client_session_keep_alive, + **creds.auth_args() ) connection.handle = handle @@ -178,23 +167,6 @@ def _split_queries(cls, sql): split_query = snowflake.connector.util_text.split_statements(sql_buf) return [part[0] for part in split_query] - @classmethod - def _get_private_key(cls, private_key_path, private_key_passphrase): - """Get Snowflake private key by path or None.""" - if private_key_path is None or private_key_passphrase is None: - return None - - with open(private_key_path, 'rb') as key: - p_key = serialization.load_pem_private_key( - key.read(), - password=private_key_passphrase.encode(), - backend=default_backend()) - - return p_key.private_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption()) - def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False): diff --git a/plugins/snowflake/dbt/adapters/snowflake/relation.py b/plugins/snowflake/dbt/adapters/snowflake/relation.py index 0532e0ffb11..0c6b8555484 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/relation.py +++ b/plugins/snowflake/dbt/adapters/snowflake/relation.py @@ -44,11 +44,3 @@ class SnowflakeRelation(BaseRelation): 'required': ['metadata', 'type', 'path', 'include_policy', 'quote_policy', 'quote_character', 'dbt_created'] } - - @classmethod - def _create_from_node(cls, config, node, **kwargs): - return cls.create( - database=node.get('database'), - schema=node.get('schema'), - identifier=node.get('alias'), - **kwargs) diff --git a/test/integration/004_simple_snapshot_test/test-check-col-snapshots-bq/snapshot.sql b/test/integration/004_simple_snapshot_test/test-check-col-snapshots-bq/snapshot.sql index 33c3e4e5fff..310f18280fd 100644 --- a/test/integration/004_simple_snapshot_test/test-check-col-snapshots-bq/snapshot.sql +++ b/test/integration/004_simple_snapshot_test/test-check-col-snapshots-bq/snapshot.sql @@ -1,11 +1,13 @@ {% snapshot snapshot_actual %} + {# this used to be check_cols=('email',), which ought to be totally valid, + but isn't because type systems are hard. #} {{ config( target_database=var('target_database', database), target_schema=schema, unique_key='concat(cast(id as string) , "-", first_name)', strategy='check', - check_cols=('email',), + check_cols=['email'], ) }} select * from `{{target.database}}`.`{{schema}}`.seed diff --git a/test/integration/007_graph_selection_tests/test_graph_selection.py b/test/integration/007_graph_selection_tests/test_graph_selection.py index 5c3f92ae9ba..2de1f1edeee 100644 --- a/test/integration/007_graph_selection_tests/test_graph_selection.py +++ b/test/integration/007_graph_selection_tests/test_graph_selection.py @@ -229,8 +229,8 @@ def test__snowflake__skip_intermediate(self): # make sure that users_rollup_dependency and users don't interleave users = [r for r in results if r.node.name == 'users'][0] dep = [r for r in results if r.node.name == 'users_rollup_dependency'][0] - user_last_end = users.timing[1]['completed_at'] - dep_first_start = dep.timing[0]['started_at'] + user_last_end = users.timing[1].completed_at + dep_first_start = dep.timing[0].started_at self.assertTrue( user_last_end <= dep_first_start, 'dependency started before its transitive parent ({} > {})'.format(user_last_end, dep_first_start) diff --git a/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py b/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py index 1337d92e320..c1f78335bac 100644 --- a/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py +++ b/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py @@ -34,7 +34,7 @@ def run_schema_and_assert(self, include, exclude, expected_tests): test_task = TestTask(args, self.config) test_results = test_task.run() - ran_tests = sorted([test.node.get('name') for test in test_results]) + ran_tests = sorted([test.node.name for test in test_results]) expected_sorted = sorted(expected_tests) self.assertEqual(ran_tests, expected_sorted) diff --git a/test/integration/007_graph_selection_tests/test_tag_selection.py b/test/integration/007_graph_selection_tests/test_tag_selection.py index 9876cf31814..d322d4032cc 100644 --- a/test/integration/007_graph_selection_tests/test_tag_selection.py +++ b/test/integration/007_graph_selection_tests/test_tag_selection.py @@ -33,7 +33,7 @@ def test__postgres__select_tag(self): results = self.run_dbt(['run', '--models', 'tag:specified_as_string']) self.assertEqual(len(results), 1) - models_run = [r.node['name'] for r in results] + models_run = [r.node.name for r in results] self.assertTrue('users' in models_run) @use_profile('postgres') @@ -43,7 +43,7 @@ def test__postgres__select_tag_and_children(self): results = self.run_dbt(['run', '--models', '+tag:specified_in_project+']) self.assertEqual(len(results), 3) - models_run = [r.node['name'] for r in results] + models_run = [r.node.name for r in results] self.assertTrue('users' in models_run) self.assertTrue('users_rollup' in models_run) @@ -55,7 +55,7 @@ def test__postgres__select_tag_in_model_with_project_Config(self): results = self.run_dbt(['run', '--models', 'tag:bi']) self.assertEqual(len(results), 2) - models_run = [r.node['name'] for r in results] + models_run = [r.node.name for r in results] self.assertTrue('users' in models_run) self.assertTrue('users_rollup' in models_run) @@ -67,7 +67,7 @@ def test__postgres__select_tag_in_model_with_project_Config(self): results = self.run_dbt(['run', '--models', '@tag:users']) self.assertEqual(len(results), 4) - models_run = set(r.node['name'] for r in results) + models_run = set(r.node.name for r in results) self.assertEqual( {'users', 'users_rollup', 'emails_alt', 'users_rollup_dependency'}, models_run diff --git a/test/integration/008_schema_tests_test/test_schema_v2_tests.py b/test/integration/008_schema_tests_test/test_schema_v2_tests.py index 53ad7e82857..920ca291113 100644 --- a/test/integration/008_schema_tests_test/test_schema_v2_tests.py +++ b/test/integration/008_schema_tests_test/test_schema_v2_tests.py @@ -36,12 +36,12 @@ def test_schema_tests(self): for result in test_results: # assert that all deliberately failing tests actually fail - if 'failure' in result.node.get('name'): + if 'failure' in result.node.name: self.assertIsNone(result.error) self.assertFalse(result.skipped) self.assertTrue( result.status > 0, - 'test {} did not fail'.format(result.node.get('name')) + 'test {} did not fail'.format(result.node.name) ) # assert that actual tests pass @@ -51,7 +51,7 @@ def test_schema_tests(self): # status = # of failing rows self.assertEqual( result.status, 0, - 'test {} failed'.format(result.node.get('name')) + 'test {} failed'.format(result.node.name) ) self.assertEqual(sum(x.status for x in test_results), 6) @@ -122,7 +122,7 @@ def test_hooks_dont_run_for_tests(self): # status = # of failing rows self.assertEqual( result.status, 0, - 'test {} failed'.format(result.node.get('name')) + 'test {} failed'.format(result.node.name) ) class TestCustomSchemaTests(DBTIntegrationTest): @@ -217,12 +217,12 @@ def test_schema_tests_bigquery(self): for result in test_results: # assert that all deliberately failing tests actually fail - if 'failure' in result.node.get('name'): + if 'failure' in result.node.name: self.assertIsNone(result.error) self.assertFalse(result.skipped) self.assertTrue( result.status > 0, - 'test {} did not fail'.format(result.node.get('name')) + 'test {} did not fail'.format(result.node.name) ) # assert that actual tests pass @@ -232,7 +232,7 @@ def test_schema_tests_bigquery(self): # status = # of failing rows self.assertEqual( result.status, 0, - 'test {} failed'.format(result.node.get('name')) + 'test {} failed'.format(result.node.name) ) self.assertEqual(sum(x.status for x in test_results), 0) diff --git a/test/integration/009_data_tests_test/test_data_tests.py b/test/integration/009_data_tests_test/test_data_tests.py index 80afcef6be9..23e4271b905 100644 --- a/test/integration/009_data_tests_test/test_data_tests.py +++ b/test/integration/009_data_tests_test/test_data_tests.py @@ -41,7 +41,7 @@ def test_postgres_data_tests(self): for result in test_results: # assert that all deliberately failing tests actually fail - if 'fail' in result.node.get('name'): + if 'fail' in result.node.name: self.assertIsNone(result.error) self.assertFalse(result.skipped) self.assertTrue(result.status > 0) @@ -70,7 +70,7 @@ def test_snowflake_data_tests(self): for result in test_results: # assert that all deliberately failing tests actually fail - if 'fail' in result.node.get('name'): + if 'fail' in result.node.name: self.assertIsNone(result.error) self.assertFalse(result.skipped) self.assertTrue(result.status > 0) diff --git a/test/integration/011_invalid_model_tests/test_invalid_models.py b/test/integration/011_invalid_model_tests/test_invalid_models.py index d5957377b6a..5e54743144f 100644 --- a/test/integration/011_invalid_model_tests/test_invalid_models.py +++ b/test/integration/011_invalid_model_tests/test_invalid_models.py @@ -1,7 +1,5 @@ from test.integration.base import DBTIntegrationTest, use_profile -from dbt.exceptions import ValidationException - class TestInvalidDisabledModels(DBTIntegrationTest): @@ -20,13 +18,10 @@ def models(self): @use_profile('postgres') def test_view_with_incremental_attributes(self): - - try: + with self.assertRaises(RuntimeError) as exc: self.run_dbt() - # should throw - self.assertTrue(False) - except RuntimeError as e: - self.assertTrue("enabled" in str(e)) + + self.assertIn('enabled', str(exc.exception)) class TestInvalidModelReference(DBTIntegrationTest): @@ -46,10 +41,7 @@ def models(self): @use_profile('postgres') def test_view_with_incremental_attributes(self): - - try: + with self.assertRaises(RuntimeError) as exc: self.run_dbt() - # should throw - self.assertTrue(False) - except RuntimeError as e: - self.assertTrue("which was not found" in str(e)) + + self.assertIn('which was not found', str(exc.exception)) diff --git a/test/integration/015_cli_invocation_tests/test_cli_invocation.py b/test/integration/015_cli_invocation_tests/test_cli_invocation.py index 77f37548ebf..92057959872 100644 --- a/test/integration/015_cli_invocation_tests/test_cli_invocation.py +++ b/test/integration/015_cli_invocation_tests/test_cli_invocation.py @@ -111,4 +111,4 @@ def test_toplevel_dbt_run_with_profile_dir_arg(self): # make sure the test runs against `custom_schema` for test_result in res: self.assertTrue(self.custom_schema, - test_result.node.get('wrapped_sql')) + test_result.node.wrapped_sql) diff --git a/test/integration/022_bigquery_test/test_simple_bigquery_view.py b/test/integration/022_bigquery_test/test_simple_bigquery_view.py index a364c2edb93..b492e1b37f0 100644 --- a/test/integration/022_bigquery_test/test_simple_bigquery_view.py +++ b/test/integration/022_bigquery_test/test_simple_bigquery_view.py @@ -28,7 +28,7 @@ def assert_nondupes_pass(self): test_results = self.run_dbt(['test'], expect_pass=False) for result in test_results: - if 'dupe' in result.node.get('name'): + if 'dupe' in result.node.name: self.assertIsNone(result.error) self.assertFalse(result.skipped) self.assertTrue(result.status > 0) diff --git a/test/integration/029_docs_generate_tests/test_docs_generate.py b/test/integration/029_docs_generate_tests/test_docs_generate.py index de4e7f41775..68deb74ae84 100644 --- a/test/integration/029_docs_generate_tests/test_docs_generate.py +++ b/test/integration/029_docs_generate_tests/test_docs_generate.py @@ -803,7 +803,7 @@ def verify_manifest_macros(self, manifest): 'unique_id': 'macro.dbt.column_list', 'tags': [], 'resource_type': 'macro', - 'depends_on': {'macros': []} + 'depends_on': {'macros': []}, } ) @@ -834,6 +834,7 @@ def expected_seeded_manifest(self, model_database=None): return { 'nodes': { 'model.test.model': { + 'build_path': None, 'name': 'model', 'root_path': OneOf(self.test_root_dir, self.initial_dir), 'resource_type': 'model', @@ -845,8 +846,8 @@ def expected_seeded_manifest(self, model_database=None): 'sources': [], 'depends_on': {'nodes': ['seed.test.seed'], 'macros': []}, 'unique_id': 'model.test.model', - 'empty': False, 'fqn': ['test', 'model'], + 'index': None, 'tags': [], 'config': model_config, 'schema': my_schema_name, @@ -879,6 +880,20 @@ def expected_seeded_manifest(self, model_database=None): 'docrefs': [], }, 'seed.test.seed': { + 'build_path': None, + 'config': { + 'enabled': True, + 'materialized': 'seed', + 'persist_docs': {}, + 'pre-hook': [], + 'post-hook': [], + 'vars': {}, + 'column_types': {}, + 'quoting': {}, + 'tags': [], + }, + 'index': None, + 'patch_path': None, 'path': 'seed.csv', 'name': 'seed', 'root_path': OneOf(self.test_root_dir, self.initial_dir), @@ -891,28 +906,18 @@ def expected_seeded_manifest(self, model_database=None): 'sources': [], 'depends_on': {'nodes': [], 'macros': []}, 'unique_id': 'seed.test.seed', - 'empty': False, 'fqn': ['test', 'seed'], 'tags': [], - 'config': { - 'enabled': True, - 'materialized': 'seed', - 'persist_docs': {}, - 'pre-hook': [], - 'post-hook': [], - 'vars': {}, - 'column_types': {}, - 'quoting': {}, - 'tags': [], - }, 'schema': my_schema_name, 'database': self.default_database, 'alias': 'seed', 'description': '', 'columns': {}, + 'docrefs': [], }, 'test.test.not_null_model_id': { 'alias': 'not_null_model_id', + 'build_path': None, 'column_name': 'id', 'columns': {}, 'config': { @@ -930,11 +935,11 @@ def expected_seeded_manifest(self, model_database=None): 'sources': [], 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, 'description': '', - 'empty': False, 'fqn': ['test', 'schema_test', 'not_null_model_id'], 'name': 'not_null_model_id', 'original_file_path': schema_yml_path, 'package_name': 'test', + 'patch_path': None, 'path': _normalize('schema_test/not_null_model_id.sql'), 'raw_sql': "{{ config(severity='ERROR') }}{{ test_not_null(model=ref('model'), column_name='id') }}", 'refs': [['model']], @@ -943,10 +948,13 @@ def expected_seeded_manifest(self, model_database=None): 'schema': my_schema_name, 'database': self.default_database, 'tags': ['schema'], - 'unique_id': 'test.test.not_null_model_id' + 'unique_id': 'test.test.not_null_model_id', + 'docrefs': [], }, 'test.test.nothing_model_': { 'alias': 'nothing_model_', + 'build_path': None, + 'column_name': None, 'columns': {}, 'config': { 'column_types': {}, @@ -963,11 +971,11 @@ def expected_seeded_manifest(self, model_database=None): 'sources': [], 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, 'description': '', - 'empty': False, 'fqn': ['test', 'schema_test', 'nothing_model_'], 'name': 'nothing_model_', 'original_file_path': schema_yml_path, 'package_name': 'test', + 'patch_path': None, 'path': _normalize('schema_test/nothing_model_.sql'), 'raw_sql': "{{ config(severity='ERROR') }}{{ test.test_nothing(model=ref('model'), ) }}", 'refs': [['model']], @@ -976,10 +984,12 @@ def expected_seeded_manifest(self, model_database=None): 'schema': my_schema_name, 'database': self.default_database, 'tags': ['schema'], - 'unique_id': 'test.test.nothing_model_' + 'unique_id': 'test.test.nothing_model_', + 'docrefs': [], }, 'test.test.unique_model_id': { 'alias': 'unique_model_id', + 'build_path': None, 'column_name': 'id', 'columns': {}, 'config': { @@ -997,11 +1007,11 @@ def expected_seeded_manifest(self, model_database=None): 'sources': [], 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, 'description': '', - 'empty': False, 'fqn': ['test', 'schema_test', 'unique_model_id'], 'name': 'unique_model_id', 'original_file_path': schema_yml_path, 'package_name': 'test', + 'patch_path': None, 'path': _normalize('schema_test/unique_model_id.sql'), 'raw_sql': "{{ config(severity='ERROR') }}{{ test_unique(model=ref('model'), column_name='id') }}", 'refs': [['model']], @@ -1011,6 +1021,7 @@ def expected_seeded_manifest(self, model_database=None): 'database': self.default_database, 'tags': ['schema'], 'unique_id': 'test.test.unique_model_id', + 'docrefs': [], }, }, 'parent_map': { @@ -1036,8 +1047,8 @@ def expected_seeded_manifest(self, model_database=None): }, 'metadata': { 'project_id': '098f6bcd4621d373cade4e832627b4f6', - 'user_id': None, 'send_anonymous_usage_stats': False, + 'user_id': None, }, 'disabled': [], } @@ -1053,6 +1064,7 @@ def expected_postgres_references_manifest(self, model_database=None): 'nodes': { 'model.test.ephemeral_copy': { 'alias': 'ephemeral_copy', + 'build_path': None, 'columns': {}, 'config': { 'column_types': {}, @@ -1071,11 +1083,13 @@ def expected_postgres_references_manifest(self, model_database=None): 'nodes': ['source.test.my_source.my_table'] }, 'description': '', - 'empty': False, + 'docrefs': [], 'fqn': ['test', 'ephemeral_copy'], + 'index': None, 'name': 'ephemeral_copy', 'original_file_path': self.dir('ref_models/ephemeral_copy.sql'), 'package_name': 'test', + 'patch_path': None, 'path': 'ephemeral_copy.sql', 'raw_sql': LineIndifferent( '{{\n config(\n materialized = "ephemeral"\n )\n}}' @@ -1087,10 +1101,11 @@ def expected_postgres_references_manifest(self, model_database=None): 'schema': my_schema_name, 'database': self.default_database, 'tags': [], - 'unique_id': 'model.test.ephemeral_copy' + 'unique_id': 'model.test.ephemeral_copy', }, 'model.test.ephemeral_summary': { 'alias': 'ephemeral_summary', + 'build_path': None, 'columns': { 'first_name': { 'description': 'The first name being summarized', @@ -1130,12 +1145,13 @@ def expected_postgres_references_manifest(self, model_database=None): 'documentation_package': '' }, { + 'column_name': None, 'documentation_name': 'ephemeral_summary', 'documentation_package': '' } ], - 'empty': False, 'fqn': ['test', 'ephemeral_summary'], + 'index': None, 'name': 'ephemeral_summary', 'original_file_path': self.dir('ref_models/ephemeral_summary.sql'), 'package_name': 'test', @@ -1156,6 +1172,7 @@ def expected_postgres_references_manifest(self, model_database=None): 'unique_id': 'model.test.ephemeral_summary'}, 'model.test.view_summary': { 'alias': 'view_summary', + 'build_path': None, 'columns': { 'first_name': { 'description': 'The first name being summarized', @@ -1177,7 +1194,7 @@ def expected_postgres_references_manifest(self, model_database=None): 'vars': config_vars, 'tags': [], }, - 'sources': [], + 'database': self.default_database, 'depends_on': { 'macros': [], 'nodes': ['model.test.ephemeral_summary'] @@ -1195,12 +1212,13 @@ def expected_postgres_references_manifest(self, model_database=None): 'documentation_package': '' }, { + 'column_name': None, 'documentation_name': 'view_summary', 'documentation_package': '' } ], - 'empty': False, 'fqn': ['test', 'view_summary'], + 'index': None, 'name': 'view_summary', 'original_file_path': self.dir('ref_models/view_summary.sql'), 'package_name': 'test', @@ -1215,12 +1233,13 @@ def expected_postgres_references_manifest(self, model_database=None): 'resource_type': 'model', 'root_path': OneOf(self.test_root_dir, self.initial_dir), 'schema': my_schema_name, - 'database': self.default_database, + 'sources': [], 'tags': [], 'unique_id': 'model.test.view_summary' }, 'seed.test.seed': { 'alias': 'seed', + 'build_path': None, 'columns': {}, 'config': { 'column_types': {}, @@ -1236,11 +1255,13 @@ def expected_postgres_references_manifest(self, model_database=None): 'sources': [], 'depends_on': {'macros': [], 'nodes': []}, 'description': '', - 'empty': False, + 'docrefs': [], 'fqn': ['test', 'seed'], + 'index': None, 'name': 'seed', 'original_file_path': self.dir('seed/seed.csv'), 'package_name': 'test', + 'patch_path': None, 'path': 'seed.csv', 'raw_sql': '-- csv --', 'refs': [], @@ -1260,21 +1281,24 @@ def expected_postgres_references_manifest(self, model_database=None): }, 'quoting': { 'database': False, + 'schema': None, 'identifier': True, }, 'database': self.default_database, 'description': 'My table', 'docrefs': [ { - 'documentation_name': 'table_info', - 'documentation_package': '' + 'column_name': None, + 'documentation_name': 'table_info', + 'documentation_package': '', }, { - 'documentation_name': 'source_info', - 'documentation_package': '' + 'column_name': None, + 'documentation_name': 'source_info', + 'documentation_package': '', } ], - 'freshness': {}, + 'freshness': {'error_after': None, 'warn_after': None}, 'identifier': 'seed', 'loaded_at_field': None, 'loader': 'a_loader', @@ -1302,7 +1326,6 @@ def expected_postgres_references_manifest(self, model_database=None): 'original_file_path': docs_path, 'package_name': 'test', 'path': 'docs.md', - 'resource_type': 'docs', 'root_path': OneOf(self.test_root_dir, self.initial_dir), 'unique_id': 'test.ephemeral_summary' }, @@ -1313,7 +1336,6 @@ def expected_postgres_references_manifest(self, model_database=None): 'original_file_path': docs_path, 'package_name': 'test', 'path': 'docs.md', - 'resource_type': 'docs', 'root_path': OneOf(self.test_root_dir, self.initial_dir), 'unique_id': 'test.source_info', }, @@ -1324,7 +1346,6 @@ def expected_postgres_references_manifest(self, model_database=None): 'original_file_path': docs_path, 'package_name': 'test', 'path': 'docs.md', - 'resource_type': 'docs', 'root_path': OneOf(self.test_root_dir, self.initial_dir), 'unique_id': 'test.summary_count' }, @@ -1335,7 +1356,6 @@ def expected_postgres_references_manifest(self, model_database=None): 'original_file_path': docs_path, 'package_name': 'test', 'path': 'docs.md', - 'resource_type': 'docs', 'root_path': OneOf(self.test_root_dir, self.initial_dir), 'unique_id': 'test.summary_first_name' }, @@ -1346,7 +1366,6 @@ def expected_postgres_references_manifest(self, model_database=None): 'original_file_path': docs_path, 'package_name': 'test', 'path': 'docs.md', - 'resource_type': 'docs', 'root_path': OneOf(self.test_root_dir, self.initial_dir), 'unique_id': 'test.table_info' }, @@ -1360,7 +1379,6 @@ def expected_postgres_references_manifest(self, model_database=None): 'original_file_path': docs_path, 'package_name': 'test', 'path': 'docs.md', - 'resource_type': 'docs', 'root_path': OneOf(self.test_root_dir, self.initial_dir), 'unique_id': 'test.view_summary' }, @@ -1381,8 +1399,8 @@ def expected_postgres_references_manifest(self, model_database=None): }, 'metadata': { 'project_id': '098f6bcd4621d373cade4e832627b4f6', - 'user_id': None, 'send_anonymous_usage_stats': False, + 'user_id': None, }, 'disabled': [], } @@ -1413,8 +1431,9 @@ def expected_bigquery_complex_manifest(self): }, 'sources': [], 'depends_on': {'macros': [], 'nodes': ['seed.test.seed']}, - 'empty': False, 'fqn': ['test', 'clustered'], + 'index': None, + 'build_path': None, 'name': 'clustered', 'original_file_path': clustered_sql_path, 'package_name': 'test', @@ -1455,6 +1474,7 @@ def expected_bigquery_complex_manifest(self): }, 'model.test.multi_clustered': { 'alias': 'multi_clustered', + 'build_path': None, 'config': { 'cluster_by': ['first_name', 'email'], 'column_types': {}, @@ -1470,8 +1490,8 @@ def expected_bigquery_complex_manifest(self): }, 'sources': [], 'depends_on': {'macros': [], 'nodes': ['seed.test.seed']}, - 'empty': False, 'fqn': ['test', 'multi_clustered'], + 'index': None, 'name': 'multi_clustered', 'original_file_path': multi_clustered_sql_path, 'package_name': 'test', @@ -1512,6 +1532,7 @@ def expected_bigquery_complex_manifest(self): }, 'model.test.nested_view': { 'alias': 'nested_view', + 'build_path': None, 'config': { 'column_types': {}, 'enabled': True, @@ -1528,8 +1549,8 @@ def expected_bigquery_complex_manifest(self): 'macros': [], 'nodes': ['model.test.nested_table'] }, - 'empty': False, 'fqn': ['test', 'nested_view'], + 'index': None, 'name': 'nested_view', 'original_file_path': nested_view_sql_path, 'package_name': 'test', @@ -1570,6 +1591,7 @@ def expected_bigquery_complex_manifest(self): }, 'model.test.nested_table': { 'alias': 'nested_table', + 'build_path': None, 'config': { 'column_types': {}, 'enabled': True, @@ -1586,11 +1608,12 @@ def expected_bigquery_complex_manifest(self): 'macros': [], 'nodes': [] }, - 'empty': False, 'fqn': ['test', 'nested_table'], + 'index': None, 'name': 'nested_table', 'original_file_path': nested_table_sql_path, 'package_name': 'test', + 'patch_path': None, 'path': 'nested_table.sql', 'raw_sql': LineIndifferent(_read_file(nested_table_sql_path).rstrip('\r\n')), 'refs': [], @@ -1602,8 +1625,12 @@ def expected_bigquery_complex_manifest(self): 'unique_id': 'model.test.nested_table', 'columns': {}, 'description': '', + 'docrefs': [], }, 'seed.test.seed': { + 'build_path': None, + 'index': None, + 'patch_path': None, 'path': 'seed.csv', 'name': 'seed', 'root_path': OneOf(self.test_root_dir, self.initial_dir), @@ -1618,7 +1645,6 @@ def expected_bigquery_complex_manifest(self): 'macros': [], }, 'unique_id': 'seed.test.seed', - 'empty': False, 'fqn': ['test', 'seed'], 'tags': [], 'config': { @@ -1637,6 +1663,7 @@ def expected_bigquery_complex_manifest(self): 'alias': 'seed', 'columns': {}, 'description': '', + 'docrefs': [], }, }, 'child_map': { @@ -1658,8 +1685,8 @@ def expected_bigquery_complex_manifest(self): }, 'metadata': { 'project_id': '098f6bcd4621d373cade4e832627b4f6', - 'user_id': None, 'send_anonymous_usage_stats': False, + 'user_id': None, }, 'disabled': [], } @@ -1671,6 +1698,8 @@ def expected_redshift_incremental_view_manifest(self): return { "nodes": { "model.test.model": { + "build_path": None, + "index": None, "name": "model", "root_path": self.test_root_dir, "resource_type": "model", @@ -1685,7 +1714,6 @@ def expected_redshift_incremental_view_manifest(self): "macros": [], }, "unique_id": "model.test.model", - "empty": False, "fqn": ["test", "model"], "tags": [], "config": { @@ -1730,6 +1758,9 @@ def expected_redshift_incremental_view_manifest(self): 'docrefs': [], }, "seed.test.seed": { + "build_path": None, + "index": None, + "patch_path": None, "path": "seed.csv", "name": "seed", "root_path": self.test_root_dir, @@ -1744,7 +1775,6 @@ def expected_redshift_incremental_view_manifest(self): "macros": [], }, "unique_id": "seed.test.seed", - "empty": False, "fqn": ["test", "seed"], "tags": [], "config": { @@ -1763,6 +1793,7 @@ def expected_redshift_incremental_view_manifest(self): "alias": "seed", 'columns': {}, 'description': '', + 'docrefs': [], }, }, "parent_map": { @@ -1778,8 +1809,8 @@ def expected_redshift_incremental_view_manifest(self): }, 'metadata': { 'project_id': '098f6bcd4621d373cade4e832627b4f6', - 'user_id': None, 'send_anonymous_usage_stats': False, + 'user_id': None, }, 'disabled': [], } @@ -1847,7 +1878,6 @@ def expected_run_results(self, quote_schema=True, quote_model=False, compiled_seed = self._quote('seed') if quote_model else 'seed' if self.adapter_type == 'bigquery': - status = 'OK' compiled_sql = '\n\nselect * from `{}`.`{}`.seed'.format( self.default_database, schema ) @@ -1884,10 +1914,10 @@ def expected_run_results(self, quote_schema=True, quote_model=False, }, 'description': 'The test model', 'docrefs': [], - 'empty': False, 'extra_ctes': [], 'extra_ctes_injected': True, 'fqn': ['test', 'model'], + 'index': None, 'injected_sql': compiled_sql, 'name': 'model', 'original_file_path': model_sql_path, @@ -1935,14 +1965,16 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'sources': [], 'depends_on': {'macros': [], 'nodes': []}, 'description': '', - 'empty': False, + 'docrefs': [], 'extra_ctes': [], 'extra_ctes_injected': True, 'fqn': ['test', 'seed'], + 'index': None, 'injected_sql': '-- csv --', 'name': 'seed', 'original_file_path': self.dir('seed/seed.csv'), 'package_name': 'test', + 'patch_path': None, 'path': 'seed.csv', 'raw_sql': '-- csv --', 'refs': [], @@ -1952,7 +1984,7 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'database': self.default_database, 'tags': [], 'unique_id': 'seed.test.seed', - 'wrapped_sql': 'None' + 'wrapped_sql': 'None', }, 'thread_id': ANY, 'timing': [ANY, ANY], @@ -1985,7 +2017,7 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'sources': [], 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, 'description': '', - 'empty': False, + 'docrefs': [], 'extra_ctes': [], 'extra_ctes_injected': True, 'fqn': ['test', 'schema_test', 'not_null_model_id'], @@ -1993,6 +2025,7 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'name': 'not_null_model_id', 'original_file_path': schema_yml_path, 'package_name': 'test', + 'patch_path': None, 'path': _normalize('schema_test/not_null_model_id.sql'), 'raw_sql': "{{ config(severity='ERROR') }}{{ test_not_null(model=ref('model'), column_name='id') }}", 'refs': [['model']], @@ -2016,6 +2049,7 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'node': { 'alias': 'nothing_model_', 'build_path': _normalize('target/compiled/test/schema_test/nothing_model_.sql'), + 'column_name': None, 'columns': {}, 'compiled': True, 'compiled_sql': AnyStringWith('select 0'), @@ -2031,10 +2065,10 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'tags': [], 'severity': 'ERROR', }, - 'sources': [], + 'database': self.default_database, 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, 'description': '', - 'empty': False, + 'docrefs': [], 'extra_ctes': [], 'extra_ctes_injected': True, 'fqn': ['test', 'schema_test', 'nothing_model_'], @@ -2042,13 +2076,14 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'name': 'nothing_model_', 'original_file_path': schema_yml_path, 'package_name': 'test', + 'patch_path': None, 'path': _normalize('schema_test/nothing_model_.sql'), 'raw_sql': "{{ config(severity='ERROR') }}{{ test.test_nothing(model=ref('model'), ) }}", 'refs': [['model']], 'resource_type': 'test', 'root_path': OneOf(self.test_root_dir, self.initial_dir), 'schema': schema, - 'database': self.default_database, + 'sources': [], 'tags': ['schema'], 'unique_id': 'test.test.nothing_model_', 'wrapped_sql': AnyStringWith('select 0'), @@ -2081,10 +2116,10 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'tags': [], 'severity': 'ERROR', }, - 'sources': [], + 'database': self.default_database, 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, 'description': '', - 'empty': False, + 'docrefs': [], 'extra_ctes': [], 'extra_ctes_injected': True, 'fqn': ['test', 'schema_test', 'unique_model_id'], @@ -2092,13 +2127,14 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'name': 'unique_model_id', 'original_file_path': schema_yml_path, 'package_name': 'test', + 'patch_path': None, 'path': _normalize('schema_test/unique_model_id.sql'), 'raw_sql': "{{ config(severity='ERROR') }}{{ test_unique(model=ref('model'), column_name='id') }}", 'refs': [['model']], 'resource_type': 'test', 'root_path': OneOf(self.test_root_dir, self.initial_dir), 'schema': schema, - 'database': self.default_database, + 'sources': [], 'tags': ['schema'], 'unique_id': 'test.test.unique_model_id', 'wrapped_sql': AnyStringWith('count(*)') @@ -2187,16 +2223,17 @@ def expected_postgres_references_run_results(self): 'documentation_package': '' }, { + 'column_name': None, 'documentation_name': 'ephemeral_summary', 'documentation_package': '' } ], - 'empty': False, 'extra_ctes': [ {'id': 'model.test.ephemeral_copy', 'sql': cte_sql}, ], 'extra_ctes_injected': True, 'fqn': ['test', 'ephemeral_summary'], + 'index': None, 'injected_sql': ephemeral_injected_sql, 'name': 'ephemeral_summary', 'original_file_path': self.dir('ref_models/ephemeral_summary.sql'), @@ -2277,14 +2314,15 @@ def expected_postgres_references_run_results(self): 'documentation_package': '' }, { + 'column_name': None, 'documentation_name': 'view_summary', 'documentation_package': '' } ], - 'empty': False, 'extra_ctes': [], 'extra_ctes_injected': True, 'fqn': ['test', 'view_summary'], + 'index': None, 'injected_sql': view_compiled_sql, 'name': 'view_summary', 'original_file_path': self.dir('ref_models/view_summary.sql'), @@ -2336,14 +2374,16 @@ def expected_postgres_references_run_results(self): 'sources': [], 'depends_on': {'macros': [], 'nodes': []}, 'description': '', - 'empty': False, + 'docrefs': [], 'extra_ctes': [], 'extra_ctes_injected': True, 'fqn': ['test', 'seed'], + 'index': None, 'injected_sql': '-- csv --', 'name': 'seed', 'original_file_path': self.dir('seed/seed.csv'), 'package_name': 'test', + 'patch_path': None, 'path': 'seed.csv', 'raw_sql': '-- csv --', 'refs': [], diff --git a/test/integration/035_changing_relation_type_test/test_changing_relation_type.py b/test/integration/035_changing_relation_type_test/test_changing_relation_type.py index e4360112428..c8d6307b441 100644 --- a/test/integration/035_changing_relation_type_test/test_changing_relation_type.py +++ b/test/integration/035_changing_relation_type_test/test_changing_relation_type.py @@ -20,23 +20,23 @@ def swap_types_and_test(self): # between materializations that create tables and views. results = self.run_dbt(['run', '--vars', 'materialized: view']) - self.assertEqual(results[0].node['config']['materialized'], 'view') + self.assertEqual(results[0].node.config.materialized, 'view') self.assertEqual(len(results), 1) results = self.run_dbt(['run', '--vars', 'materialized: table']) - self.assertEqual(results[0].node['config']['materialized'], 'table') + self.assertEqual(results[0].node.config.materialized, 'table') self.assertEqual(len(results), 1) results = self.run_dbt(['run', '--vars', 'materialized: view']) - self.assertEqual(results[0].node['config']['materialized'], 'view') + self.assertEqual(results[0].node.config.materialized, 'view') self.assertEqual(len(results), 1) results = self.run_dbt(['run', '--vars', 'materialized: incremental']) - self.assertEqual(results[0].node['config']['materialized'], 'incremental') + self.assertEqual(results[0].node.config.materialized, 'incremental') self.assertEqual(len(results), 1) results = self.run_dbt(['run', '--vars', 'materialized: view']) - self.assertEqual(results[0].node['config']['materialized'], 'view') + self.assertEqual(results[0].node.config.materialized, 'view') self.assertEqual(len(results), 1) @use_profile("postgres") @@ -59,23 +59,23 @@ def test__bigquery__switch_materialization(self): # and then remove these bq-specific tests results = self.run_dbt(['run', '--vars', 'materialized: view']) - self.assertEqual(results[0].node['config']['materialized'], 'view') + self.assertEqual(results[0].node.config.materialized, 'view') self.assertEqual(len(results), 1) results = self.run_dbt(['run', '--vars', 'materialized: table']) - self.assertEqual(results[0].node['config']['materialized'], 'table') + self.assertEqual(results[0].node.config.materialized, 'table') self.assertEqual(len(results), 1) results = self.run_dbt(['run', '--vars', 'materialized: view', "--full-refresh"]) - self.assertEqual(results[0].node['config']['materialized'], 'view') + self.assertEqual(results[0].node.config.materialized, 'view') self.assertEqual(len(results), 1) results = self.run_dbt(['run', '--vars', 'materialized: incremental']) - self.assertEqual(results[0].node['config']['materialized'], 'incremental') + self.assertEqual(results[0].node.config.materialized, 'incremental') self.assertEqual(len(results), 1) results = self.run_dbt(['run', '--vars', 'materialized: view', "--full-refresh"]) - self.assertEqual(results[0].node['config']['materialized'], 'view') + self.assertEqual(results[0].node.config.materialized, 'view') self.assertEqual(len(results), 1) @use_profile('presto') @@ -83,13 +83,13 @@ def test__presto__switch_materialization(self): # presto can't do incremental materializations so there's less to this results = self.run_dbt(['run', '--vars', 'materialized: view']) - self.assertEqual(results[0].node['config']['materialized'], 'view') + self.assertEqual(results[0].node.config.materialized, 'view') self.assertEqual(len(results), 1) results = self.run_dbt(['run', '--vars', 'materialized: table']) - self.assertEqual(results[0].node['config']['materialized'], 'table') + self.assertEqual(results[0].node.config.materialized, 'table') self.assertEqual(len(results), 1) results = self.run_dbt(['run', '--vars', 'materialized: view']) - self.assertEqual(results[0].node['config']['materialized'], 'view') + self.assertEqual(results[0].node.config.materialized, 'view') self.assertEqual(len(results), 1) diff --git a/test/integration/036_snowflake_view_dependency_test/test_view_binding_dependency.py b/test/integration/036_snowflake_view_dependency_test/test_view_binding_dependency.py index 1a11a912016..56096a2821f 100644 --- a/test/integration/036_snowflake_view_dependency_test/test_view_binding_dependency.py +++ b/test/integration/036_snowflake_view_dependency_test/test_view_binding_dependency.py @@ -68,7 +68,7 @@ def test__snowflake__changed_table_schema_for_downstream_view_changed_to_table(s # ensure that the model actually was materialized as a table for result in results: node_name = result.node.name - self.assertEqual(result.node.config['materialized'], expected_types[node_name]) + self.assertEqual(result.node.config.materialized, expected_types[node_name]) results = self.run_dbt(["run", "--vars", "{add_table_field: true, dependent_type: table}"]) self.assertEqual(len(results), 2) @@ -82,7 +82,7 @@ def test__snowflake__changed_table_schema_for_downstream_view_changed_to_table(s # ensure that the model actually was materialized as a table for result in results: node_name = result.node.name - self.assertEqual(result.node.config['materialized'], expected_types[node_name]) + self.assertEqual(result.node.config.materialized, expected_types[node_name]) @use_profile('presto') def test__presto__changed_table_schema_for_downstream_view(self): @@ -115,7 +115,7 @@ def test__presto__changed_table_schema_for_downstream_view_changed_to_table(self # ensure that the model actually was materialized as a table for result in results: node_name = result.node.name - self.assertEqual(result.node.config['materialized'], expected_types[node_name]) + self.assertEqual(result.node.config.materialized, expected_types[node_name]) results = self.run_dbt(["run", "--vars", "{add_table_field: true, dependent_type: table}"]) self.assertEqual(len(results), 2) @@ -129,4 +129,4 @@ def test__presto__changed_table_schema_for_downstream_view_changed_to_table(self # ensure that the model actually was materialized as a table for result in results: node_name = result.node.name - self.assertEqual(result.node.config['materialized'], expected_types[node_name]) + self.assertEqual(result.node.config.materialized, expected_types[node_name]) diff --git a/test/integration/041_presto_test/test_simple_presto_view.py b/test/integration/041_presto_test/test_simple_presto_view.py index 595cd6c21f9..18871872a09 100644 --- a/test/integration/041_presto_test/test_simple_presto_view.py +++ b/test/integration/041_presto_test/test_simple_presto_view.py @@ -29,7 +29,7 @@ def assert_nondupes_pass(self): test_results = self.run_dbt(['test'], expect_pass=False) for result in test_results: - if 'dupe' in result.node.get('name'): + if 'dupe' in result.node.name: self.assertIsNone(result.error) self.assertFalse(result.skipped) self.assertTrue(result.status > 0) diff --git a/test/integration/042_sources_test/models/schema.yml b/test/integration/042_sources_test/models/schema.yml index d7ee5c93ded..0b7dfe7263d 100644 --- a/test/integration/042_sources_test/models/schema.yml +++ b/test/integration/042_sources_test/models/schema.yml @@ -20,7 +20,8 @@ sources: - name: test_table identifier: source loaded_at_field: updated_at - error_after: {count: 18, period: hour} + freshness: + error_after: {count: 18, period: hour} columns: - name: favorite_color description: The favorite color diff --git a/test/integration/042_sources_test/test_sources.py b/test/integration/042_sources_test/test_sources.py index 19a9ded3e64..c23dddaef63 100644 --- a/test/integration/042_sources_test/test_sources.py +++ b/test/integration/042_sources_test/test_sources.py @@ -179,7 +179,7 @@ def setUp(self): def _set_updated_at_to(self, delta): insert_time = datetime.utcnow() + delta timestr = insert_time.strftime("%Y-%m-%d %H:%M:%S") - #favorite_color,id,first_name,email,ip_address,updated_at + # favorite_color,id,first_name,email,ip_address,updated_at insert_id = self._id self._id += 1 raw_sql = """INSERT INTO {schema}.{source} @@ -211,8 +211,6 @@ def _assert_freshness_results(self, path, state): self.freshness_start_time) last_inserted_time = self.last_inserted_time - if last_inserted_time is None: - last_inserted_time = "2016-09-19T14:45:51+00:00" self.assertEqual(data['sources'], { 'source.test.test_source.test_table': { @@ -222,7 +220,7 @@ def _assert_freshness_results(self, path, state): 'state': state, 'criteria': { 'warn_after': {'count': 10, 'period': 'hour'}, - 'error_after': {'count': 1, 'period': 'day'}, + 'error_after': {'count': 18, 'period': 'hour'}, }, } }) @@ -235,7 +233,7 @@ def _run_source_freshness(self): ) self.assertEqual(len(results), 1) self.assertEqual(results[0].status, 'error') - self.assertTrue(results[0].failed) + self.assertTrue(results[0].fail) self.assertIsNone(results[0].error) self._assert_freshness_results('target/error_source.json', 'error') @@ -246,7 +244,7 @@ def _run_source_freshness(self): ) self.assertEqual(len(results), 1) self.assertEqual(results[0].status, 'warn') - self.assertFalse(results[0].failed) + self.assertFalse(results[0].fail) self.assertIsNone(results[0].error) self._assert_freshness_results('target/warn_source.json', 'warn') @@ -257,7 +255,7 @@ def _run_source_freshness(self): ) self.assertEqual(len(results), 1) self.assertEqual(results[0].status, 'pass') - self.assertFalse(results[0].failed) + self.assertFalse(results[0].fail) self.assertIsNone(results[0].error) self._assert_freshness_results('target/pass_source.json', 'pass') @@ -291,7 +289,7 @@ def test_postgres_error(self): ) self.assertEqual(len(results), 1) self.assertEqual(results[0].status, 'error') - self.assertFalse(results[0].failed) + self.assertFalse(results[0].fail) self.assertIsNotNone(results[0].error) diff --git a/test/unit/test_bigquery_adapter.py b/test/unit/test_bigquery_adapter.py index 8e7cba54bcf..21f697e3091 100644 --- a/test/unit/test_bigquery_adapter.py +++ b/test/unit/test_bigquery_adapter.py @@ -79,7 +79,7 @@ def test_acquire_connection_oauth_validations(self, mock_open_connection): adapter = self.get_adapter('oauth') try: connection = adapter.acquire_connection('dummy') - self.assertEqual(connection.get('type'), 'bigquery') + self.assertEqual(connection.type, 'bigquery') except dbt.exceptions.ValidationException as e: self.fail('got ValidationException: {}'.format(str(e))) @@ -94,7 +94,7 @@ def test_acquire_connection_service_account_validations(self, mock_open_connecti adapter = self.get_adapter('service_account') try: connection = adapter.acquire_connection('dummy') - self.assertEqual(connection.get('type'), 'bigquery') + self.assertEqual(connection.type, 'bigquery') except dbt.exceptions.ValidationException as e: self.fail('got ValidationException: {}'.format(str(e))) diff --git a/test/unit/test_compiler.py b/test/unit/test_compiler.py index a3d67db5005..96184b82eb5 100644 --- a/test/unit/test_compiler.py +++ b/test/unit/test_compiler.py @@ -4,12 +4,15 @@ import dbt.flags import dbt.compilation -from collections import OrderedDict from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.graph.compiled import CompiledNode +from dbt.contracts.graph.parsed import NodeConfig, DependsOn +from dbt.contracts.graph.compiled import CompiledNode, InjectedCTE +from dbt.node_types import NodeType + +from datetime import datetime -class CompilerTest(unittest.TestCase): +class CompilerTest(unittest.TestCase): def assertEqualIgnoreWhitespace(self, a, b): self.assertEqual( "".join(a.split()), @@ -33,7 +36,7 @@ def setUp(self): 'project-root': os.path.abspath('./dbt_modules/snowplow'), } - self.model_config = { + self.model_config = NodeConfig.from_dict({ 'enabled': True, 'materialized': 'view', 'persist_docs': {}, @@ -43,11 +46,10 @@ def setUp(self): 'quoting': {}, 'column_types': {}, 'tags': [], - } + }) def test__prepend_ctes__already_has_cte(self): - ephemeral_config = self.model_config.copy() - ephemeral_config['materialized'] = 'ephemeral' + ephemeral_config = self.model_config.replace(materialized='ephemeral') input_graph = Manifest( macros={}, @@ -57,20 +59,14 @@ def test__prepend_ctes__already_has_cte(self): database='dbt', schema='analytics', alias='view', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.view', fqn=['root_project', 'view'], - empty=False, package_name='root', root_path='/usr/src/app', refs=[], sources=[], - depends_on={ - 'nodes': [ - 'model.root.ephemeral' - ], - 'macros': [] - }, + depends_on=DependsOn(nodes=['model.root.ephemeral']), config=self.model_config, tags=[], path='view.sql', @@ -78,9 +74,7 @@ def test__prepend_ctes__already_has_cte(self): raw_sql='select * from {{ref("ephemeral")}}', compiled=True, extra_ctes_injected=False, - extra_ctes=[ - {'id': 'model.root.ephemeral', 'sql': None} - ], + extra_ctes=[InjectedCTE(id='model.root.ephemeral')], injected_sql='', compiled_sql=( 'with cte as (select * from something_else) ' @@ -91,18 +85,14 @@ def test__prepend_ctes__already_has_cte(self): database='dbt', schema='analytics', alias='view', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.ephemeral', fqn=['root_project', 'ephemeral'], - empty=False, package_name='root', root_path='/usr/src/app', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=ephemeral_config, tags=[], path='ephemeral.sql', @@ -116,7 +106,8 @@ def test__prepend_ctes__already_has_cte(self): ), }, docs={}, - generated_at='2018-02-14T09:15:13Z', + # '2018-02-14T09:15:13Z' + generated_at=datetime(2018, 2, 14, 9, 15, 13), disabled=[] ) @@ -125,18 +116,16 @@ def test__prepend_ctes__already_has_cte(self): input_graph) self.assertEqual(result, output_graph.nodes['model.root.view']) - self.assertEqual(result.get('extra_ctes_injected'), True) + self.assertEqual(result.extra_ctes_injected, True) self.assertEqualIgnoreWhitespace( - result.get('injected_sql'), + result.injected_sql, ('with __dbt__CTE__ephemeral as (' 'select * from source_table' '), cte as (select * from something_else) ' 'select * from __dbt__CTE__ephemeral')) self.assertEqual( - (input_graph.nodes - .get('model.root.ephemeral', {}) - .get('extra_ctes_injected')), + input_graph.nodes['model.root.ephemeral'].extra_ctes_injected, True) def test__prepend_ctes__no_ctes(self): @@ -148,48 +137,40 @@ def test__prepend_ctes__no_ctes(self): database='dbt', schema='analytics', alias='view', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.view', fqn=['root_project', 'view'], - empty=False, package_name='root', root_path='/usr/src/app', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='view.sql', original_file_path='view.sql', raw_sql=('with cte as (select * from something_else) ' - 'select * from source_table'), + 'select * from source_table'), compiled=True, extra_ctes_injected=False, extra_ctes=[], injected_sql='', compiled_sql=('with cte as (select * from something_else) ' - 'select * from source_table') + 'select * from source_table') ), 'model.root.view_no_cte': CompiledNode( name='view_no_cte', database='dbt', schema='analytics', alias='view_no_cte', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.view_no_cte', fqn=['root_project', 'view_no_cte'], - empty=False, package_name='root', root_path='/usr/src/app', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='view.sql', @@ -214,12 +195,10 @@ def test__prepend_ctes__no_ctes(self): self.assertEqual( result, output_graph.nodes.get('model.root.view')) - self.assertEqual(result.get('extra_ctes_injected'), True) + self.assertTrue(result.extra_ctes_injected) self.assertEqualIgnoreWhitespace( - result.get('injected_sql'), - (output_graph.nodes - .get('model.root.view') - .get('compiled_sql'))) + result.injected_sql, + output_graph.nodes.get('model.root.view').compiled_sql) result, output_graph = dbt.compilation.prepend_ctes( input_graph.nodes.get('model.root.view_no_cte'), @@ -228,16 +207,13 @@ def test__prepend_ctes__no_ctes(self): self.assertEqual( result, output_graph.nodes.get('model.root.view_no_cte')) - self.assertEqual(result.get('extra_ctes_injected'), True) + self.assertTrue(result.extra_ctes_injected) self.assertEqualIgnoreWhitespace( - result.get('injected_sql'), - (output_graph.nodes - .get('model.root.view_no_cte') - .get('compiled_sql'))) + result.injected_sql, + output_graph.nodes.get('model.root.view_no_cte').compiled_sql) def test__prepend_ctes(self): - ephemeral_config = self.model_config.copy() - ephemeral_config['materialized'] = 'ephemeral' + ephemeral_config = self.model_config.replace(materialized='ephemeral') input_graph = Manifest( macros={}, @@ -247,20 +223,14 @@ def test__prepend_ctes(self): database='dbt', schema='analytics', alias='view', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.view', fqn=['root_project', 'view'], - empty=False, package_name='root', root_path='/usr/src/app', refs=[], sources=[], - depends_on={ - 'nodes': [ - 'model.root.ephemeral' - ], - 'macros': [] - }, + depends_on=DependsOn(nodes=['model.root.ephemeral']), config=self.model_config, tags=[], path='view.sql', @@ -268,9 +238,7 @@ def test__prepend_ctes(self): raw_sql='select * from {{ref("ephemeral")}}', compiled=True, extra_ctes_injected=False, - extra_ctes=[ - {'id': 'model.root.ephemeral', 'sql': None} - ], + extra_ctes=[InjectedCTE(id='model.root.ephemeral')], injected_sql='', compiled_sql='select * from __dbt__CTE__ephemeral' ), @@ -279,18 +247,14 @@ def test__prepend_ctes(self): database='dbt', schema='analytics', alias='ephemeral', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.ephemeral', fqn=['root_project', 'ephemeral'], - empty=False, package_name='root', root_path='/usr/src/app', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=ephemeral_config, tags=[], path='ephemeral.sql', @@ -313,26 +277,20 @@ def test__prepend_ctes(self): input_graph) self.assertEqual(result, - (output_graph.nodes - .get('model.root.view'))) + output_graph.nodes.get('model.root.view')) - self.assertEqual(result.get('extra_ctes_injected'), True) + self.assertTrue(result.extra_ctes_injected) self.assertEqualIgnoreWhitespace( - result.get('injected_sql'), + result.injected_sql, ('with __dbt__CTE__ephemeral as (' 'select * from source_table' ') ' 'select * from __dbt__CTE__ephemeral')) - self.assertEqual( - (output_graph.nodes - .get('model.root.ephemeral', {}) - .get('extra_ctes_injected')), - True) + self.assertTrue(output_graph.nodes['model.root.ephemeral'].extra_ctes_injected) def test__prepend_ctes__multiple_levels(self): - ephemeral_config = self.model_config.copy() - ephemeral_config['materialized'] = 'ephemeral' + ephemeral_config = self.model_config.replace(materialized='ephemeral') input_graph = Manifest( macros={}, @@ -342,20 +300,14 @@ def test__prepend_ctes__multiple_levels(self): database='dbt', schema='analytics', alias='view', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.view', fqn=['root_project', 'view'], - empty=False, package_name='root', root_path='/usr/src/app', refs=[], sources=[], - depends_on={ - 'nodes': [ - 'model.root.ephemeral' - ], - 'macros': [] - }, + depends_on=DependsOn(nodes=['model.root.ephemeral']), config=self.model_config, tags=[], path='view.sql', @@ -363,9 +315,7 @@ def test__prepend_ctes__multiple_levels(self): raw_sql='select * from {{ref("ephemeral")}}', compiled=True, extra_ctes_injected=False, - extra_ctes=[ - {'id': 'model.root.ephemeral', 'sql': None} - ], + extra_ctes=[InjectedCTE(id='model.root.ephemeral')], injected_sql='', compiled_sql='select * from __dbt__CTE__ephemeral' ), @@ -374,18 +324,14 @@ def test__prepend_ctes__multiple_levels(self): database='dbt', schema='analytics', alias='ephemeral', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.ephemeral', fqn=['root_project', 'ephemeral'], - empty=False, package_name='root', root_path='/usr/src/app', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=ephemeral_config, tags=[], path='ephemeral.sql', @@ -393,9 +339,7 @@ def test__prepend_ctes__multiple_levels(self): raw_sql='select * from {{ref("ephemeral_level_two")}}', compiled=True, extra_ctes_injected=False, - extra_ctes=[ - {'id': 'model.root.ephemeral_level_two', 'sql': None} - ], + extra_ctes=[InjectedCTE(id='model.root.ephemeral_level_two')], injected_sql='', compiled_sql='select * from __dbt__CTE__ephemeral_level_two' # noqa ), @@ -404,18 +348,14 @@ def test__prepend_ctes__multiple_levels(self): database='dbt', schema='analytics', alias='ephemeral_level_two', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.ephemeral_level_two', fqn=['root_project', 'ephemeral_level_two'], - empty=False, package_name='root', root_path='/usr/src/app', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=ephemeral_config, tags=[], path='ephemeral_level_two.sql', @@ -438,9 +378,9 @@ def test__prepend_ctes__multiple_levels(self): input_graph) self.assertEqual(result, input_graph.nodes['model.root.view']) - self.assertEqual(result.get('extra_ctes_injected'), True) + self.assertTrue(result.extra_ctes_injected) self.assertEqualIgnoreWhitespace( - result.get('injected_sql'), + result.injected_sql, ('with __dbt__CTE__ephemeral_level_two as (' 'select * from source_table' '), __dbt__CTE__ephemeral as (' @@ -448,13 +388,5 @@ def test__prepend_ctes__multiple_levels(self): ') ' 'select * from __dbt__CTE__ephemeral')) - self.assertEqual( - (output_graph.nodes - .get('model.root.ephemeral') - .get('extra_ctes_injected')), - True) - self.assertEqual( - (output_graph.nodes - .get('model.root.ephemeral_level_two') - .get('extra_ctes_injected')), - True) + self.assertTrue(output_graph.nodes['model.root.ephemeral'].extra_ctes_injected) + self.assertTrue(output_graph.nodes['model.root.ephemeral_level_two'].extra_ctes_injected) diff --git a/test/unit/test_config.py b/test/unit/test_config.py index 3770bc768a1..86bca572c7e 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -13,7 +13,7 @@ import dbt.exceptions from dbt.adapters.postgres import PostgresCredentials from dbt.adapters.redshift import RedshiftCredentials -from dbt.contracts.project import PackageConfig +from dbt.contracts.project import PackageConfig, LocalPackage, GitPackage from dbt.semver import VersionSpecifier from dbt.task.run_operation import RunOperationTask @@ -692,15 +692,12 @@ def test_all_overrides(self): }) self.assertEqual(project.dbt_version, [VersionSpecifier.from_version_string('>=0.1.0')]) - self.assertEqual(project.packages, PackageConfig(packages=[ - { - 'local': 'foo', - }, - { - 'git': 'git@example.com:fishtown-analytics/dbt-utils.git', - 'revision': 'test-rev' - }, - ])) + self.assertEqual( + project.packages, + PackageConfig(packages=[ + LocalPackage(local='foo'), + GitPackage(git='git@example.com:fishtown-analytics/dbt-utils.git', revision='test-rev') + ])) str(project) json.dumps(project.to_project_config()) @@ -1010,7 +1007,7 @@ def test_validate_fails(self): project = self.get_project() profile = self.get_profile() # invalid - must be boolean - profile.config.use_colors = None + profile.config.use_colors = 100 with self.assertRaises(dbt.exceptions.DbtProjectError): dbt.config.RuntimeConfig.from_parts(project, profile, {}) diff --git a/test/unit/test_context.py b/test/unit/test_context.py index 55721c026a2..2032d25cc4a 100644 --- a/test/unit/test_context.py +++ b/test/unit/test_context.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -from dbt.contracts.graph.parsed import ParsedNode +from dbt.contracts.graph.parsed import ParsedNode, NodeConfig, DependsOn from dbt.context import parser, runtime import dbt.exceptions from .mock_adapter import adapter_factory @@ -17,17 +17,13 @@ def setUp(self): resource_type='model', unique_id='model.root.model_one', fqn=['root', 'model_one'], - empty=False, package_name='root', original_file_path='model_one.sql', root_path='/usr/src/app', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, - config={ + depends_on=DependsOn(), + config=NodeConfig.from_dict({ 'enabled': True, 'materialized': 'view', 'persist_docs': {}, @@ -37,7 +33,7 @@ def setUp(self): 'quoting': {}, 'column_types': {}, 'tags': [], - }, + }), tags=[], path='model_one.sql', raw_sql='', diff --git a/test/unit/test_contracts_graph_compiled.py b/test/unit/test_contracts_graph_compiled.py new file mode 100644 index 00000000000..c403d826553 --- /dev/null +++ b/test/unit/test_contracts_graph_compiled.py @@ -0,0 +1,395 @@ +from dbt.contracts.graph.compiled import ( + CompiledNode, InjectedCTE, CompiledTestNode +) +from dbt.contracts.graph.parsed import ( + DependsOn, NodeConfig, TestConfig +) +from dbt.node_types import NodeType + +from .utils import ContractTestCase + + +class TestCompiledNode(ContractTestCase): + ContractType = CompiledNode + + def test_basic_uncompiled(self): + node_dict = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Model), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': '', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': [], + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + }, + 'docrefs': [], + 'columns': {}, + 'compiled': False, + 'extra_ctes': [], + 'extra_ctes_injected': False, + } + node = self.ContractType( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from wherever', + name='foo', + resource_type=NodeType.Model, + unique_id='model.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(), + description='', + database='test_db', + schema='test_schema', + alias='bar', + tags=[], + config=NodeConfig(), + compiled=False, + extra_ctes=[], + extra_ctes_injected=False, + ) + self.assert_symmetric(node, node_dict) + self.assertFalse(node.empty) + self.assertTrue(node.is_refable) + self.assertFalse(node.is_ephemeral) + self.assertEqual(node.local_vars(), {}) + + minimum = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Model), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'database': 'test_db', + 'schema': 'test_schema', + 'alias': 'bar', + } + self.assert_from_dict(node, minimum) + + def test_basic_compiled(self): + node_dict = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Model), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from {{ ref("other") }}', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': '', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': [], + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + }, + 'docrefs': [], + 'columns': {}, + 'compiled': True, + 'compiled_sql': 'select * from whatever', + 'extra_ctes': [{'id': 'whatever', 'sql': 'select * from other'}], + 'extra_ctes_injected': True, + 'injected_sql': 'with whatever as (select * from other) select * from whatever', + 'wrapped_sql': 'None', + } + node = self.ContractType( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from {{ ref("other") }}', + name='foo', + resource_type=NodeType.Model, + unique_id='model.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(), + description='', + database='test_db', + schema='test_schema', + alias='bar', + tags=[], + config=NodeConfig(), + compiled=True, + compiled_sql='select * from whatever', + extra_ctes=[InjectedCTE('whatever', 'select * from other')], + extra_ctes_injected=True, + injected_sql='with whatever as (select * from other) select * from whatever', + wrapped_sql='None', + ) + self.assert_symmetric(node, node_dict) + self.assertFalse(node.empty) + self.assertTrue(node.is_refable) + self.assertFalse(node.is_ephemeral) + self.assertEqual(node.local_vars(), {}) + + def test_invalid_extra_fields(self): + bad_extra = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Model), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'database': 'test_db', + 'schema': 'test_schema', + 'alias': 'bar', + 'notvalid': 'nope', + } + self.assert_fails_validation(bad_extra) + + def test_invalid_bad_type(self): + bad_type = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Macro), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'database': 'test_db', + 'schema': 'test_schema', + 'alias': 'bar', + } + self.assert_fails_validation(bad_type) + + +class TestCompiledTestNode(ContractTestCase): + ContractType = CompiledTestNode + + def test_basic_uncompiled(self): + node_dict = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Test), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': '', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': [], + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'severity': 'error', + }, + 'docrefs': [], + 'columns': {}, + 'compiled': False, + 'extra_ctes': [], + 'extra_ctes_injected': False, + } + node = self.ContractType( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from wherever', + name='foo', + resource_type=NodeType.Test, + unique_id='model.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(), + description='', + database='test_db', + schema='test_schema', + alias='bar', + tags=[], + config=TestConfig(), + compiled=False, + extra_ctes=[], + extra_ctes_injected=False, + ) + self.assert_symmetric(node, node_dict) + self.assertFalse(node.empty) + self.assertFalse(node.is_refable) + self.assertFalse(node.is_ephemeral) + self.assertEqual(node.local_vars(), {}) + + minimum = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Test), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'database': 'test_db', + 'schema': 'test_schema', + 'alias': 'bar', + } + self.assert_from_dict(node, minimum) + + def test_basic_compiled(self): + node_dict = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Test), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from {{ ref("other") }}', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': '', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': [], + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'severity': 'warn', + }, + 'docrefs': [], + 'columns': {}, + 'compiled': True, + 'compiled_sql': 'select * from whatever', + 'extra_ctes': [{'id': 'whatever', 'sql': 'select * from other'}], + 'extra_ctes_injected': True, + 'injected_sql': 'with whatever as (select * from other) select * from whatever', + 'wrapped_sql': 'select count(*) from (with whatever as (select * from other) select * from whatever) sbq', + 'column_name': 'id', + } + node = self.ContractType( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from {{ ref("other") }}', + name='foo', + resource_type=NodeType.Test, + unique_id='model.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(), + description='', + database='test_db', + schema='test_schema', + alias='bar', + tags=[], + config=TestConfig(severity='warn'), + compiled=True, + compiled_sql='select * from whatever', + extra_ctes=[InjectedCTE('whatever', 'select * from other')], + extra_ctes_injected=True, + injected_sql='with whatever as (select * from other) select * from whatever', + wrapped_sql='select count(*) from (with whatever as (select * from other) select * from whatever) sbq', + column_name='id', + ) + self.assert_symmetric(node, node_dict) + self.assertFalse(node.empty) + self.assertFalse(node.is_refable) + self.assertFalse(node.is_ephemeral) + self.assertEqual(node.local_vars(), {}) + + def test_invalid_extra_fields(self): + bad_extra = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Test), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'database': 'test_db', + 'schema': 'test_schema', + 'alias': 'bar', + 'extra': 'extra value', + } + self.assert_fails_validation(bad_extra) + + def test_invalid_resource_type(self): + bad_type = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Model), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'database': 'test_db', + 'schema': 'test_schema', + 'alias': 'bar', + } + self.assert_fails_validation(bad_type) diff --git a/test/unit/test_contracts_graph_parsed.py b/test/unit/test_contracts_graph_parsed.py new file mode 100644 index 00000000000..dd48e9add13 --- /dev/null +++ b/test/unit/test_contracts_graph_parsed.py @@ -0,0 +1,1678 @@ +from dbt.node_types import NodeType +from dbt.contracts.graph.parsed import ( + ParsedNode, DependsOn, NodeConfig, ColumnInfo, Hook, ParsedTestNode, + TestConfig, ParsedSnapshotNode, TimestampSnapshotConfig, All, Docref, + GenericSnapshotConfig, CheckSnapshotConfig, TimestampStrategy, + CheckStrategy, IntermediateSnapshotNode, ParsedNodePatch, ParsedMacro, + MacroDependsOn, ParsedSourceDefinition, ParsedDocumentation, +) +from dbt.contracts.graph.unparsed import Quoting, FreshnessThreshold + +from hologram import ValidationError +from .utils import ContractTestCase + + +class TestNodeConfig(ContractTestCase): + ContractType = NodeConfig + + def test_basics(self): + cfg_dict = { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + } + cfg = self.ContractType() + self.assert_symmetric(cfg, cfg_dict) + + def test_populated(self): + cfg_dict = { + 'column_types': {'a': 'text'}, + 'enabled': True, + 'materialized': 'table', + 'persist_docs': {}, + 'post-hook': [{'sql': 'insert into blah(a, b) select "1", 1', 'transaction': True}], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'extra': 'even more', + } + cfg = self.ContractType( + column_types={'a': 'text'}, + materialized='table', + post_hook=[Hook(sql='insert into blah(a, b) select "1", 1')] + ) + cfg._extra['extra'] = 'even more' + + self.assert_symmetric(cfg, cfg_dict) + + +class TestParsedNode(ContractTestCase): + ContractType = ParsedNode + + def test_ok(self): + node_dict = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Model), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': '', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': [], + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + }, + 'docrefs': [], + 'columns': {}, + } + node = self.ContractType( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from wherever', + name='foo', + resource_type=NodeType.Model, + unique_id='model.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(), + description='', + database='test_db', + schema='test_schema', + alias='bar', + tags=[], + config=NodeConfig(), + ) + self.assert_symmetric(node, node_dict) + self.assertFalse(node.empty) + self.assertTrue(node.is_refable) + self.assertFalse(node.is_ephemeral) + self.assertEqual(node.local_vars(), {}) + + minimum = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Model), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'database': 'test_db', + 'schema': 'test_schema', + 'alias': 'bar', + } + self.assert_from_dict(node, minimum) + + def test_complex(self): + node_dict = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Model), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from {{ ref("bar") }}', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': ['model.test.bar']}, + 'database': 'test_db', + 'description': 'My parsed node', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': ['tag'], + 'config': { + 'column_types': {'a': 'text'}, + 'enabled': True, + 'materialized': 'ephemeral', + 'persist_docs': {}, + 'post-hook': [{'sql': 'insert into blah(a, b) select "1", 1', 'transaction': True}], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {'foo': 100}, + }, + 'docrefs': [], + 'columns': {'a': {'name': 'a', 'description': 'a text field'}}, + } + + node = self.ContractType( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from {{ ref("bar") }}', + name='foo', + resource_type=NodeType.Model, + unique_id='model.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(nodes=['model.test.bar']), + description='My parsed node', + database='test_db', + schema='test_schema', + alias='bar', + tags=['tag'], + config=NodeConfig( + column_types={'a': 'text'}, + materialized='ephemeral', + post_hook=[Hook(sql='insert into blah(a, b) select "1", 1')], + vars={'foo': 100}, + ), + columns={'a': ColumnInfo('a', 'a text field')}, + ) + self.assert_symmetric(node, node_dict) + self.assertFalse(node.empty) + self.assertTrue(node.is_refable) + self.assertTrue(node.is_ephemeral) + self.assertEqual(node.local_vars(), {'foo': 100}) + + def test_invalid_bad_tags(self): + # bad top-level field + bad_tags = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Model), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': 'My parsed node', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': 100, + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': None, + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + }, + 'docrefs': [], + 'columns': {}, + } + self.assert_fails_validation(bad_tags) + + def test_invalid_bad_materialized(self): + # bad nested field + bad_materialized = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Model), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': 'My parsed node', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': ['tag'], + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': None, + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + }, + 'docrefs': [], + 'columns': {}, + } + self.assert_fails_validation(bad_materialized) + + def test_patch_ok(self): + initial = self.ContractType( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from wherever', + name='foo', + resource_type=NodeType.Model, + unique_id='model.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(), + description='', + database='test_db', + schema='test_schema', + alias='bar', + tags=[], + config=NodeConfig(), + ) + patch = ParsedNodePatch( + name='foo', + description='The foo model', + original_file_path='/path/to/schema.yml', + columns={'a': ColumnInfo(name='a', description='a text field')}, + docrefs=[ + Docref(documentation_name='foo', documentation_package='test'), + ], + ) + + initial.patch(patch) + + expected_dict = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Model), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': 'The foo model', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': [], + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + }, + 'patch_path': '/path/to/schema.yml', + 'columns': {'a': {'name': 'a', 'description': 'a text field'}}, + 'docrefs': [ + { + 'documentation_name': 'foo', + 'documentation_package': 'test', + } + ], + } + + expected = self.ContractType( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from wherever', + name='foo', + resource_type=NodeType.Model, + unique_id='model.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(), + description='The foo model', + database='test_db', + schema='test_schema', + alias='bar', + tags=[], + config=NodeConfig(), + patch_path='/path/to/schema.yml', + columns={'a': ColumnInfo(name='a', description='a text field')}, + docrefs=[ + Docref(documentation_name='foo', documentation_package='test'), + ], + ) + self.assert_symmetric(expected, expected_dict) # sanity check + self.assertEqual(initial, expected) + self.assert_symmetric(initial, expected_dict) + + def patch_invalid(self): + initial = self.ContractType( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from wherever', + name='foo', + resource_type=NodeType.Model, + unique_id='model.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(), + description='', + database='test_db', + schema='test_schema', + alias='bar', + tags=[], + config=NodeConfig(), + ) + # invalid patch: description can't be None + patch = ParsedNodePatch( + name='foo', + description=None, + original_file_path='/path/to/schema.yml', + columns={}, + docrefs=[], + ) + with self.assertRaises(ValidationError): + initial.patch(patch) + + +class TestParsedHookNode(ContractTestCase): + ContractType = ParsedNode + + def test_ok(self): + node_dict = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Operation), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': '', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': [], + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + }, + 'docrefs': [], + 'columns': {}, + } + node = self.ContractType( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from wherever', + name='foo', + resource_type=NodeType.Operation, + unique_id='model.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(), + description='', + database='test_db', + schema='test_schema', + alias='bar', + tags=[], + config=NodeConfig(), + ) + self.assert_symmetric(node, node_dict) + self.assertFalse(node.empty) + self.assertFalse(node.is_refable) + self.assertEqual(node.get_materialization(), 'view') + + minimum = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Operation), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'database': 'test_db', + 'schema': 'test_schema', + 'alias': 'bar', + } + self.assert_from_dict(node, minimum) + + def test_complex(self): + node_dict = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Operation), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from {{ ref("bar") }}', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': ['model.test.bar']}, + 'database': 'test_db', + 'description': 'My parsed node', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': ['tag'], + 'config': { + 'column_types': {'a': 'text'}, + 'enabled': True, + 'materialized': 'table', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + }, + 'docrefs': [], + 'columns': {'a': {'name': 'a', 'description': 'a text field'}}, + 'index': 13, + } + + node = self.ContractType( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from {{ ref("bar") }}', + name='foo', + resource_type=NodeType.Operation, + unique_id='model.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(nodes=['model.test.bar']), + description='My parsed node', + database='test_db', + schema='test_schema', + alias='bar', + tags=['tag'], + config=NodeConfig( + column_types={'a': 'text'}, + materialized='table', + post_hook=[] + ), + columns={'a': ColumnInfo('a', 'a text field')}, + index=13, + ) + self.assert_symmetric(node, node_dict) + self.assertFalse(node.empty) + self.assertFalse(node.is_refable) + self.assertEqual(node.get_materialization(), 'table') + + def test_invalid_index_type(self): + # bad top-level field + bad_index = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Operation), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': 'My parsed node', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': [], + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': None, + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + }, + 'docrefs': [], + 'columns': {}, + 'index': 'a string!?', + } + self.assert_fails_validation(bad_index) + + +class TestParsedTestNode(ContractTestCase): + ContractType = ParsedTestNode + + def test_ok(self): + node_dict = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Test), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'test.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': '', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': [], + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'severity': 'error', + }, + 'docrefs': [], + 'columns': {}, + } + node = self.ContractType( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from wherever', + name='foo', + resource_type=NodeType.Test, + unique_id='test.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(), + description='', + database='test_db', + schema='test_schema', + alias='bar', + tags=[], + config=TestConfig(), + ) + self.assert_symmetric(node, node_dict) + self.assertFalse(node.empty) + self.assertFalse(node.is_ephemeral) + self.assertFalse(node.is_refable) + self.assertEqual(node.get_materialization(), 'view') + + minimum = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Test), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'test.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'database': 'test_db', + 'schema': 'test_schema', + 'alias': 'bar', + } + self.assert_from_dict(node, minimum) + + def test_complex(self): + node_dict = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Test), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from {{ ref("bar") }}', + 'unique_id': 'test.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': ['model.test.bar']}, + 'database': 'test_db', + 'description': 'My parsed node', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': ['tag'], + 'config': { + 'column_types': {'a': 'text'}, + 'enabled': True, + 'materialized': 'table', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'severity': 'WARN', + 'extra_key': 'extra value' + }, + 'docrefs': [], + 'columns': {'a': {'name': 'a', 'description': 'a text field'}}, + 'column_name': 'id', + } + + cfg = TestConfig( + column_types={'a': 'text'}, + materialized='table', + severity='WARN' + ) + cfg._extra.update({'extra_key': 'extra value'}) + + node = self.ContractType( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from {{ ref("bar") }}', + name='foo', + resource_type=NodeType.Test, + unique_id='test.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(nodes=['model.test.bar']), + description='My parsed node', + database='test_db', + schema='test_schema', + alias='bar', + tags=['tag'], + config=cfg, + columns={'a': ColumnInfo('a', 'a text field')}, + column_name='id', + ) + self.assert_symmetric(node, node_dict) + self.assertFalse(node.empty) + + def test_invalid_column_name_type(self): + # bad top-level field + bad_column_name = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Test), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': 'My parsed node', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': 100, + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': None, + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'severity': 'ERROR', + }, + 'docrefs': [], + 'columns': {}, + 'column_name': {}, + } + self.assert_fails_validation(bad_column_name) + + def test_invalid_missing_severity(self): + # note the typo ('severtiy') + missing_config_value = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Test), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': 'My parsed node', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': ['tag'], + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': None, + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'severtiy': 'WARN', + }, + 'docrefs': [], + 'columns': {}, + } + self.assert_fails_validation(missing_config_value) + + def test_invalid_severity(self): + invalid_config_value = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Test), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': 'My parsed node', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': ['tag'], + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': None, + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'severity': 'WERROR', # invalid severity + }, + 'docrefs': [], + 'columns': {}, + } + self.assert_fails_validation(invalid_config_value) + + +class TestTimestampSnapshotConfig(ContractTestCase): + ContractType = TimestampSnapshotConfig + + def test_basics(self): + cfg_dict = { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'unique_key': 'id', + 'strategy': 'timestamp', + 'updated_at': 'last_update', + 'target_database': 'some_snapshot_db', + 'target_schema': 'some_snapshot_schema', + } + cfg = self.ContractType( + strategy=TimestampStrategy.Timestamp, + updated_at='last_update', + unique_key='id', + target_database='some_snapshot_db', + target_schema='some_snapshot_schema', + ) + self.assert_symmetric(cfg, cfg_dict) + + def test_populated(self): + cfg_dict = { + 'column_types': {'a': 'text'}, + 'enabled': True, + 'materialized': 'table', + 'persist_docs': {}, + 'post-hook': [{'sql': 'insert into blah(a, b) select "1", 1', 'transaction': True}], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'target_database': 'some_snapshot_db', + 'target_schema': 'some_snapshot_schema', + 'unique_key': 'id', + 'extra': 'even more', + 'strategy': 'timestamp', + 'updated_at': 'last_update', + } + cfg = self.ContractType( + column_types={'a': 'text'}, + materialized='table', + post_hook=[Hook(sql='insert into blah(a, b) select "1", 1')], + strategy=TimestampStrategy.Timestamp, + target_database='some_snapshot_db', + target_schema='some_snapshot_schema', + updated_at='last_update', + unique_key='id', + ) + cfg._extra['extra'] = 'even more' + + self.assert_symmetric(cfg, cfg_dict) + + def test_invalid_wrong_strategy(self): + bad_type = { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'target_database': 'some_snapshot_db', + 'target_schema': 'some_snapshot_schema', + 'unique_key': 'id', + 'strategy': 'check', + 'updated_at': 'last_update', + } + self.assert_fails_validation(bad_type) + + def test_invalid_missing_updated_at(self): + bad_fields = { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'target_database': 'some_snapshot_db', + 'target_schema': 'some_snapshot_schema', + 'unique_key': 'id', + 'strategy': 'timestamp', + 'check_cols': 'all' + } + self.assert_fails_validation(bad_fields) + + +class TestCheckSnapshotConfig(ContractTestCase): + ContractType = CheckSnapshotConfig + + def test_basics(self): + cfg_dict = { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'target_database': 'some_snapshot_db', + 'target_schema': 'some_snapshot_schema', + 'unique_key': 'id', + 'strategy': 'check', + 'check_cols': 'all', + } + cfg = self.ContractType( + strategy=CheckStrategy.Check, + check_cols=All.All, + unique_key='id', + target_database='some_snapshot_db', + target_schema='some_snapshot_schema', + ) + self.assert_symmetric(cfg, cfg_dict) + + def test_populated(self): + cfg_dict = { + 'column_types': {'a': 'text'}, + 'enabled': True, + 'materialized': 'table', + 'persist_docs': {}, + 'post-hook': [{'sql': 'insert into blah(a, b) select "1", 1', 'transaction': True}], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'target_database': 'some_snapshot_db', + 'target_schema': 'some_snapshot_schema', + 'unique_key': 'id', + 'extra': 'even more', + 'strategy': 'check', + 'check_cols': ['a', 'b'], + } + cfg = self.ContractType( + column_types={'a': 'text'}, + materialized='table', + post_hook=[Hook(sql='insert into blah(a, b) select "1", 1')], + strategy=CheckStrategy.Check, + check_cols=['a', 'b'], + target_database='some_snapshot_db', + target_schema='some_snapshot_schema', + unique_key='id', + ) + cfg._extra['extra'] = 'even more' + + self.assert_symmetric(cfg, cfg_dict) + + def test_invalid_wrong_strategy(self): + wrong_strategy = { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'target_database': 'some_snapshot_db', + 'target_schema': 'some_snapshot_schema', + 'unique_key': 'id', + 'strategy': 'timestamp', + 'check_cols': 'all', + } + self.assert_fails_validation(wrong_strategy) + + def test_invalid_missing_check_cols(self): + wrong_fields = { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'target_database': 'some_snapshot_db', + 'target_schema': 'some_snapshot_schema', + 'unique_key': 'id', + 'strategy': 'check', + 'updated_at': 'last_update' + } + self.assert_fails_validation(wrong_fields) + + def test_invalid_check_value(self): + invalid_check_type = { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'target_database': 'some_snapshot_db', + 'target_schema': 'some_snapshot_schema', + 'unique_key': 'id', + 'strategy': 'timestamp', + 'check_cols': 'some', + } + self.assert_fails_validation(invalid_check_type) + + +class TestGenericSnapshotConfig(ContractTestCase): + ContractType = GenericSnapshotConfig + + def test_ok(self): + cfg_dict = { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'unique_key': 'id', + 'strategy': 'hello', + 'target_database': 'some_snapshot_db', + 'target_schema': 'some_snapshot_schema', + 'magic_key': 'magic', + } + cfg = self.ContractType( + strategy='hello', + unique_key='id', + target_database='some_snapshot_db', + target_schema='some_snapshot_schema', + ) + cfg._extra.update({'magic_key': 'magic'}) + self.assert_symmetric(cfg, cfg_dict) + + +class TestParsedSnapshotNode(ContractTestCase): + ContractType = ParsedSnapshotNode + + def test_timestamp_ok(self): + node_dict = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Snapshot), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': '', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': [], + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'target_database': 'some_snapshot_db', + 'target_schema': 'some_snapshot_schema', + 'unique_key': 'id', + 'strategy': 'timestamp', + 'updated_at': 'last_update', + }, + 'docrefs': [], + 'columns': {}, + } + + node = self.ContractType( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from wherever', + name='foo', + resource_type=NodeType.Snapshot, + unique_id='model.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(), + description='', + database='test_db', + schema='test_schema', + alias='bar', + tags=[], + config=TimestampSnapshotConfig( + strategy=TimestampStrategy.Timestamp, + unique_key='id', + updated_at='last_update', + target_database='some_snapshot_db', + target_schema='some_snapshot_schema', + ), + ) + + cfg = NodeConfig() + cfg._extra.update({ + 'unique_key': 'id', + 'strategy': 'timestamp', + 'updated_at': 'last_update', + 'target_database': 'some_snapshot_db', + 'target_schema': 'some_snapshot_schema', + }) + + inter = IntermediateSnapshotNode( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from wherever', + name='foo', + resource_type=NodeType.Snapshot, + unique_id='model.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(), + description='', + database='test_db', + schema='test_schema', + alias='bar', + tags=[], + config=cfg, + ) + self.assert_symmetric(node, node_dict) + self.assert_symmetric(inter, node_dict, cls=IntermediateSnapshotNode) + self.assertEqual( + self.ContractType.from_dict(inter.to_dict()), + node + ) + self.assertTrue(node.is_refable) + self.assertFalse(node.is_ephemeral) + + def test_check_ok(self): + node_dict = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Snapshot), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': '', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': [], + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'target_database': 'some_snapshot_db', + 'target_schema': 'some_snapshot_schema', + 'unique_key': 'id', + 'strategy': 'check', + 'check_cols': 'all', + }, + 'docrefs': [], + 'columns': {}, + } + + node = self.ContractType( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from wherever', + name='foo', + resource_type=NodeType.Snapshot, + unique_id='model.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(), + description='', + database='test_db', + schema='test_schema', + alias='bar', + tags=[], + config=CheckSnapshotConfig( + strategy=CheckStrategy.Check, + unique_key='id', + check_cols=All.All, + target_database='some_snapshot_db', + target_schema='some_snapshot_schema', + ), + ) + cfg = NodeConfig() + cfg._extra.update({ + 'unique_key': 'id', + 'strategy': 'check', + 'check_cols': 'all', + 'target_database': 'some_snapshot_db', + 'target_schema': 'some_snapshot_schema', + }) + + inter = IntermediateSnapshotNode( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from wherever', + name='foo', + resource_type=NodeType.Snapshot, + unique_id='model.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(), + description='', + database='test_db', + schema='test_schema', + alias='bar', + tags=[], + config=cfg, + ) + self.assert_symmetric(node, node_dict) + self.assert_symmetric(inter, node_dict, cls=IntermediateSnapshotNode) + self.assertEqual( + self.ContractType.from_dict(inter.to_dict()), + node + ) + self.assertTrue(node.is_refable) + self.assertFalse(node.is_ephemeral) + + def test_ok_unknown_strategy(self): + unknown_strategy = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Snapshot), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': '', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': [], + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'target_database': 'some_snapshot_db', + 'target_schema': 'some_snapshot_schema', + 'unique_key': 'id', + 'strategy': 'unknown', + 'magic_key': 'something', + }, + 'docrefs': [], + 'columns': {}, + } + config = GenericSnapshotConfig( + strategy='unknown', + unique_key='id', + target_database='some_snapshot_db', + target_schema='some_snapshot_schema', + ) + config._extra.update({'magic_key': 'something'}) + + node = self.ContractType( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from wherever', + name='foo', + resource_type=NodeType.Snapshot, + unique_id='model.test.foo', + fqn=['test', 'models', 'foo'], + refs=[], + sources=[], + depends_on=DependsOn(), + description='', + database='test_db', + schema='test_schema', + alias='bar', + tags=[], + config=config, + ) + self.assert_symmetric(node, unknown_strategy) + + def test_invalid_bad_resource_type(self): + bad_resource_type = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': str(NodeType.Model), + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from wherever', + 'unique_id': 'model.test.foo', + 'fqn': ['test', 'models', 'foo'], + 'refs': [], + 'sources': [], + 'depends_on': {'macros': [], 'nodes': []}, + 'database': 'test_db', + 'description': '', + 'schema': 'test_schema', + 'alias': 'bar', + 'tags': [], + 'config': { + 'column_types': {}, + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'quoting': {}, + 'tags': [], + 'vars': {}, + 'target_database': 'some_snapshot_db', + 'target_schema': 'some_snapshot_schema', + 'unique_key': 'id', + 'strategy': 'timestamp', + 'updated_at': 'last_update', + }, + 'docrefs': [], + 'columns': {}, + } + self.assert_fails_validation(bad_resource_type) + + +class TestParsedNodePatch(ContractTestCase): + ContractType = ParsedNodePatch + + def test_empty(self): + dct = { + 'name': 'foo', + 'description': 'The foo model', + 'original_file_path': '/path/to/schema.yml', + 'columns': {}, + 'docrefs': [], + } + patch = ParsedNodePatch( + name='foo', + description='The foo model', + original_file_path='/path/to/schema.yml', + columns={}, + docrefs=[], + ) + self.assert_symmetric(patch, dct) + + def test_populated(self): + dct = { + 'name': 'foo', + 'description': 'The foo model', + 'original_file_path': '/path/to/schema.yml', + 'columns': {'a': {'name': 'a', 'description': 'a text field'}}, + 'docrefs': [ + { + 'documentation_name': 'foo', + 'documentation_package': 'test', + } + ], + } + patch = ParsedNodePatch( + name='foo', + description='The foo model', + original_file_path='/path/to/schema.yml', + columns={'a': ColumnInfo(name='a', description='a text field')}, + docrefs=[ + Docref(documentation_name='foo', documentation_package='test'), + ], + ) + self.assert_symmetric(patch, dct) + + +class TestParsedMacro(ContractTestCase): + ContractType = ParsedMacro + + def test_ok(self): + macro_dict = { + 'name': 'foo', + 'path': '/root/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': '{% macro foo() %}select 1 as id{% endmacro %}', + 'root_path': '/root/', + 'resource_type': 'macro', + 'unique_id': 'macro.test.foo', + 'tags': [], + 'depends_on': {'macros': []} + } + macro = ParsedMacro( + name='foo', + path='/root/path.sql', + original_file_path='/root/path.sql', + package_name='test', + raw_sql='{% macro foo() %}select 1 as id{% endmacro %}', + root_path='/root/', + resource_type=NodeType.Macro, + unique_id='macro.test.foo', + tags=[], + depends_on=MacroDependsOn() + ) + self.assert_symmetric(macro, macro_dict) + self.assertEqual(macro.local_vars(), {}) + + def test_invalid_missing_unique_id(self): + bad_missing_uid = { + 'name': 'foo', + 'path': '/root/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': '{% macro foo() %}select 1 as id{% endmacro %}', + 'root_path': '/root/', + 'resource_type': 'macro', + 'tags': [], + 'depends_on': {'macros': []} + } + self.assert_fails_validation(bad_missing_uid) + + def test_invalid_extra_field(self): + bad_extra_field = { + 'name': 'foo', + 'path': '/root/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': '{% macro foo() %}select 1 as id{% endmacro %}', + 'root_path': '/root/', + 'resource_type': 'macro', + 'unique_id': 'macro.test.foo', + 'tags': [], + 'depends_on': {'macros': []}, + 'extra': 'too many fields' + } + self.assert_fails_validation(bad_extra_field) + + +class TestParsedDocumentation(ContractTestCase): + ContractType = ParsedDocumentation + + def test_ok(self): + doc_dict = { + 'block_contents': 'some doc contents', + 'file_contents': '{% doc foo %}some doc contents{% enddoc %}', + 'name': 'foo', + 'original_file_path': '/root/docs/doc.md', + 'package_name': 'test', + 'path': '/root/docs', + 'root_path': '/root', + 'unique_id': 'test.foo', + } + doc = self.ContractType( + package_name='test', + root_path='/root', + path='/root/docs', + original_file_path='/root/docs/doc.md', + file_contents='{% doc foo %}some doc contents{% enddoc %}', + name='foo', + unique_id='test.foo', + block_contents='some doc contents' + ) + self.assert_symmetric(doc, doc_dict) + + def test_invalid_missing(self): + bad_missing_contents = { + # 'block_contents': 'some doc contents', + 'file_contents': '{% doc foo %}some doc contents{% enddoc %}', + 'name': 'foo', + 'original_file_path': '/root/docs/doc.md', + 'package_name': 'test', + 'path': '/root/docs', + 'root_path': '/root', + 'unique_id': 'test.foo', + } + self.assert_fails_validation(bad_missing_contents) + + def test_invalid_extra(self): + bad_extra_field = { + 'block_contents': 'some doc contents', + 'file_contents': '{% doc foo %}some doc contents{% enddoc %}', + 'name': 'foo', + 'original_file_path': '/root/docs/doc.md', + 'package_name': 'test', + 'path': '/root/docs', + 'root_path': '/root', + 'unique_id': 'test.foo', + + 'extra': 'more', + } + self.assert_fails_validation(bad_extra_field) + + +class TestParsedSourceDefinition(ContractTestCase): + ContractType = ParsedSourceDefinition + + def test_basic(self): + source_def_dict = { + 'package_name': 'test', + 'root_path': '/root', + 'path': '/root/models/sources.yml', + 'original_file_path': '/root/models/sources.yml', + 'database': 'some_db', + 'schema': 'some_schema', + 'fqn': ['test', 'source', 'my_source', 'my_source_table'], + 'source_name': 'my_source', + 'name': 'my_source_table', + 'source_description': 'my source description', + 'loader': 'stitch', + 'identifier': 'my_source_table', + 'resource_type': str(NodeType.Source), + 'description': '', + 'freshness': {}, + 'docrefs': [], + 'columns': {}, + 'quoting': {}, + 'unique_id': 'test.source.my_source.my_source_table', + } + source_def = self.ContractType( + columns={}, + docrefs=[], + database='some_db', + description='', + fqn=['test', 'source', 'my_source', 'my_source_table'], + freshness=FreshnessThreshold(), + identifier='my_source_table', + loader='stitch', + name='my_source_table', + original_file_path='/root/models/sources.yml', + package_name='test', + path='/root/models/sources.yml', + quoting=Quoting(), + resource_type=NodeType.Source, + root_path='/root', + schema='some_schema', + source_description='my source description', + source_name='my_source', + unique_id='test.source.my_source.my_source_table', + ) + self.assert_symmetric(source_def, source_def_dict) + minimum = { + 'package_name': 'test', + 'root_path': '/root', + 'path': '/root/models/sources.yml', + 'original_file_path': '/root/models/sources.yml', + 'database': 'some_db', + 'schema': 'some_schema', + 'fqn': ['test', 'source', 'my_source', 'my_source_table'], + 'source_name': 'my_source', + 'name': 'my_source_table', + 'source_description': 'my source description', + 'loader': 'stitch', + 'identifier': 'my_source_table', + 'resource_type': str(NodeType.Source), + 'unique_id': 'test.source.my_source.my_source_table', + } + self.assert_from_dict(source_def, minimum) + + def test_invalid_missing(self): + bad_missing_name = { + 'package_name': 'test', + 'root_path': '/root', + 'path': '/root/models/sources.yml', + 'original_file_path': '/root/models/sources.yml', + 'database': 'some_db', + 'schema': 'some_schema', + 'fqn': ['test', 'source', 'my_source', 'my_source_table'], + 'source_name': 'my_source', + # 'name': 'my_source_table', + 'source_description': 'my source description', + 'loader': 'stitch', + 'identifier': 'my_source_table', + 'resource_type': str(NodeType.Source), + 'unique_id': 'test.source.my_source.my_source_table', + } + self.assert_fails_validation(bad_missing_name) + + def test_invalid_bad_resource_type(self): + bad_resource_type = { + 'package_name': 'test', + 'root_path': '/root', + 'path': '/root/models/sources.yml', + 'original_file_path': '/root/models/sources.yml', + 'database': 'some_db', + 'schema': 'some_schema', + 'fqn': ['test', 'source', 'my_source', 'my_source_table'], + 'source_name': 'my_source', + 'name': 'my_source_table', + 'source_description': 'my source description', + 'loader': 'stitch', + 'identifier': 'my_source_table', + 'resource_type': str(NodeType.Model), + 'unique_id': 'test.source.my_source.my_source_table', + } + self.assert_fails_validation(bad_resource_type) diff --git a/test/unit/test_contracts_graph_unparsed.py b/test/unit/test_contracts_graph_unparsed.py new file mode 100644 index 00000000000..41eb6e75582 --- /dev/null +++ b/test/unit/test_contracts_graph_unparsed.py @@ -0,0 +1,447 @@ +from datetime import timedelta + +from dbt.contracts.graph.unparsed import ( + UnparsedNode, UnparsedRunHook, UnparsedMacro, Time, TimePeriod, + FreshnessStatus, FreshnessThreshold, Quoting, UnparsedSourceDefinition, + UnparsedSourceTableDefinition, UnparsedDocumentationFile, NamedTested, + UnparsedNodeUpdate +) +from dbt.node_types import NodeType +from .utils import ContractTestCase + + +class TestUnparsedMacro(ContractTestCase): + ContractType = UnparsedMacro + + def test_ok(self): + macro_dict = { + 'path': '/root/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': '{% macro foo() %}select 1 as id{% endmacro %}', + 'root_path': '/root/', + 'resource_type': 'macro', + } + macro = self.ContractType( + path='/root/path.sql', + original_file_path='/root/path.sql', + package_name='test', + raw_sql='{% macro foo() %}select 1 as id{% endmacro %}', + root_path='/root/', + resource_type=NodeType.Macro, + ) + self.assert_symmetric(macro, macro_dict) + + def test_invalid_missing_field(self): + macro_dict = { + 'path': '/root/path.sql', + 'original_file_path': '/root/path.sql', + # 'package_name': 'test', + 'raw_sql': '{% macro foo() %}select 1 as id{% endmacro %}', + 'root_path': '/root/', + 'resource_type': 'macro', + } + self.assert_fails_validation(macro_dict) + + def test_invalid_extra_field(self): + macro_dict = { + 'path': '/root/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': '{% macro foo() %}select 1 as id{% endmacro %}', + 'root_path': '/root/', + 'extra': 'extra', + 'resource_type': 'macro', + } + self.assert_fails_validation(macro_dict) + + +class TestUnparsedNode(ContractTestCase): + ContractType = UnparsedNode + + def test_ok(self): + node_dict = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': NodeType.Model, + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from {{ ref("thing") }}', + } + node = self.ContractType( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql='select * from {{ ref("thing") }}', + name='foo', + resource_type=NodeType.Model, + ) + self.assert_symmetric(node, node_dict) + self.assertFalse(node.empty) + + self.assert_fails_validation(node_dict, cls=UnparsedRunHook) + self.assert_fails_validation(node_dict, cls=UnparsedMacro) + + def test_empty(self): + node_dict = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': NodeType.Model, + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': ' \n', + } + node = UnparsedNode( + package_name='test', + root_path='/root/', + path='/root/x/path.sql', + original_file_path='/root/path.sql', + raw_sql=' \n', + name='foo', + resource_type=NodeType.Model, + ) + self.assert_symmetric(node, node_dict) + self.assertTrue(node.empty) + + self.assert_fails_validation(node_dict, cls=UnparsedRunHook) + self.assert_fails_validation(node_dict, cls=UnparsedMacro) + + def test_bad_type(self): + node_dict = { + 'name': 'foo', + 'root_path': '/root/', + 'resource_type': NodeType.Source, # not valid! + 'path': '/root/x/path.sql', + 'original_file_path': '/root/path.sql', + 'package_name': 'test', + 'raw_sql': 'select * from {{ ref("thing") }}', + } + self.assert_fails_validation(node_dict) + + +class TestUnparsedRunHook(ContractTestCase): + ContractType = UnparsedRunHook + + def test_ok(self): + node_dict = { + 'name': 'foo', + 'root_path': 'test/dbt_project.yml', + 'resource_type': NodeType.Operation, + 'path': '/root/dbt_project.yml', + 'original_file_path': '/root/dbt_project.yml', + 'package_name': 'test', + 'raw_sql': 'GRANT select on dbt_postgres', + 'index': 4 + } + node = self.ContractType( + package_name='test', + root_path='test/dbt_project.yml', + path='/root/dbt_project.yml', + original_file_path='/root/dbt_project.yml', + raw_sql='GRANT select on dbt_postgres', + name='foo', + resource_type=NodeType.Operation, + index=4, + ) + self.assert_symmetric(node, node_dict) + self.assert_fails_validation(node_dict, cls=UnparsedNode) + + def test_bad_type(self): + node_dict = { + 'name': 'foo', + 'root_path': 'test/dbt_project.yml', + 'resource_type': NodeType.Model, # invalid + 'path': '/root/dbt_project.yml', + 'original_file_path': '/root/dbt_project.yml', + 'package_name': 'test', + 'raw_sql': 'GRANT select on dbt_postgres', + 'index': 4 + } + self.assert_fails_validation(node_dict) + + +class TestFreshnessThreshold(ContractTestCase): + ContractType = FreshnessThreshold + + def test_empty(self): + empty = self.ContractType(None, None) + self.assert_symmetric(empty, {}) + self.assertEqual(empty.status(float('Inf')), FreshnessStatus.Pass) + self.assertEqual(empty.status(0), FreshnessStatus.Pass) + + def test_both(self): + threshold = self.ContractType( + warn_after=Time(count=18, period=TimePeriod.hour), + error_after=Time(count=2, period=TimePeriod.day), + ) + dct = { + 'error_after': {'count': 2, 'period': 'day'}, + 'warn_after': {'count': 18, 'period': 'hour'} + } + self.assert_symmetric(threshold, dct) + + error_seconds = timedelta(days=3).total_seconds() + warn_seconds = timedelta(days=1).total_seconds() + pass_seconds = timedelta(hours=3).total_seconds() + self.assertEqual(threshold.status(error_seconds), FreshnessStatus.Error) + self.assertEqual(threshold.status(warn_seconds), FreshnessStatus.Warn) + self.assertEqual(threshold.status(pass_seconds), FreshnessStatus.Pass) + + def test_merged(self): + t1 = self.ContractType( + warn_after=Time(count=36, period=TimePeriod.hour), + error_after=Time(count=2, period=TimePeriod.day), + ) + t2 = self.ContractType( + warn_after=Time(count=18, period=TimePeriod.hour), + ) + threshold = self.ContractType( + warn_after=Time(count=18, period=TimePeriod.hour), + error_after=Time(count=2, period=TimePeriod.day), + ) + self.assertEqual(threshold, t1.merged(t2)) + + error_seconds = timedelta(days=3).total_seconds() + warn_seconds = timedelta(days=1).total_seconds() + pass_seconds = timedelta(hours=3).total_seconds() + self.assertEqual(threshold.status(error_seconds), FreshnessStatus.Error) + self.assertEqual(threshold.status(warn_seconds), FreshnessStatus.Warn) + self.assertEqual(threshold.status(pass_seconds), FreshnessStatus.Pass) + + +class TestQuoting(ContractTestCase): + ContractType = Quoting + + def test_empty(self): + empty = self.ContractType() + self.assert_symmetric(empty, {}) + + def test_partial(self): + a = self.ContractType(None, True, False) + b = self.ContractType(True, False, None) + self.assert_symmetric(a, {'schema': True, 'identifier': False}) + self.assert_symmetric(b, {'database': True, 'schema': False}) + + c = a.merged(b) + self.assertEqual(c, self.ContractType(True, False, False)) + self.assert_symmetric( + c, {'database': True, 'schema': False, 'identifier': False} + ) + + +class TestUnparsedSourceDefinition(ContractTestCase): + ContractType = UnparsedSourceDefinition + + def test_defaults(self): + minimum = self.ContractType(name='foo') + self.assert_from_dict(minimum, {'name': 'foo'}) + self.assert_to_dict(minimum, {'name': 'foo', 'description': '', 'quoting': {}, 'freshness': {}, 'tables': [], 'loader': ''}) + + def test_contents(self): + empty = self.ContractType( + name='foo', + description='a description', + quoting=Quoting(database=False), + loader='some_loader', + freshness=FreshnessThreshold(), + tables=[], + ) + dct = { + 'name': 'foo', + 'description': 'a description', + 'quoting': {'database': False}, + 'loader': 'some_loader', + 'freshness': {}, + 'tables': [], + } + self.assert_symmetric(empty, dct) + + def test_table_defaults(self): + table_1 = UnparsedSourceTableDefinition(name='table1') + table_2 = UnparsedSourceTableDefinition( + name='table2', + description='table 2', + quoting=Quoting(database=True), + ) + source = self.ContractType( + name='foo', + tables=[table_1, table_2] + ) + from_dict = { + 'name': 'foo', + 'tables': [ + {'name': 'table1'}, + { + 'name': 'table2', + 'description': 'table 2', + 'quoting': {'database': True}, + }, + ], + } + to_dict = { + 'name': 'foo', + 'description': '', + 'loader': '', + 'quoting': {}, + 'freshness': {}, + 'tables': [ + { + 'name': 'table1', + 'description': '', + 'tests': [], + 'columns': [], + 'quoting': {}, + 'freshness': {}, + }, + { + 'name': 'table2', + 'description': 'table 2', + 'tests': [], + 'columns': [], + 'quoting': {'database': True}, + 'freshness': {}, + }, + ], + } + self.assert_from_dict(source, from_dict) + self.assert_symmetric(source, to_dict) + + +class TestUnparsedDocumentationFile(ContractTestCase): + ContractType = UnparsedDocumentationFile + + def test_ok(self): + doc = self.ContractType( + package_name='test', + root_path='/root', + path='/root/docs', + original_file_path='/root/docs/doc.md', + file_contents='blah blah blah', + ) + doc_dict = { + 'package_name': 'test', + 'root_path': '/root', + 'path': '/root/docs', + 'original_file_path': '/root/docs/doc.md', + 'file_contents': 'blah blah blah', + } + self.assert_symmetric(doc, doc_dict) + self.assertEqual(doc.resource_type, NodeType.Documentation) + self.assert_fails_validation(doc_dict, UnparsedNode) + + def test_extra_field(self): + self.assert_fails_validation({}) + doc_dict = { + 'package_name': 'test', + 'root_path': '/root', + 'path': '/root/docs', + 'original_file_path': '/root/docs/doc.md', + 'file_contents': 'blah blah blah', + 'resource_type': 'docs', + } + self.assert_fails_validation(doc_dict) + + +class TestUnparsedNodeUpdate(ContractTestCase): + ContractType = UnparsedNodeUpdate + + def test_defaults(self): + minimum = self.ContractType(name='foo') + from_dict = {'name': 'foo'} + to_dict = {'name': 'foo', 'columns': [], 'description': '', 'tests': []} + self.assert_from_dict(minimum, from_dict) + self.assert_to_dict(minimum, to_dict) + + def test_contents(self): + update = self.ContractType( + name='foo', + description='a description', + tests=['table_test'], + columns=[ + NamedTested(name='x', description='x description'), + NamedTested( + name='y', + description='y description', + tests=[ + 'unique', + {'accepted_values': {'values': ['blue', 'green']}} + ] + ), + ], + ) + dct = { + 'name': 'foo', + 'description': 'a description', + 'tests': ['table_test'], + 'columns': [ + {'name': 'x', 'description': 'x description', 'tests': []}, + { + 'name': 'y', + 'description': 'y description', + 'tests': [ + 'unique', + {'accepted_values': {'values': ['blue', 'green']}} + ], + }, + ], + } + self.assert_symmetric(update, dct) + + def test_bad_test_type(self): + dct = { + 'name': 'foo', + 'description': 'a description', + 'tests': ['table_test'], + 'columns': [ + {'name': 'x', 'description': 'x description', 'tests': []}, + { + 'name': 'y', + 'description': 'y description', + 'tests': [ + 100, + {'accepted_values': {'values': ['blue', 'green']}} + ], + }, + ], + } + self.assert_fails_validation(dct) + + dct = { + 'name': 'foo', + 'description': 'a description', + 'tests': ['table_test'], + 'columns': [ + # column missing a name + {'description': 'x description', 'tests': []}, + { + 'name': 'y', + 'description': 'y description', + 'tests': [ + 'unique', + {'accepted_values': {'values': ['blue', 'green']}} + ], + }, + ], + } + self.assert_fails_validation(dct) + + # missing a name + dct = { + 'description': 'a description', + 'tests': ['table_test'], + 'columns': [ + {'name': 'x', 'description': 'x description', 'tests': []}, + { + 'name': 'y', + 'description': 'y description', + 'tests': [ + 'unique', + {'accepted_values': {'values': ['blue', 'green']}} + ], + }, + ], + } + self.assert_fails_validation(dct) diff --git a/test/unit/test_deps.py b/test/unit/test_deps.py index d3e4cbd34df..e235feafbf2 100644 --- a/test/unit/test_deps.py +++ b/test/unit/test_deps.py @@ -2,47 +2,103 @@ from unittest import mock import dbt.exceptions -from dbt.task.deps import GitPackage, LocalPackage, RegistryPackage +from dbt.task.deps import GitPackage, LocalPackage, RegistryPackage, \ + LocalPackageContract, GitPackageContract, RegistryPackageContract, \ + resolve_packages +from dbt.contracts.project import PackageConfig from dbt.semver import VersionSpecifier +from hologram import ValidationError + class TestLocalPackage(unittest.TestCase): def test_init(self): - a = LocalPackage(local='/path/to/package') - a.resolve_version() - self.assertEqual(a.source_type(), 'local') + a_contract = LocalPackageContract.from_dict({'local': '/path/to/package'}) + self.assertEqual(a_contract.local, '/path/to/package') + a = LocalPackage.from_contract(a_contract) self.assertEqual(a.local, '/path/to/package') + a_pinned = a.resolved() + self.assertEqual(a_pinned.local, '/path/to/package') + self.assertEqual(str(a_pinned), '/path/to/package') class TestGitPackage(unittest.TestCase): def test_init(self): - a = GitPackage(git='http://example.com', revision='0.0.1') + a_contract = GitPackageContract.from_dict( + {'git': 'http://example.com', 'revision': '0.0.1'} + ) + self.assertEqual(a_contract.git, 'http://example.com') + self.assertEqual(a_contract.revision, '0.0.1') + self.assertIs(a_contract.warn_unpinned, None) + + a = GitPackage.from_contract(a_contract) self.assertEqual(a.git, 'http://example.com') - self.assertEqual(a.revision, '0.0.1') - self.assertEqual(a.version, ['0.0.1']) - self.assertEqual(a.source_type(), 'git') + self.assertEqual(a.revisions, ['0.0.1']) + self.assertIs(a.warn_unpinned, True) + + a_pinned = a.resolved() + self.assertEqual(a_pinned.name, 'http://example.com') + self.assertEqual(a_pinned.get_version(), '0.0.1') + self.assertEqual(a_pinned.source_type(), 'git') + self.assertIs(a_pinned.warn_unpinned, True) def test_invalid(self): - with self.assertRaises(dbt.exceptions.ValidationException): - GitPackage(git='http://example.com', version='0.0.1') + with self.assertRaises(ValidationError): + GitPackageContract.from_dict( + {'git': 'http://example.com', 'version': '0.0.1'} + ) def test_resolve_ok(self): - a = GitPackage(git='http://example.com', revision='0.0.1') - b = GitPackage(git='http://example.com', revision='0.0.1') + a_contract = GitPackageContract.from_dict( + {'git': 'http://example.com', 'revision': '0.0.1'} + ) + b_contract = GitPackageContract.from_dict( + {'git': 'http://example.com', 'revision': '0.0.1', + 'warn-unpinned': False} + ) + a = GitPackage.from_contract(a_contract) + b = GitPackage.from_contract(b_contract) + self.assertTrue(a.warn_unpinned) + self.assertFalse(b.warn_unpinned) c = a.incorporate(b) - self.assertEqual(c.git, 'http://example.com') - self.assertEqual(c.version, ['0.0.1', '0.0.1']) - c.resolve_version() - self.assertEqual(c.version, ['0.0.1']) + + c_pinned = c.resolved() + self.assertEqual(c_pinned.name, 'http://example.com') + self.assertEqual(c_pinned.get_version(), '0.0.1') + self.assertEqual(c_pinned.source_type(), 'git') + self.assertFalse(c_pinned.warn_unpinned) def test_resolve_fail(self): - a = GitPackage(git='http://example.com', revision='0.0.1') - b = GitPackage(git='http://example.com', revision='0.0.2') + a_contract = GitPackageContract.from_dict( + {'git': 'http://example.com', 'revision': '0.0.1'} + ) + b_contract = GitPackageContract.from_dict( + {'git': 'http://example.com', 'revision': '0.0.2'} + ) + a = GitPackage.from_contract(a_contract) + b = GitPackage.from_contract(b_contract) c = a.incorporate(b) self.assertEqual(c.git, 'http://example.com') - self.assertEqual(c.version, ['0.0.1', '0.0.2']) + self.assertEqual(c.revisions, ['0.0.1', '0.0.2']) + with self.assertRaises(dbt.exceptions.DependencyException): - c.resolve_version() + c.resolved() + + def test_default_revision(self): + a_contract = GitPackageContract.from_dict({'git': 'http://example.com'}) + self.assertEqual(a_contract.revision, None) + self.assertIs(a_contract.warn_unpinned, None) + + a = GitPackage.from_contract(a_contract) + self.assertEqual(a.git, 'http://example.com') + self.assertEqual(a.revisions, []) + self.assertIs(a.warn_unpinned, True) + + a_pinned = a.resolved() + self.assertEqual(a_pinned.name, 'http://example.com') + self.assertEqual(a_pinned.get_version(), 'master') + self.assertEqual(a_pinned.source_type(), 'git') + self.assertIs(a_pinned.warn_unpinned, True) class TestHubPackage(unittest.TestCase): @@ -53,31 +109,6 @@ def setUp(self): self.get_available_versions = self.registry.get_available_versions self.package_version = self.registry.package_version - def tearDown(self): - self.patcher.stop() - - def test_init(self): - a = RegistryPackage(package='fishtown-analytics-test/a', - version='0.1.2') - self.assertEqual(a.package, 'fishtown-analytics-test/a') - self.assertEqual( - a.version, - [VersionSpecifier( - build=None, - major='0', - matcher='=', - minor='1', - patch='2', - prerelease=None - )] - ) - self.assertEqual(a.source_type(), 'hub') - - def test_invalid(self): - with self.assertRaises(dbt.exceptions.DependencyException): - RegistryPackage(package='namespace/name', key='invalid') - - def test_resolve_ok(self): self.index_cached.return_value = [ 'fishtown-analytics-test/a', ] @@ -99,105 +130,263 @@ def test_resolve_ok(self): 'newfield': ['another', 'value'], } - a = RegistryPackage( + def tearDown(self): + self.patcher.stop() + + def test_init(self): + a_contract = RegistryPackageContract( + package='fishtown-analytics-test/a', + version='0.1.2', + ) + self.assertEqual(a_contract.package, 'fishtown-analytics-test/a') + self.assertEqual(a_contract.version, '0.1.2') + + a = RegistryPackage.from_contract(a_contract) + self.assertEqual(a.package, 'fishtown-analytics-test/a') + self.assertEqual( + a.versions, + [VersionSpecifier( + build=None, + major='0', + matcher='=', + minor='1', + patch='2', + prerelease=None + )] + ) + + a_pinned = a.resolved() + self.assertEqual(a_contract.package, 'fishtown-analytics-test/a') + self.assertEqual(a_contract.version, '0.1.2') + self.assertEqual(a_pinned.source_type(), 'hub') + + def test_invalid(self): + with self.assertRaises(ValidationError): + RegistryPackageContract.from_dict( + {'package': 'namespace/name', 'key': 'invalid'} + ) + + def test_resolve_ok(self): + a_contract = RegistryPackageContract( package='fishtown-analytics-test/a', version='0.1.2' ) - b = RegistryPackage( + b_contract = RegistryPackageContract( package='fishtown-analytics-test/a', version='0.1.2' ) + a = RegistryPackage.from_contract(a_contract) + b = RegistryPackage.from_contract(b_contract) c = a.incorporate(b) + + self.assertEqual(c.package, 'fishtown-analytics-test/a') self.assertEqual( - c.version, + c.versions, [ - VersionSpecifier({ - 'build': None, - 'major': '0', - 'matcher': '=', - 'minor': '1', - 'patch': '2', - 'prerelease': None, - }), - VersionSpecifier({ - 'build': None, - 'major': '0', - 'matcher': '=', - 'minor': '1', - 'patch': '2', - 'prerelease': None, - }) + VersionSpecifier( + build=None, + major='0', + matcher='=', + minor='1', + patch='2', + prerelease=None, + ), + VersionSpecifier( + build=None, + major='0', + matcher='=', + minor='1', + patch='2', + prerelease=None, + ), ] ) - c.resolve_version() - self.assertEqual(c.package, 'fishtown-analytics-test/a') - self.assertEqual( - c.version, - [VersionSpecifier({ - 'build': None, - 'major': '0', - 'matcher': '=', - 'minor': '1', - 'patch': '2', - 'prerelease': None, - })] - ) - self.assertEqual(c.source_type(), 'hub') + + c_pinned = c.resolved() + self.assertEqual(c_pinned.package, 'fishtown-analytics-test/a') + self.assertEqual(c_pinned.version, '0.1.2') + self.assertEqual(c_pinned.source_type(), 'hub') def test_resolve_missing_package(self): - self.index_cached.return_value = [ - 'fishtown-analytics-test/b', - ] - a = RegistryPackage( - package='fishtown-analytics-test/a', + a = RegistryPackage.from_contract(RegistryPackageContract( + package='fishtown-analytics-test/b', version='0.1.2' - ) + )) with self.assertRaises(dbt.exceptions.DependencyException) as exc: - a.resolve_version() + a.resolved() - msg = 'Package fishtown-analytics-test/a was not found in the package index' + msg = 'Package fishtown-analytics-test/b was not found in the package index' self.assertEqual(msg, str(exc.exception)) def test_resolve_missing_version(self): - self.index_cached.return_value = [ - 'fishtown-analytics-test/a', - ] - self.get_available_versions.return_value = [ - '0.1.3', '0.1.4' - ] - a = RegistryPackage( + a = RegistryPackage.from_contract(RegistryPackageContract( package='fishtown-analytics-test/a', - version='0.1.2' - ) + version='0.1.4' + )) + with self.assertRaises(dbt.exceptions.DependencyException) as exc: - a.resolve_version() + a.resolved() msg = ( "Could not find a matching version for package " - "fishtown-analytics-test/a\n Requested range: =0.1.2, =0.1.2\n " - "Available versions: ['0.1.3', '0.1.4']" + "fishtown-analytics-test/a\n Requested range: =0.1.4, =0.1.4\n " + "Available versions: ['0.1.2', '0.1.3']" ) self.assertEqual(msg, str(exc.exception)) def test_resolve_conflict(self): - self.index_cached.return_value = [ - 'fishtown-analytics-test/a', - ] - self.get_available_versions.return_value = [ - '0.1.2', '0.1.3' - ] - a = RegistryPackage( + a_contract = RegistryPackageContract( package='fishtown-analytics-test/a', version='0.1.2' ) - b = RegistryPackage( + b_contract = RegistryPackageContract( package='fishtown-analytics-test/a', version='0.1.3' ) + a = RegistryPackage.from_contract(a_contract) + b = RegistryPackage.from_contract(b_contract) c = a.incorporate(b) + with self.assertRaises(dbt.exceptions.DependencyException) as exc: - c.resolve_version() + c.resolved() msg = ( "Version error for package fishtown-analytics-test/a: Could not " "find a satisfactory version from options: ['=0.1.2', '=0.1.3']" ) self.assertEqual(msg, str(exc.exception)) + + def test_resolve_ranges(self): + a_contract = RegistryPackageContract( + package='fishtown-analytics-test/a', + version='0.1.2' + ) + b_contract = RegistryPackageContract( + package='fishtown-analytics-test/a', + version='<0.1.4' + ) + a = RegistryPackage.from_contract(a_contract) + b = RegistryPackage.from_contract(b_contract) + c = a.incorporate(b) + + self.assertEqual(c.package, 'fishtown-analytics-test/a') + self.assertEqual( + c.versions, + [ + VersionSpecifier( + build=None, + major='0', + matcher='=', + minor='1', + patch='2', + prerelease=None, + ), + VersionSpecifier( + build=None, + major='0', + matcher='<', + minor='1', + patch='4', + prerelease=None, + ), + ] + ) + + c_pinned = c.resolved() + self.assertEqual(c_pinned.package, 'fishtown-analytics-test/a') + self.assertEqual(c_pinned.version, '0.1.2') + self.assertEqual(c_pinned.source_type(), 'hub') + + +class MockRegistry: + def __init__(self, packages): + self.packages = packages + + def index_cached(self, registry_base_url=None): + return sorted(self.packages) + + def get_available_versions(self, name): + try: + pkg = self.packages[name] + except KeyError: + return [] + return list(pkg) + + def package_version(self, name, version): + try: + return self.packages[name][version] + except KeyError: + return None + + +class TestPackageSpec(unittest.TestCase): + def setUp(self): + self.patcher = mock.patch('dbt.task.deps.registry') + self.registry = self.patcher.start() + self.mock_registry = MockRegistry(packages={ + 'fishtown-analytics-test/a': { + '0.1.2': { + 'id': 'fishtown-analytics-test/a/0.1.2', + 'name': 'a', + 'version': '0.1.2', + 'packages': [], + '_source': { + 'blahblah': 'asdfas', + }, + 'downloads': { + 'tarball': 'https://example.com/invalid-url!', + 'extra': 'field', + }, + 'newfield': ['another', 'value'], + }, + '0.1.3': { + 'id': 'fishtown-analytics-test/a/0.1.3', + 'name': 'a', + 'version': '0.1.3', + 'packages': [], + '_source': { + 'blahblah': 'asdfas', + }, + 'downloads': { + 'tarball': 'https://example.com/invalid-url!', + 'extra': 'field', + }, + 'newfield': ['another', 'value'], + } + }, + 'fishtown-analytics-test/b': { + '0.2.1': { + 'id': 'fishtown-analytics-test/b/0.2.1', + 'name': 'b', + 'version': '0.2.1', + 'packages': [{'package': 'fishtown-analytics-test/a', 'version': '>=0.1.3'}], + '_source': { + 'blahblah': 'asdfas', + }, + 'downloads': { + 'tarball': 'https://example.com/invalid-url!', + 'extra': 'field', + }, + 'newfield': ['another', 'value'], + }, + } + }) + + self.registry.index_cached.side_effect = self.mock_registry.index_cached + self.registry.get_available_versions.side_effect = self.mock_registry.get_available_versions + self.registry.package_version.side_effect = self.mock_registry.package_version + + def tearDown(self): + self.patcher.stop() + + def test_dependency_resolution(self): + package_config = PackageConfig.from_dict({ + 'packages': [ + {'package': 'fishtown-analytics-test/a', 'version': '>0.1.2'}, + {'package': 'fishtown-analytics-test/b', 'version': '0.2.1'}, + ], + }) + resolved = resolve_packages(package_config.packages, None) + self.assertEqual(len(resolved), 2) + self.assertEqual(resolved[0].name, 'fishtown-analytics-test/a') + self.assertEqual(resolved[0].version, '0.1.3') + self.assertEqual(resolved[1].name, 'fishtown-analytics-test/b') + self.assertEqual(resolved[1].version, '0.2.1') diff --git a/test/unit/test_docs_blocks.py b/test/unit/test_docs_blocks.py index a28eb1e5fcf..213d745d958 100644 --- a/test/unit/test_docs_blocks.py +++ b/test/unit/test_docs_blocks.py @@ -119,7 +119,6 @@ def test_load_file(self, system): def test_parse(self): docfile = UnparsedDocumentationFile( root_path=self.root_path, - resource_type=NodeType.Documentation, path='test_file.md', original_file_path=self.testfile_path, package_name='some_package', diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index c7e0006206d..2436756739e 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -63,7 +63,9 @@ def mock_write_gpickle(graph, outfile): } self.mock_load_projects = self.load_projects_patcher.start() - self.mock_load_projects.return_value = [] + def _load_projects(config, paths): + yield config.project_name, config + self.mock_load_projects.side_effect = _load_projects self.mock_models = [] self.mock_content = {} @@ -187,12 +189,9 @@ def test__model_materializations(self): "model_four": "table" } - nodes = linker.graph.node - for model, expected in expected_materialization.items(): key = 'model.test_models_compile.{}'.format(model) - actual = manifest.nodes[key].get('config', {}) \ - .get('materialized') + actual = manifest.nodes[key].config.materialized self.assertEqual(actual, expected) def test__model_incremental(self): @@ -221,10 +220,7 @@ def test__model_incremental(self): self.assertEqual(list(linker.nodes()), [node]) self.assertEqual(list(linker.edges()), []) - self.assertEqual( - manifest.nodes[node].get('config', {}).get('materialized'), - 'incremental' - ) + self.assertEqual(manifest.nodes[node].config.materialized, 'incremental') def test__dependency_list(self): self.use_models({ diff --git a/test/unit/test_manifest.py b/test/unit/test_manifest.py index f8cddfb5955..ed47a5f990f 100644 --- a/test/unit/test_manifest.py +++ b/test/unit/test_manifest.py @@ -2,23 +2,38 @@ from unittest import mock import copy +from datetime import datetime import dbt.flags from dbt import tracking -from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.graph.parsed import ParsedNode +from dbt.contracts.graph.manifest import Manifest, ManifestMetadata +from dbt.contracts.graph.parsed import ParsedNode, DependsOn, NodeConfig from dbt.contracts.graph.compiled import CompiledNode -from dbt.utils import timestring +from dbt.node_types import NodeType import freezegun +REQUIRED_PARSED_NODE_KEYS = frozenset({ + 'alias', 'tags', 'config', 'unique_id', 'refs', 'sources', + 'depends_on', 'database', 'schema', 'name', 'resource_type', + 'package_name', 'root_path', 'path', 'original_file_path', 'raw_sql', + 'docrefs', 'description', 'columns', 'fqn', 'build_path', 'patch_path', + 'index', +}) + +REQUIRED_COMPILED_NODE_KEYS = frozenset(REQUIRED_PARSED_NODE_KEYS | { + 'compiled', 'extra_ctes_injected', 'extra_ctes', 'compiled_sql', + 'injected_sql', 'wrapped_sql' +}) + + class ManifestTest(unittest.TestCase): def setUp(self): dbt.flags.STRICT_MODE = True self.maxDiff = None - self.model_config = { + self.model_config = NodeConfig.from_dict({ 'enabled': True, 'materialized': 'view', 'persist_docs': {}, @@ -28,7 +43,7 @@ def setUp(self): 'quoting': {}, 'column_types': {}, 'tags': [], - } + }) self.nested_nodes = { 'model.snowplow.events': ParsedNode( @@ -36,17 +51,13 @@ def setUp(self): database='dbt', schema='analytics', alias='events', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.snowplow.events', fqn=['snowplow', 'events'], - empty=False, package_name='snowplow', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='events.sql', @@ -59,17 +70,13 @@ def setUp(self): database='dbt', schema='analytics', alias='events', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.events', fqn=['root', 'events'], - empty=False, package_name='root', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='events.sql', @@ -82,17 +89,13 @@ def setUp(self): database='dbt', schema='analytics', alias='dep', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.dep', fqn=['root', 'dep'], - empty=False, package_name='root', refs=[['events']], sources=[], - depends_on={ - 'nodes': ['model.root.events'], - 'macros': [] - }, + depends_on=DependsOn(nodes=['model.root.events']), config=self.model_config, tags=[], path='multi.sql', @@ -105,17 +108,13 @@ def setUp(self): database='dbt', schema='analytics', alias='nested', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.nested', fqn=['root', 'nested'], - empty=False, package_name='root', refs=[['events']], sources=[], - depends_on={ - 'nodes': ['model.root.dep'], - 'macros': [] - }, + depends_on=DependsOn(nodes=['model.root.dep']), config=self.model_config, tags=[], path='multi.sql', @@ -128,17 +127,13 @@ def setUp(self): database='dbt', schema='analytics', alias='sibling', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.sibling', fqn=['root', 'sibling'], - empty=False, package_name='root', refs=[['events']], sources=[], - depends_on={ - 'nodes': ['model.root.events'], - 'macros': [] - }, + depends_on=DependsOn(nodes=['model.root.events']), config=self.model_config, tags=[], path='multi.sql', @@ -151,17 +146,13 @@ def setUp(self): database='dbt', schema='analytics', alias='multi', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.multi', fqn=['root', 'multi'], - empty=False, package_name='root', refs=[['events']], sources=[], - depends_on={ - 'nodes': ['model.root.nested', 'model.root.sibling'], - 'macros': [] - }, + depends_on=DependsOn(nodes=['model.root.nested', 'model.root.sibling']), config=self.model_config, tags=[], path='multi.sql', @@ -170,13 +161,15 @@ def setUp(self): raw_sql='does not matter' ), } + for node in self.nested_nodes.values(): + node.validate(node.to_dict()) @freezegun.freeze_time('2018-02-14T09:15:13Z') def test__no_nodes(self): manifest = Manifest(nodes={}, macros={}, docs={}, - generated_at=timestring(), disabled=[]) + generated_at=datetime.utcnow(), disabled=[]) self.assertEqual( - manifest.serialize(), + manifest.writable_manifest().to_dict(), { 'nodes': {}, 'macros': {}, @@ -184,11 +177,7 @@ def test__no_nodes(self): 'child_map': {}, 'generated_at': '2018-02-14T09:15:13Z', 'docs': {}, - 'metadata': { - 'project_id': None, - 'user_id': None, - 'send_anonymous_usage_stats': None, - }, + 'metadata': {}, 'disabled': [], } ) @@ -197,8 +186,8 @@ def test__no_nodes(self): def test__nested_nodes(self): nodes = copy.copy(self.nested_nodes) manifest = Manifest(nodes=nodes, macros={}, docs={}, - generated_at=timestring(), disabled=[]) - serialized = manifest.serialize() + generated_at=datetime.utcnow(), disabled=[]) + serialized = manifest.writable_manifest().to_dict() self.assertEqual(serialized['generated_at'], '2018-02-14T09:15:13Z') self.assertEqual(serialized['docs'], {}) self.assertEqual(serialized['disabled'], []) @@ -261,15 +250,14 @@ def test__nested_nodes(self): def test__to_flat_graph(self): nodes = copy.copy(self.nested_nodes) manifest = Manifest(nodes=nodes, macros={}, docs={}, - generated_at=timestring(), disabled=[]) + generated_at=datetime.utcnow(), disabled=[]) flat_graph = manifest.to_flat_graph() flat_nodes = flat_graph['nodes'] self.assertEqual(set(flat_graph), set(['nodes', 'macros'])) self.assertEqual(flat_graph['macros'], {}) self.assertEqual(set(flat_nodes), set(self.nested_nodes)) - expected_keys = set(ParsedNode.SCHEMA['required']) for node in flat_nodes.values(): - self.assertEqual(set(node), expected_keys) + self.assertEqual(frozenset(node), REQUIRED_PARSED_NODE_KEYS) @mock.patch.object(tracking, 'active_user') def test_get_metadata(self, mock_user): @@ -280,11 +268,11 @@ def test_get_metadata(self, mock_user): config.hashed_name.return_value = '098f6bcd4621d373cade4e832627b4f6' self.assertEqual( Manifest.get_metadata(config), - { - 'project_id': '098f6bcd4621d373cade4e832627b4f6', - 'user_id': 'cfc9500f-dc7f-4c83-9ea7-2c581c1b38cf', - 'send_anonymous_usage_stats': False, - } + ManifestMetadata( + project_id='098f6bcd4621d373cade4e832627b4f6', + user_id='cfc9500f-dc7f-4c83-9ea7-2c581c1b38cf', + send_anonymous_usage_stats=False, + ) ) @mock.patch.object(tracking, 'active_user') @@ -296,7 +284,7 @@ def test_no_nodes_with_metadata(self, mock_user): # md5 of 'test' config.hashed_name.return_value = '098f6bcd4621d373cade4e832627b4f6' manifest = Manifest(nodes={}, macros={}, docs={}, - generated_at=timestring(), disabled=[], + generated_at=datetime.utcnow(), disabled=[], config=config) metadata = { 'project_id': '098f6bcd4621d373cade4e832627b4f6', @@ -304,7 +292,7 @@ def test_no_nodes_with_metadata(self, mock_user): 'send_anonymous_usage_stats': False, } self.assertEqual( - manifest.serialize(), + manifest.writable_manifest().to_dict(), { 'nodes': {}, 'macros': {}, @@ -323,7 +311,7 @@ def test_no_nodes_with_metadata(self, mock_user): def test_get_resource_fqns_empty(self): manifest = Manifest(nodes={}, macros={}, docs={}, - generated_at=timestring(), disabled=[]) + generated_at=datetime.utcnow(), disabled=[]) self.assertEqual(manifest.get_resource_fqns(), {}) def test_get_resource_fqns(self): @@ -336,14 +324,10 @@ def test_get_resource_fqns(self): resource_type='seed', unique_id='seed.root.seed', fqn=['root', 'seed'], - empty=False, package_name='root', refs=[['events']], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='seed.csv', @@ -352,7 +336,7 @@ def test_get_resource_fqns(self): raw_sql='-- csv --' ) manifest = Manifest(nodes=nodes, macros={}, docs={}, - generated_at=timestring(), disabled=[]) + generated_at=datetime.utcnow(), disabled=[]) expect = { 'models': frozenset([ ('snowplow', 'events'), @@ -374,7 +358,7 @@ def setUp(self): self.maxDiff = None - self.model_config = { + self.model_config = NodeConfig.from_dict({ 'enabled': True, 'materialized': 'view', 'persist_docs': {}, @@ -384,7 +368,7 @@ def setUp(self): 'quoting': {}, 'column_types': {}, 'tags': [], - } + }) self.nested_nodes = { 'model.snowplow.events': CompiledNode( @@ -392,17 +376,13 @@ def setUp(self): database='dbt', schema='analytics', alias='events', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.snowplow.events', fqn=['snowplow', 'events'], - empty=False, package_name='snowplow', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='events.sql', @@ -420,17 +400,13 @@ def setUp(self): database='dbt', schema='analytics', alias='events', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.events', fqn=['root', 'events'], - empty=False, package_name='root', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='events.sql', @@ -448,17 +424,13 @@ def setUp(self): database='dbt', schema='analytics', alias='dep', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.dep', fqn=['root', 'dep'], - empty=False, package_name='root', refs=[['events']], sources=[], - depends_on={ - 'nodes': ['model.root.events'], - 'macros': [] - }, + depends_on=DependsOn(nodes=['model.root.events']), config=self.model_config, tags=[], path='multi.sql', @@ -471,17 +443,13 @@ def setUp(self): database='dbt', schema='analytics', alias='nested', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.nested', fqn=['root', 'nested'], - empty=False, package_name='root', refs=[['events']], sources=[], - depends_on={ - 'nodes': ['model.root.dep'], - 'macros': [] - }, + depends_on=DependsOn(nodes=['model.root.dep']), config=self.model_config, tags=[], path='multi.sql', @@ -494,17 +462,13 @@ def setUp(self): database='dbt', schema='analytics', alias='sibling', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.sibling', fqn=['root', 'sibling'], - empty=False, package_name='root', refs=[['events']], sources=[], - depends_on={ - 'nodes': ['model.root.events'], - 'macros': [] - }, + depends_on=DependsOn(nodes=['model.root.events']), config=self.model_config, tags=[], path='multi.sql', @@ -517,17 +481,13 @@ def setUp(self): database='dbt', schema='analytics', alias='multi', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.multi', fqn=['root', 'multi'], - empty=False, package_name='root', refs=[['events']], sources=[], - depends_on={ - 'nodes': ['model.root.nested', 'model.root.sibling'], - 'macros': [] - }, + depends_on=DependsOn(nodes=['model.root.nested', 'model.root.sibling']), config=self.model_config, tags=[], path='multi.sql', @@ -540,9 +500,9 @@ def setUp(self): @freezegun.freeze_time('2018-02-14T09:15:13Z') def test__no_nodes(self): manifest = Manifest(nodes={}, macros={}, docs={}, - generated_at=timestring(), disabled=[]) + generated_at=datetime.utcnow(), disabled=[]) self.assertEqual( - manifest.serialize(), + manifest.writable_manifest().to_dict(), { 'nodes': {}, 'macros': {}, @@ -550,11 +510,7 @@ def test__no_nodes(self): 'child_map': {}, 'generated_at': '2018-02-14T09:15:13Z', 'docs': {}, - 'metadata': { - 'project_id': None, - 'user_id': None, - 'send_anonymous_usage_stats': None, - }, + 'metadata': {}, 'disabled': [], } ) @@ -563,8 +519,8 @@ def test__no_nodes(self): def test__nested_nodes(self): nodes = copy.copy(self.nested_nodes) manifest = Manifest(nodes=nodes, macros={}, docs={}, - generated_at=timestring(), disabled=[]) - serialized = manifest.serialize() + generated_at=datetime.utcnow(), disabled=[]) + serialized = manifest.writable_manifest().to_dict() self.assertEqual(serialized['generated_at'], '2018-02-14T09:15:13Z') self.assertEqual(serialized['disabled'], []) parent_map = serialized['parent_map'] @@ -626,19 +582,17 @@ def test__nested_nodes(self): def test__to_flat_graph(self): nodes = copy.copy(self.nested_nodes) manifest = Manifest(nodes=nodes, macros={}, docs={}, - generated_at=timestring(), disabled=[]) + generated_at=datetime.utcnow(), disabled=[]) flat_graph = manifest.to_flat_graph() flat_nodes = flat_graph['nodes'] self.assertEqual(set(flat_graph), set(['nodes', 'macros'])) self.assertEqual(flat_graph['macros'], {}) self.assertEqual(set(flat_nodes), set(self.nested_nodes)) - parsed_keys = set(ParsedNode.SCHEMA['required']) - compiled_keys = set(CompiledNode.SCHEMA['required']) compiled_count = 0 for node in flat_nodes.values(): if node.get('compiled'): - self.assertEqual(set(node), compiled_keys) + self.assertEqual(frozenset(node), REQUIRED_COMPILED_NODE_KEYS) compiled_count += 1 else: - self.assertEqual(set(node), parsed_keys) + self.assertEqual(frozenset(node), REQUIRED_PARSED_NODE_KEYS) self.assertEqual(compiled_count, 2) diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index 00af5d76eff..8640a4a4bd1 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -1,19 +1,23 @@ import unittest from unittest import mock +from datetime import datetime import os import yaml import dbt.flags import dbt.parser -from dbt.parser import ModelParser, MacroParser, DataTestParser, SchemaParser, ParserUtils +from dbt.parser import ModelParser, MacroParser, DataTestParser, \ + SchemaParser, ParserUtils from dbt.parser.source_config import SourceConfig -from dbt.utils import timestring, deep_merge from dbt.node_types import NodeType from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.parsed import ParsedNode, ParsedMacro, \ - ParsedNodePatch, ParsedSourceDefinition + ParsedNodePatch, ParsedSourceDefinition, NodeConfig, DependsOn, \ + ColumnInfo, ParsedTestNode, TestConfig +from dbt.contracts.graph.unparsed import FreshnessThreshold, Quoting, Time, \ + TimePeriod from .utils import config_from_parts_or_dicts @@ -80,6 +84,7 @@ def setUp(self): def tearDown(self): self.patcher.stop() + class SourceConfigTest(BaseParserTest): def test__source_config_single_call(self): cfg = SourceConfig(self.root_project_config, self.root_project_config, @@ -162,11 +167,11 @@ def setUp(self): super().setUp() self.maxDiff = None - self.macro_manifest = Manifest(macros={}, nodes={}, docs={}, - generated_at=timestring(), disabled=[]) + generated_at=datetime.utcnow(), + disabled=[]) - self.model_config = { + self.model_config = NodeConfig.from_dict({ 'enabled': True, 'materialized': 'view', 'persist_docs': {}, @@ -176,10 +181,21 @@ def setUp(self): 'quoting': {}, 'column_types': {}, 'tags': [], - } + }) - self.test_config = deep_merge(self.model_config, {'severity': 'ERROR'}) - self.warn_test_config = deep_merge(self.model_config, {'severity': 'WARN'}) + self.test_config = TestConfig.from_dict({ + 'enabled': True, + 'materialized': 'view', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [], + 'vars': {}, + 'quoting': {}, + 'column_types': {}, + 'tags': [], + 'severity': 'ERROR', + }) + self.warn_test_config = self.test_config.replace(severity='WARN') self.disabled_config = { 'enabled': False, @@ -204,51 +220,38 @@ def setUp(self): path='test_one.yml', original_file_path='test_one.yml', columns={ - 'id': { - 'name': 'id', - 'description': 'user ID', - }, + 'id': ColumnInfo(name='id', description='user ID'), }, docrefs=[], - freshness={ - 'warn_after': { - 'count': 7, - 'period': 'hour' - }, - 'error_after': { - 'count': 20, - 'period': 'hour' - }, - }, + freshness=FreshnessThreshold( + warn_after=Time(count=7, period=TimePeriod.hour), + error_after=Time(count=20, period=TimePeriod.hour) + ), loaded_at_field='something', database='test', schema='foo', identifier='bar', - resource_type='source', - quoting={ - 'schema': True, - 'identifier': False, - }, + resource_type=NodeType.Source, + quoting=Quoting(schema=True, identifier=False), fqn=['root', 'my_source', 'my_table'] ) self._expected_source_tests = [ - ParsedNode( + ParsedTestNode( alias='source_accepted_values_my_source_my_table_id__a__b', name='source_accepted_values_my_source_my_table_id__a__b', database='test', schema='analytics', - resource_type='test', + resource_type=NodeType.Test, unique_id='test.root.source_accepted_values_my_source_my_table_id__a__b', fqn=['root', 'schema_test', 'source_accepted_values_my_source_my_table_id__a__b'], - empty=False, package_name='root', original_file_path='test_one.yml', root_path=get_os_path('/usr/src/app'), refs=[], sources=[['my_source', 'my_table']], - depends_on={'nodes': [], 'macros': []}, + depends_on=DependsOn(), config=self.test_config, path=get_os_path( 'schema_test/source_accepted_values_my_source_my_table_id__a__b.sql'), @@ -256,22 +259,21 @@ def setUp(self): raw_sql="{{ config(severity='ERROR') }}{{ test_accepted_values(model=source('my_source', 'my_table'), column_name='id', values=['a', 'b']) }}", description='', columns={}, - column_name='id' + column_name='id', ), - ParsedNode( + ParsedTestNode( alias='source_not_null_my_source_my_table_id', name='source_not_null_my_source_my_table_id', database='test', schema='analytics', - resource_type='test', + resource_type=NodeType.Test, unique_id='test.root.source_not_null_my_source_my_table_id', fqn=['root', 'schema_test', 'source_not_null_my_source_my_table_id'], - empty=False, package_name='root', root_path=get_os_path('/usr/src/app'), refs=[], sources=[['my_source', 'my_table']], - depends_on={'nodes': [], 'macros': []}, + depends_on=DependsOn(), config=self.test_config, original_file_path='test_one.yml', path=get_os_path('schema_test/source_not_null_my_source_my_table_id.sql'), @@ -279,68 +281,65 @@ def setUp(self): raw_sql="{{ config(severity='ERROR') }}{{ test_not_null(model=source('my_source', 'my_table'), column_name='id') }}", description='', columns={}, - column_name='id' + column_name='id', ), - ParsedNode( + ParsedTestNode( alias='source_relationships_my_source_my_table_id__id__ref_model_two_', name='source_relationships_my_source_my_table_id__id__ref_model_two_', database='test', schema='analytics', - resource_type='test', + resource_type=NodeType.Test, unique_id='test.root.source_relationships_my_source_my_table_id__id__ref_model_two_', # noqa fqn=['root', 'schema_test', 'source_relationships_my_source_my_table_id__id__ref_model_two_'], - empty=False, package_name='root', original_file_path='test_one.yml', root_path=get_os_path('/usr/src/app'), refs=[['model_two']], sources=[['my_source', 'my_table']], - depends_on={'nodes': [], 'macros': []}, + depends_on=DependsOn(), config=self.test_config, path=get_os_path('schema_test/source_relationships_my_source_my_table_id__id__ref_model_two_.sql'), # noqa tags=['schema'], raw_sql="{{ config(severity='ERROR') }}{{ test_relationships(model=source('my_source', 'my_table'), column_name='id', from='id', to=ref('model_two')) }}", description='', columns={}, - column_name='id' + column_name='id', ), - ParsedNode( + ParsedTestNode( alias='source_some_test_my_source_my_table_value', name='source_some_test_my_source_my_table_value', database='test', schema='analytics', - resource_type='test', + resource_type=NodeType.Test, unique_id='test.snowplow.source_some_test_my_source_my_table_value', fqn=['snowplow', 'schema_test', 'source_some_test_my_source_my_table_value'], - empty=False, package_name='snowplow', original_file_path='test_one.yml', root_path=get_os_path('/usr/src/app'), refs=[], sources=[['my_source', 'my_table']], - depends_on={'nodes': [], 'macros': []}, + depends_on=DependsOn(), config=self.warn_test_config, path=get_os_path('schema_test/source_some_test_my_source_my_table_value.sql'), tags=['schema'], raw_sql="{{ config(severity='WARN') }}{{ snowplow.test_some_test(model=source('my_source', 'my_table'), key='value') }}", description='', - columns={} + columns={}, ), - ParsedNode( + ParsedTestNode( alias='source_unique_my_source_my_table_id', name='source_unique_my_source_my_table_id', database='test', schema='analytics', - resource_type='test', + resource_type=NodeType.Test, unique_id='test.root.source_unique_my_source_my_table_id', fqn=['root', 'schema_test', 'source_unique_my_source_my_table_id'], - empty=False, package_name='root', root_path=get_os_path('/usr/src/app'), refs=[], sources=[['my_source', 'my_table']], - depends_on={'nodes': [], 'macros': []}, + depends_on=DependsOn(), config=self.warn_test_config, original_file_path='test_one.yml', path=get_os_path('schema_test/source_unique_my_source_my_table_id.sql'), @@ -348,27 +347,26 @@ def setUp(self): raw_sql="{{ config(severity='WARN') }}{{ test_unique(model=source('my_source', 'my_table'), column_name='id') }}", description='', columns={}, - column_name='id' + column_name='id', ), ] self._expected_model_tests = [ - ParsedNode( + ParsedTestNode( alias='accepted_values_model_one_id__a__b', name='accepted_values_model_one_id__a__b', database='test', schema='analytics', - resource_type='test', + resource_type=NodeType.Test, unique_id='test.root.accepted_values_model_one_id__a__b', fqn=['root', 'schema_test', 'accepted_values_model_one_id__a__b'], - empty=False, package_name='root', original_file_path='test_one.yml', root_path=get_os_path('/usr/src/app'), refs=[['model_one']], sources=[], - depends_on={'nodes': [], 'macros': []}, + depends_on=DependsOn(), config=self.test_config, path=get_os_path( 'schema_test/accepted_values_model_one_id__a__b.sql'), @@ -376,22 +374,21 @@ def setUp(self): raw_sql="{{ config(severity='ERROR') }}{{ test_accepted_values(model=ref('model_one'), column_name='id', values=['a', 'b']) }}", description='', columns={}, - column_name='id' + column_name='id', ), - ParsedNode( + ParsedTestNode( alias='not_null_model_one_id', name='not_null_model_one_id', database='test', schema='analytics', - resource_type='test', + resource_type=NodeType.Test, unique_id='test.root.not_null_model_one_id', fqn=['root', 'schema_test', 'not_null_model_one_id'], - empty=False, package_name='root', root_path=get_os_path('/usr/src/app'), refs=[['model_one']], sources=[], - depends_on={'nodes': [], 'macros': []}, + depends_on=DependsOn(), config=self.test_config, original_file_path='test_one.yml', path=get_os_path('schema_test/not_null_model_one_id.sql'), @@ -399,68 +396,65 @@ def setUp(self): raw_sql="{{ config(severity='ERROR') }}{{ test_not_null(model=ref('model_one'), column_name='id') }}", description='', columns={}, - column_name='id' + column_name='id', ), - ParsedNode( + ParsedTestNode( alias='relationships_model_one_id__id__ref_model_two_', name='relationships_model_one_id__id__ref_model_two_', database='test', schema='analytics', - resource_type='test', + resource_type=NodeType.Test, unique_id='test.root.relationships_model_one_id__id__ref_model_two_', # noqa fqn=['root', 'schema_test', 'relationships_model_one_id__id__ref_model_two_'], - empty=False, package_name='root', original_file_path='test_one.yml', root_path=get_os_path('/usr/src/app'), refs=[['model_one'], ['model_two']], sources=[], - depends_on={'nodes': [], 'macros': []}, + depends_on=DependsOn(), config=self.test_config, path=get_os_path('schema_test/relationships_model_one_id__id__ref_model_two_.sql'), # noqa tags=['schema'], raw_sql="{{ config(severity='ERROR') }}{{ test_relationships(model=ref('model_one'), column_name='id', from='id', to=ref('model_two')) }}", description='', columns={}, - column_name='id' + column_name='id', ), - ParsedNode( + ParsedTestNode( alias='some_test_model_one_value', name='some_test_model_one_value', database='test', schema='analytics', - resource_type='test', + resource_type=NodeType.Test, unique_id='test.snowplow.some_test_model_one_value', fqn=['snowplow', 'schema_test', 'some_test_model_one_value'], - empty=False, package_name='snowplow', original_file_path='test_one.yml', root_path=get_os_path('/usr/src/app'), refs=[['model_one']], sources=[], - depends_on={'nodes': [], 'macros': []}, + depends_on=DependsOn(), config=self.warn_test_config, path=get_os_path('schema_test/some_test_model_one_value.sql'), tags=['schema'], raw_sql="{{ config(severity='WARN') }}{{ snowplow.test_some_test(model=ref('model_one'), key='value') }}", description='', - columns={} + columns={}, ), - ParsedNode( + ParsedTestNode( alias='unique_model_one_id', name='unique_model_one_id', database='test', schema='analytics', - resource_type='test', + resource_type=NodeType.Test, unique_id='test.root.unique_model_one_id', fqn=['root', 'schema_test', 'unique_model_one_id'], - empty=False, package_name='root', root_path=get_os_path('/usr/src/app'), refs=[['model_one']], sources=[], - depends_on={'nodes': [], 'macros': []}, + depends_on=DependsOn(), config=self.warn_test_config, original_file_path='test_one.yml', path=get_os_path('schema_test/unique_model_one_id.sql'), @@ -468,7 +462,7 @@ def setUp(self): raw_sql="{{ config(severity='WARN') }}{{ test_unique(model=ref('model_one'), column_name='id') }}", description='', columns={}, - column_name='id' + column_name='id', ), ] @@ -477,10 +471,7 @@ def setUp(self): description='blah blah', original_file_path='test_one.yml', columns={ - 'id': { - 'name': 'id', - 'description': 'user ID', - }, + 'id': ColumnInfo(name='id', description='user ID'), }, docrefs=[], ) @@ -946,9 +937,9 @@ def setUp(self): super().setUp() self.macro_manifest = Manifest(macros={}, nodes={}, docs={}, - generated_at=timestring(), disabled=[]) + generated_at=datetime.utcnow(), disabled=[]) - self.model_config = { + self.model_config = NodeConfig.from_dict({ 'enabled': True, 'materialized': 'view', 'persist_docs': {}, @@ -958,10 +949,10 @@ def setUp(self): 'quoting': {}, 'column_types': {}, 'tags': [], - } - self.test_config = deep_merge(self.model_config, {'severity': 'ERROR'}) + }) + self.test_config = self.model_config.replace(severity='ERROR') - self.disabled_config = { + self.disabled_config = NodeConfig.from_dict({ 'enabled': False, 'materialized': 'view', 'persist_docs': {}, @@ -971,7 +962,7 @@ def setUp(self): 'quoting': {}, 'column_types': {}, 'tags': [], - } + }) def test__single_model(self): models = [{ @@ -997,19 +988,15 @@ def test__single_model(self): name='model_one', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.model_one', fqn=['root', 'model_one'], - empty=False, package_name='root', original_file_path='model_one.sql', root_path=get_os_path('/usr/src/app'), refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='model_one.sql', @@ -1044,10 +1031,7 @@ def test__single_model__nested_configuration(self): } } - ephemeral_config = self.model_config.copy() - ephemeral_config.update({ - 'materialized': 'ephemeral' - }) + ephemeral_config = self.model_config.replace(materialized='ephemeral') parser = ModelParser( self.root_project_config, @@ -1062,19 +1046,15 @@ def test__single_model__nested_configuration(self): name='model_one', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.model_one', fqn=['root', 'nested', 'path', 'model_one'], - empty=False, package_name='root', original_file_path='nested/path/model_one.sql', root_path=get_os_path('/usr/src/app'), refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=ephemeral_config, tags=[], path=get_os_path('nested/path/model_one.sql'), @@ -1113,17 +1093,13 @@ def test__empty_model(self): name='model_one', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.model_one', fqn=['root', 'model_one'], - empty=True, package_name='root', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [], - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='model_one.sql', @@ -1171,24 +1147,19 @@ def test__simple_dependency(self): name='base', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.base', fqn=['root', 'base'], - empty=False, package_name='root', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='base.sql', original_file_path='base.sql', root_path=get_os_path('/usr/src/app'), - raw_sql=self.find_input_by_name( - models, 'base').get('raw_sql'), + raw_sql=self.find_input_by_name(models, 'base').get('raw_sql'), description='', columns={} @@ -1198,24 +1169,19 @@ def test__simple_dependency(self): name='events_tx', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.events_tx', fqn=['root', 'events_tx'], - empty=False, package_name='root', refs=[['base']], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='events_tx.sql', original_file_path='events_tx.sql', root_path=get_os_path('/usr/src/app'), - raw_sql=self.find_input_by_name( - models, 'events_tx').get('raw_sql'), + raw_sql=self.find_input_by_name(models, 'events_tx').get('raw_sql'), description='', columns={} ) @@ -1284,17 +1250,13 @@ def test__multiple_dependencies(self): name='events', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.events', fqn=['root', 'events'], - empty=False, package_name='root', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='events.sql', @@ -1310,17 +1272,13 @@ def test__multiple_dependencies(self): name='sessions', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.sessions', fqn=['root', 'sessions'], - empty=False, package_name='root', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='sessions.sql', @@ -1336,17 +1294,13 @@ def test__multiple_dependencies(self): name='events_tx', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.events_tx', fqn=['root', 'events_tx'], - empty=False, package_name='root', refs=[['events']], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='events_tx.sql', @@ -1362,17 +1316,13 @@ def test__multiple_dependencies(self): name='sessions_tx', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.sessions_tx', fqn=['root', 'sessions_tx'], - empty=False, package_name='root', refs=[['sessions']], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='sessions_tx.sql', @@ -1388,17 +1338,13 @@ def test__multiple_dependencies(self): name='multi', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.multi', fqn=['root', 'multi'], - empty=False, package_name='root', refs=[['sessions_tx'], ['events_tx']], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='multi.sql', @@ -1476,17 +1422,13 @@ def test__multiple_dependencies__packages(self): name='events', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.snowplow.events', fqn=['snowplow', 'events'], - empty=False, package_name='snowplow', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='events.sql', @@ -1502,17 +1444,13 @@ def test__multiple_dependencies__packages(self): name='sessions', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.snowplow.sessions', fqn=['snowplow', 'sessions'], - empty=False, package_name='snowplow', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='sessions.sql', @@ -1528,17 +1466,13 @@ def test__multiple_dependencies__packages(self): name='events_tx', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.snowplow.events_tx', fqn=['snowplow', 'events_tx'], - empty=False, package_name='snowplow', refs=[['events']], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='events_tx.sql', @@ -1554,17 +1488,13 @@ def test__multiple_dependencies__packages(self): name='sessions_tx', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.snowplow.sessions_tx', fqn=['snowplow', 'sessions_tx'], - empty=False, package_name='snowplow', refs=[['sessions']], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='sessions_tx.sql', @@ -1580,18 +1510,14 @@ def test__multiple_dependencies__packages(self): name='multi', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.multi', fqn=['root', 'multi'], - empty=False, package_name='root', refs=[['snowplow', 'sessions_tx'], ['snowplow', 'events_tx']], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='multi.sql', @@ -1607,86 +1533,71 @@ def test__multiple_dependencies__packages(self): ) def test__process_refs__packages(self): - graph = { - 'macros': {}, - 'nodes': { - 'model.snowplow.events': { - 'name': 'events', - 'alias': 'events', - 'database': 'test', - 'schema': 'analytics', - 'resource_type': 'model', - 'unique_id': 'model.snowplow.events', - 'fqn': ['snowplow', 'events'], - 'empty': False, - 'package_name': 'snowplow', - 'refs': [], - 'sources': [], - 'depends_on': { - 'nodes': [], - 'macros': [] - }, - 'config': self.disabled_config, - 'tags': [], - 'path': 'events.sql', - 'original_file_path': 'events.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': 'does not matter' - }, - 'model.root.events': { - 'name': 'events', - 'alias': 'events', - 'database': 'test', - 'schema': 'analytics', - 'resource_type': 'model', - 'unique_id': 'model.root.events', - 'fqn': ['root', 'events'], - 'empty': False, - 'package_name': 'root', - 'refs': [], - 'sources': [], - 'depends_on': { - 'nodes': [], - 'macros': [] - }, - 'config': self.model_config, - 'tags': [], - 'path': 'events.sql', - 'original_file_path': 'events.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': 'does not matter' - }, - 'model.root.dep': { - 'name': 'dep', - 'alias': 'dep', - 'database': 'test', - 'schema': 'analytics', - 'resource_type': 'model', - 'unique_id': 'model.root.dep', - 'fqn': ['root', 'dep'], - 'empty': False, - 'package_name': 'root', - 'refs': [['events']], - 'sources': [], - 'depends_on': { - 'nodes': [], - 'macros': [] - }, - 'config': self.model_config, - 'tags': [], - 'path': 'multi.sql', - 'original_file_path': 'multi.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': 'does not matter' - } - } + nodes = { + 'model.snowplow.events': ParsedNode( + name='events', + alias='events', + database='test', + schema='analytics', + resource_type=NodeType.Model, + unique_id='model.snowplow.events', + fqn=['snowplow', 'events'], + package_name='snowplow', + refs=[], + sources=[], + depends_on=DependsOn(), + config=self.disabled_config, + tags=[], + path='events.sql', + original_file_path='events.sql', + root_path=get_os_path('/usr/src/app'), + raw_sql='does not matter', + ), + 'model.root.events': ParsedNode( + name='events', + alias='events', + database='test', + schema='analytics', + resource_type=NodeType.Model, + unique_id='model.root.events', + fqn=['root', 'events'], + package_name='root', + refs=[], + sources=[], + depends_on=DependsOn(), + config=self.model_config, + tags=[], + path='events.sql', + original_file_path='events.sql', + root_path=get_os_path('/usr/src/app'), + raw_sql='does not matter', + ), + 'model.root.dep': ParsedNode( + name='dep', + alias='dep', + database='test', + schema='analytics', + resource_type=NodeType.Model, + unique_id='model.root.dep', + fqn=['root', 'dep'], + package_name='root', + refs=[['events']], + sources=[], + depends_on=DependsOn(), + config=self.model_config, + tags=[], + path='multi.sql', + original_file_path='multi.sql', + root_path=get_os_path('/usr/src/app'), + raw_sql='does not matter', + ), } manifest = Manifest( - nodes={k: ParsedNode(**v) for (k,v) in graph['nodes'].items()}, - macros={k: ParsedMacro(**v) for (k,v) in graph['macros'].items()}, + nodes=nodes, + macros={}, docs={}, - generated_at=timestring(), + generated_at=datetime.utcnow(), disabled=[] ) @@ -1704,15 +1615,15 @@ def test__process_refs__packages(self): 'resource_type': 'model', 'unique_id': 'model.snowplow.events', 'fqn': ['snowplow', 'events'], - 'empty': False, 'package_name': 'snowplow', + 'docrefs': [], 'refs': [], 'sources': [], 'depends_on': { 'nodes': [], 'macros': [] }, - 'config': self.disabled_config, + 'config': self.disabled_config.to_dict(), 'tags': [], 'path': 'events.sql', 'original_file_path': 'events.sql', @@ -1720,6 +1631,9 @@ def test__process_refs__packages(self): 'raw_sql': 'does not matter', 'columns': {}, 'description': '', + 'build_path': None, + 'index': None, + 'patch_path': None, }, 'model.root.events': { 'name': 'events', @@ -1729,15 +1643,15 @@ def test__process_refs__packages(self): 'resource_type': 'model', 'unique_id': 'model.root.events', 'fqn': ['root', 'events'], - 'empty': False, 'package_name': 'root', + 'docrefs': [], 'refs': [], 'sources': [], 'depends_on': { 'nodes': [], 'macros': [] }, - 'config': self.model_config, + 'config': self.model_config.to_dict(), 'tags': [], 'path': 'events.sql', 'original_file_path': 'events.sql', @@ -1745,6 +1659,9 @@ def test__process_refs__packages(self): 'raw_sql': 'does not matter', 'columns': {}, 'description': '', + 'build_path': None, + 'index': None, + 'patch_path': None, }, 'model.root.dep': { 'name': 'dep', @@ -1754,15 +1671,15 @@ def test__process_refs__packages(self): 'resource_type': 'model', 'unique_id': 'model.root.dep', 'fqn': ['root', 'dep'], - 'empty': False, 'package_name': 'root', + 'docrefs': [], 'refs': [['events']], 'sources': [], 'depends_on': { 'nodes': ['model.root.events'], 'macros': [] }, - 'config': self.model_config, + 'config': self.model_config.to_dict(), 'tags': [], 'path': 'multi.sql', 'original_file_path': 'multi.sql', @@ -1770,6 +1687,9 @@ def test__process_refs__packages(self): 'raw_sql': 'does not matter', 'columns': {}, 'description': '', + 'build_path': None, + 'index': None, + 'patch_path': None, } } } @@ -1787,9 +1707,7 @@ def test__in_model_config(self): "select * from events"), }] - self.model_config.update({ - 'materialized': 'table' - }) + self.model_config = self.model_config.replace(materialized='table') parser = ModelParser( self.root_project_config, @@ -1805,17 +1723,13 @@ def test__in_model_config(self): name='model_one', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.model_one', fqn=['root', 'model_one'], - empty=False, package_name='root', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [], - }, + depends_on=DependsOn(), config=self.model_config, tags=[], root_path=get_os_path('/usr/src/app'), @@ -1867,19 +1781,9 @@ def test__root_project_config(self): 'raw_sql': ("select * from events"), }] - self.model_config.update({ - 'materialized': 'table' - }) - - ephemeral_config = self.model_config.copy() - ephemeral_config.update({ - 'materialized': 'ephemeral' - }) - - view_config = self.model_config.copy() - view_config.update({ - 'materialized': 'view' - }) + self.model_config = self.model_config.replace(materialized='table') + ephemeral_config = self.model_config.replace(materialized='ephemeral') + view_config = self.model_config.replace(materialized='view') parser = ModelParser( self.root_project_config, @@ -1895,17 +1799,13 @@ def test__root_project_config(self): name='table', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.table', fqn=['root', 'table'], - empty=False, package_name='root', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), path='table.sql', original_file_path='table.sql', config=self.model_config, @@ -1921,17 +1821,13 @@ def test__root_project_config(self): name='ephemeral', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.ephemeral', fqn=['root', 'ephemeral'], - empty=False, package_name='root', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), path='ephemeral.sql', original_file_path='ephemeral.sql', config=ephemeral_config, @@ -1947,17 +1843,13 @@ def test__root_project_config(self): name='view', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.view', fqn=['root', 'view'], - empty=False, package_name='root', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), path='view.sql', original_file_path='view.sql', root_path=get_os_path('/usr/src/app'), @@ -2056,39 +1948,27 @@ def test__other_project_config(self): 'raw_sql': ("select * from events"), }] - self.model_config.update({ - 'materialized': 'table' - }) - - ephemeral_config = self.model_config.copy() - ephemeral_config.update({ - 'materialized': 'ephemeral' - }) - - view_config = self.model_config.copy() - view_config.update({ - 'materialized': 'view' - }) + self.model_config = self.model_config.replace(materialized='table') - disabled_config = self.model_config.copy() - disabled_config.update({ - 'enabled': False, - 'materialized': 'ephemeral' - }) - - - sort_config = self.model_config.copy() - sort_config.update({ - 'enabled': False, - 'materialized': 'view', - 'sort': 'timestamp', - }) - - multi_sort_config = self.model_config.copy() - multi_sort_config.update({ - 'materialized': 'table', - 'sort': ['timestamp', 'id'] - }) + ephemeral_config = self.model_config.replace( + materialized='ephemeral' + ) + view_config = self.model_config.replace( + materialized='view' + ) + disabled_config = self.model_config.replace( + materialized='ephemeral', + enabled=False, + ) + sort_config = self.model_config.replace( + materialized='view', + enabled=False, + sort='timestamp', + ) + multi_sort_config = self.model_config.replace( + materialized='table', + sort=['timestamp', 'id'], + ) parser = ModelParser( self.root_project_config, @@ -2104,17 +1984,13 @@ def test__other_project_config(self): name='table', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.table', fqn=['root', 'table'], - empty=False, package_name='root', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), path='table.sql', original_file_path='table.sql', root_path=get_os_path('/usr/src/app'), @@ -2130,17 +2006,13 @@ def test__other_project_config(self): name='ephemeral', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.ephemeral', fqn=['root', 'ephemeral'], - empty=False, package_name='root', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), path='ephemeral.sql', original_file_path='ephemeral.sql', root_path=get_os_path('/usr/src/app'), @@ -2156,17 +2028,13 @@ def test__other_project_config(self): name='view', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.view', fqn=['root', 'view'], - empty=False, package_name='root', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), path='view.sql', original_file_path='view.sql', root_path=get_os_path('/usr/src/app'), @@ -2182,17 +2050,13 @@ def test__other_project_config(self): name='multi_sort', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.snowplow.multi_sort', fqn=['snowplow', 'views', 'multi_sort'], - empty=False, package_name='snowplow', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), path=get_os_path('views/multi_sort.sql'), original_file_path=get_os_path('views/multi_sort.sql'), root_path=get_os_path('/usr/src/app'), @@ -2207,7 +2071,7 @@ def test__other_project_config(self): disabled=[ ParsedNode( name='disabled', - resource_type='model', + resource_type=NodeType.Model, package_name='snowplow', path='disabled.sql', original_file_path='disabled.sql', @@ -2217,13 +2081,9 @@ def test__other_project_config(self): schema='analytics', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=disabled_config, tags=[], - empty=False, alias='disabled', unique_id='model.snowplow.disabled', fqn=['snowplow', 'disabled'], @@ -2231,7 +2091,7 @@ def test__other_project_config(self): ), ParsedNode( name='package', - resource_type='model', + resource_type=NodeType.Model, package_name='snowplow', path=get_os_path('views/package.sql'), original_file_path=get_os_path('views/package.sql'), @@ -2241,13 +2101,9 @@ def test__other_project_config(self): schema='analytics', refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=sort_config, tags=[], - empty=False, alias='package', unique_id='model.snowplow.package', fqn=['snowplow', 'views', 'package'], @@ -2281,17 +2137,13 @@ def test__simple_data_test(self): name='no_events', database='test', schema='analytics', - resource_type='test', + resource_type=NodeType.Test, unique_id='test.root.no_events', fqn=['root', 'no_events'], - empty=False, package_name='root', refs=[['base']], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.test_config, path='no_events.sql', original_file_path='no_events.sql', @@ -2312,7 +2164,7 @@ def test__simple_macro(self): {{a}} + {{b}} {% endmacro %} """ - parser = MacroParser(None, None) + parser = MacroParser(None, {}) result = parser.parse_macro_file( macro_file_path='simple_macro.sql', macro_file_contents=macro_file_contents, @@ -2326,7 +2178,7 @@ def test__simple_macro(self): self.assertEqual( result, { - 'macro.root.simple': ParsedMacro(**{ + 'macro.root.simple': ParsedMacro.from_dict({ 'name': 'simple', 'resource_type': 'macro', 'unique_id': 'macro.root.simple', @@ -2349,7 +2201,7 @@ def test__simple_macro_used_in_model(self): {{a}} + {{b}} {% endmacro %} """ - parser = MacroParser(None, None) + parser = MacroParser(None, {}) result = parser.parse_macro_file( macro_file_path='simple_macro.sql', macro_file_contents=macro_file_contents, @@ -2362,7 +2214,7 @@ def test__simple_macro_used_in_model(self): self.assertEqual( result, { - 'macro.root.simple': ParsedMacro(**{ + 'macro.root.simple': ParsedMacro.from_dict({ 'name': 'simple', 'resource_type': 'macro', 'unique_id': 'macro.root.simple', @@ -2403,19 +2255,15 @@ def test__simple_macro_used_in_model(self): name='model_one', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.model_one', fqn=['root', 'model_one'], - empty=False, package_name='root', original_file_path='model_one.sql', root_path=get_os_path('/usr/src/app'), refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='model_one.sql', @@ -2453,18 +2301,14 @@ def test__macro_no_explicit_project_used_in_model(self): name='model_one', database='test', schema='analytics', - resource_type='model', + resource_type=NodeType.Model, unique_id='model.root.model_one', fqn=['root', 'model_one'], - empty=False, package_name='root', root_path=get_os_path('/usr/src/app'), refs=[], sources=[], - depends_on={ - 'nodes': [], - 'macros': [] - }, + depends_on=DependsOn(), config=self.model_config, tags=[], path='model_one.sql', diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index e15b00a220f..be8c75fb5cc 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -110,9 +110,7 @@ def test_default_keepalive(self, psycopg2): @mock.patch('dbt.adapters.postgres.connections.psycopg2') def test_changed_keepalive(self, psycopg2): - self.config.credentials = self.config.credentials.incorporate( - keepalives_idle=256 - ) + self.config.credentials = self.config.credentials.replace(keepalives_idle=256) connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( @@ -126,9 +124,7 @@ def test_changed_keepalive(self, psycopg2): @mock.patch('dbt.adapters.postgres.connections.psycopg2') def test_search_path(self, psycopg2): - self.config.credentials = self.config.credentials.incorporate( - search_path="test" - ) + self.config.credentials = self.config.credentials.replace(search_path="test") connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( @@ -142,9 +138,7 @@ def test_search_path(self, psycopg2): @mock.patch('dbt.adapters.postgres.connections.psycopg2') def test_schema_with_space(self, psycopg2): - self.config.credentials = self.config.credentials.incorporate( - search_path="test test" - ) + self.config.credentials = self.config.credentials.replace(search_path="test test") connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( @@ -158,9 +152,7 @@ def test_schema_with_space(self, psycopg2): @mock.patch('dbt.adapters.postgres.connections.psycopg2') def test_set_zero_keepalive(self, psycopg2): - self.config.credentials = self.config.credentials.incorporate( - keepalives_idle=0 - ) + self.config.credentials = self.config.credentials.replace(keepalives_idle=0) connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( diff --git a/test/unit/test_redshift_adapter.py b/test/unit/test_redshift_adapter.py index dbcd597084d..79a06320b95 100644 --- a/test/unit/test_redshift_adapter.py +++ b/test/unit/test_redshift_adapter.py @@ -6,7 +6,7 @@ import dbt.utils from dbt.adapters.redshift import RedshiftAdapter -from dbt.exceptions import ValidationException, FailedToConnectException +from dbt.exceptions import FailedToConnectException from dbt.logger import GLOBAL_LOGGER as logger # noqa from .utils import config_from_parts_or_dicts, mock_connection @@ -70,7 +70,7 @@ def test_explicit_database_conn(self): self.assertEqual(creds, self.config.credentials) def test_explicit_iam_conn(self): - self.config.credentials = self.config.credentials.incorporate( + self.config.credentials = self.config.credentials.replace( method='iam', cluster_id='my_redshift', iam_duration_seconds=1200 @@ -79,24 +79,22 @@ def test_explicit_iam_conn(self): with mock.patch.object(RedshiftAdapter.ConnectionManager, 'fetch_cluster_credentials', new=fetch_cluster_credentials): creds = RedshiftAdapter.ConnectionManager.get_credentials(self.config.credentials) - expected_creds = self.config.credentials.incorporate(password='tmp_password') + expected_creds = self.config.credentials.replace(password='tmp_password') self.assertEqual(creds, expected_creds) def test_invalid_auth_method(self): # we have to set method this way, otherwise it won't validate - self.config.credentials._contents['method'] = 'badmethod' + self.config.credentials.method = 'badmethod' - with self.assertRaises(dbt.exceptions.FailedToConnectException) as context: + with self.assertRaises(FailedToConnectException) as context: with mock.patch.object(RedshiftAdapter.ConnectionManager, 'fetch_cluster_credentials', new=fetch_cluster_credentials): RedshiftAdapter.ConnectionManager.get_credentials(self.config.credentials) self.assertTrue('badmethod' in context.exception.msg) def test_invalid_iam_no_cluster_id(self): - self.config.credentials = self.config.credentials.incorporate( - method='iam' - ) - with self.assertRaises(dbt.exceptions.FailedToConnectException) as context: + self.config.credentials = self.config.credentials.replace(method='iam') + with self.assertRaises(FailedToConnectException) as context: with mock.patch.object(RedshiftAdapter.ConnectionManager, 'fetch_cluster_credentials', new=fetch_cluster_credentials): RedshiftAdapter.ConnectionManager.get_credentials(self.config.credentials) @@ -141,14 +139,12 @@ def test_default_keepalive(self, psycopg2): password='password', port=5439, connect_timeout=10, - keepalives_idle=RedshiftAdapter.ConnectionManager.DEFAULT_TCP_KEEPALIVE + keepalives_idle=240 ) @mock.patch('dbt.adapters.postgres.connections.psycopg2') def test_changed_keepalive(self, psycopg2): - self.config.credentials = self.config.credentials.incorporate( - keepalives_idle=256 - ) + self.config.credentials = self.config.credentials.replace(keepalives_idle=256) connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( @@ -162,9 +158,7 @@ def test_changed_keepalive(self, psycopg2): @mock.patch('dbt.adapters.postgres.connections.psycopg2') def test_search_path(self, psycopg2): - self.config.credentials = self.config.credentials.incorporate( - search_path="test" - ) + self.config.credentials = self.config.credentials.replace(search_path="test") connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( @@ -175,13 +169,11 @@ def test_search_path(self, psycopg2): port=5439, connect_timeout=10, options="-c search_path=test", - keepalives_idle=RedshiftAdapter.ConnectionManager.DEFAULT_TCP_KEEPALIVE) + keepalives_idle=240) @mock.patch('dbt.adapters.postgres.connections.psycopg2') def test_search_path_with_space(self, psycopg2): - self.config.credentials = self.config.credentials.incorporate( - search_path="test test" - ) + self.config.credentials = self.config.credentials.replace(search_path="test test") connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( @@ -192,13 +184,11 @@ def test_search_path_with_space(self, psycopg2): port=5439, connect_timeout=10, options="-c search_path=test\ test", - keepalives_idle=RedshiftAdapter.ConnectionManager.DEFAULT_TCP_KEEPALIVE) + keepalives_idle=240) @mock.patch('dbt.adapters.postgres.connections.psycopg2') def test_set_zero_keepalive(self, psycopg2): - self.config.credentials = self.config.credentials.incorporate( - keepalives_idle=0 - ) + self.config.credentials = self.config.credentials.replace(keepalives_idle=0) connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( diff --git a/test/unit/test_snowflake_adapter.py b/test/unit/test_snowflake_adapter.py index 40a3e0d0faa..590dd820541 100644 --- a/test/unit/test_snowflake_adapter.py +++ b/test/unit/test_snowflake_adapter.py @@ -164,8 +164,8 @@ def test_client_session_keep_alive_false_by_default(self): ]) def test_client_session_keep_alive_true(self): - self.config.credentials = self.config.credentials.incorporate( - client_session_keep_alive=True) + self.config.credentials = self.config.credentials.replace( + client_session_keep_alive=True) self.adapter = SnowflakeAdapter(self.config) self.adapter.connections.set_connection_name(name='new_connection_with_new_config') @@ -178,8 +178,9 @@ def test_client_session_keep_alive_true(self): ]) def test_user_pass_authentication(self): - self.config.credentials = self.config.credentials.incorporate( - password='test_password') + self.config.credentials = self.config.credentials.replace( + password='test_password', + ) self.adapter = SnowflakeAdapter(self.config) self.adapter.connections.set_connection_name(name='new_connection_with_new_config') @@ -192,8 +193,10 @@ def test_user_pass_authentication(self): ]) def test_authenticator_user_pass_authentication(self): - self.config.credentials = self.config.credentials.incorporate( - password='test_password', authenticator='test_sso_url') + self.config.credentials = self.config.credentials.replace( + password='test_password', + authenticator='test_sso_url', + ) self.adapter = SnowflakeAdapter(self.config) self.adapter.connections.set_connection_name(name='new_connection_with_new_config') @@ -207,8 +210,9 @@ def test_authenticator_user_pass_authentication(self): ]) def test_authenticator_externalbrowser_authentication(self): - self.config.credentials = self.config.credentials.incorporate( - authenticator='externalbrowser') + self.config.credentials = self.config.credentials.replace( + authenticator='externalbrowser' + ) self.adapter = SnowflakeAdapter(self.config) self.adapter.connections.set_connection_name(name='new_connection_with_new_config') @@ -221,11 +225,12 @@ def test_authenticator_externalbrowser_authentication(self): private_key=None) ]) - @mock.patch('dbt.adapters.snowflake.SnowflakeConnectionManager._get_private_key', return_value='test_key') + @mock.patch('dbt.adapters.snowflake.SnowflakeCredentials._get_private_key', return_value='test_key') def test_authenticator_private_key_authentication(self, mock_get_private_key): - self.config.credentials = self.config.credentials.incorporate( + self.config.credentials = self.config.credentials.replace( private_key_path='/tmp/test_key.p8', - private_key_passphrase='p@ssphr@se') + private_key_passphrase='p@ssphr@se', + ) self.adapter = SnowflakeAdapter(self.config) self.adapter.connections.set_connection_name(name='new_connection_with_new_config') diff --git a/test/unit/utils.py b/test/unit/utils.py index 1e3fd487a72..ffc0529f790 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -4,6 +4,9 @@ issues. """ from unittest import mock +from unittest import TestCase + +from hologram import ValidationError class Obj: @@ -46,3 +49,30 @@ def inject_adapter(value): key = value.type() factory._ADAPTERS[key] = value factory.ADAPTER_TYPES[key] = type(value) + + +class ContractTestCase(TestCase): + ContractType = None + + def setUp(self): + self.maxDiff = None + super().setUp() + + def assert_to_dict(self, obj, dct): + self.assertEqual(obj.to_dict(), dct) + + def assert_from_dict(self, obj, dct, cls=None): + if cls is None: + cls = self.ContractType + self.assertEqual(cls.from_dict(dct), obj) + + def assert_symmetric(self, obj, dct, cls=None): + self.assert_to_dict(obj, dct) + self.assert_from_dict(obj, dct, cls) + + def assert_fails_validation(self, dct, cls=None): + if cls is None: + cls = self.ContractType + + with self.assertRaises(ValidationError): + cls.from_dict(dct)