diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index a0dbdd80e7c..1d6b83c7028 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -17,7 +17,8 @@ from dbt.schema import Column from dbt.utils import filter_null_values -from dbt.adapters.base.meta import AdapterMeta, available, available_deprecated + +from dbt.adapters.base.meta import AdapterMeta, available from dbt.adapters.base import BaseRelation from dbt.adapters.cache import RelationsCache @@ -221,7 +222,7 @@ def connection_named(self, name): finally: self.release_connection() - @available + @available.parse(lambda *a, **k: ('', dbt.clients.agate_helper())) def execute(self, sql, auto_begin=False, fetch=False): """Execute the given SQL. This is a thin wrapper around ConnectionManager.execute. @@ -404,7 +405,7 @@ def check_schema_exists(self, database, schema): # Abstract methods about relations ### @abc.abstractmethod - @available + @available.parse_none def drop_relation(self, relation): """Drop the given relation. @@ -417,7 +418,7 @@ def drop_relation(self, relation): ) @abc.abstractmethod - @available + @available.parse_none def truncate_relation(self, relation): """Truncate the given relation. @@ -428,7 +429,7 @@ def truncate_relation(self, relation): ) @abc.abstractmethod - @available + @available.parse_none def rename_relation(self, from_relation, to_relation): """Rename the relation from from_relation to to_relation. @@ -442,7 +443,7 @@ def rename_relation(self, from_relation, to_relation): ) @abc.abstractmethod - @available + @available.parse_list def get_columns_in_relation(self, relation): """Get a list of the columns in the given Relation. @@ -454,7 +455,7 @@ def get_columns_in_relation(self, relation): '`get_columns_in_relation` is not implemented for this adapter!' ) - @available_deprecated('get_columns_in_relation') + @available.deprecated('get_columns_in_relation', lambda *a, **k: []) def get_columns_in_table(self, schema, identifier): """DEPRECATED: Get a list of the columns in the given table.""" relation = self.Relation.create( @@ -488,7 +489,7 @@ def list_relations_without_caching(self, information_schema, schema): relations from. :param str schema: The name of the schema to list relations from. :return: The relations in schema - :retype: List[self.Relation] + :rtype: List[self.Relation] """ raise dbt.exceptions.NotImplementedException( '`list_relations_without_caching` is not implemented for this ' @@ -498,10 +499,17 @@ def list_relations_without_caching(self, information_schema, schema): ### # Provided methods about relations ### - @available + @available.parse_list def get_missing_columns(self, from_relation, to_relation): - """Returns dict of {column:type} for columns in from_table that are - missing from to_relation + """Returns a list of Columns in from_relation that are missing from + to_relation. + + :param Relation from_relation: The relation that might have extra + columns + :param Relation to_relation: The realtion that might have columns + missing + :return: The columns in from_relation that are missing from to_relation + :rtype: List[self.Relation] """ if not isinstance(from_relation, self.Relation): dbt.exceptions.invalid_type_error( @@ -534,7 +542,7 @@ def get_missing_columns(self, from_relation, to_relation): if col_name in missing_columns ] - @available + @available.parse_none def expand_target_column_types(self, temp_table, to_relation): if not isinstance(to_relation, self.Relation): dbt.exceptions.invalid_type_error( @@ -600,7 +608,7 @@ def _make_match(self, relations_list, database, schema, identifier): return matches - @available + @available.parse_none def get_relation(self, database, schema, identifier): relations_list = self.list_relations(database, schema) @@ -622,7 +630,7 @@ def get_relation(self, database, schema, identifier): return None - @available_deprecated('get_relation') + @available.deprecated('get_relation', lambda *a, **k: False) def already_exists(self, schema, name): """DEPRECATED: Return if a model already exists in the database""" database = self.config.credentials.database @@ -634,7 +642,7 @@ def already_exists(self, schema, name): # although some adapters may override them ### @abc.abstractmethod - @available + @available.parse_none def create_schema(self, database, schema): """Create the given schema if it does not exist. diff --git a/core/dbt/adapters/base/meta.py b/core/dbt/adapters/base/meta.py index 14201c93563..888c4c87a33 100644 --- a/core/dbt/adapters/base/meta.py +++ b/core/dbt/adapters/base/meta.py @@ -3,26 +3,42 @@ from dbt.deprecations import warn, renamed_method +def _always_none(*args, **kwargs): + return None + + +def _always_list(*args, **kwargs): + return None + + def available(func): - """A decorator to indicate that a method on the adapter will be exposed to - the database wrapper, and the model name will be injected into the - arguments. + """A decorator to indicate that a method on the adapter will be + exposed to the database wrapper, and will be available at parse and run + time. """ func._is_available_ = True return func -def available_deprecated(supported_name): +def available_deprecated(supported_name, parse_replacement=None): """A decorator that marks a function as available, but also prints a deprecation warning. Use like @available_deprecated('my_new_method') - def my_old_method(self, arg, model_name=None): + def my_old_method(self, arg): args = compatability_shim(arg) - return self.my_new_method(*args, model_name=None) + return self.my_new_method(*args) + + @available_deprecated('my_new_slow_method', lambda *a, **k: (0, '')) + def my_old_slow_method(self, arg): + args = compatibility_shim(arg) + return self.my_new_slow_method(*args) To make `adapter.my_old_method` available but also print out a warning on use directing users to `my_new_method`. + + The optional parse_replacement, if provided, will provide a parse-time + replacement for the actual method (see `available_parse`). """ def wrapper(func): func_name = func.__name__ @@ -32,10 +48,43 @@ def wrapper(func): def inner(*args, **kwargs): warn('adapter:{}'.format(func_name)) return func(*args, **kwargs) + + if parse_replacement: + available = available_parse(parse_replacement) return available(inner) return wrapper +def available_parse(parse_replacement): + """A decorator factory to indicate that a method on the adapter will be + exposed to the database wrapper, and will be stubbed out at parse time with + the given function. + + @available_parse() + def my_method(self, a, b): + if something: + return None + return big_expensive_db_query() + + @available_parse(lambda *args, **args: {}) + def my_other_method(self, a, b): + x = {} + x.update(big_expensive_db_query()) + return x + """ + def inner(func): + func._parse_replacement_ = parse_replacement + available(func) + return func + return inner + + +available.deprecated = available_deprecated +available.parse = available_parse +available.parse_none = available_parse(lambda *a, **k: None) +available.parse_list = available_parse(lambda *a, **k: []) + + class AdapterMeta(abc.ABCMeta): def __new__(mcls, name, bases, namespace, **kwargs): cls = super(AdapterMeta, mcls).__new__(mcls, name, bases, namespace, @@ -47,15 +96,22 @@ def __new__(mcls, name, bases, namespace, **kwargs): # injected into the arguments. All methods in here are exposed to the # context. available = set() + replacements = {} # collect base class data first for base in bases: available.update(getattr(base, '_available_', set())) + replacements.update(getattr(base, '_parse_replacements_', set())) # override with local data if it exists for name, value in namespace.items(): if getattr(value, '_is_available_', False): available.add(name) + parse_replacement = getattr(value, '_parse_replacement_', None) + if parse_replacement is not None: + replacements[name] = parse_replacement cls._available_ = frozenset(available) + # should this be a namedtuple so it will be immutable like _available_? + cls._parse_replacements_ = replacements return cls diff --git a/core/dbt/adapters/sql/impl.py b/core/dbt/adapters/sql/impl.py index 245b812def1..2f9058a22c6 100644 --- a/core/dbt/adapters/sql/impl.py +++ b/core/dbt/adapters/sql/impl.py @@ -35,7 +35,7 @@ class SQLAdapter(BaseAdapter): - list_relations_without_caching - get_columns_in_relation """ - @available + @available.parse(lambda *a, **k: (None, None)) def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False): """Add a query to the current transaction. A thin wrapper around diff --git a/core/dbt/context/common.py b/core/dbt/context/common.py index ca4f71ff209..80175d4b8df 100644 --- a/core/dbt/context/common.py +++ b/core/dbt/context/common.py @@ -47,7 +47,7 @@ def create(self, *args, **kwargs): return self.relation_type.create(*args, **kwargs) -class DatabaseWrapper(object): +class BaseDatabaseWrapper(object): """ Wrapper for runtime database interaction. Applies the runtime quote policy via a relation proxy. @@ -57,14 +57,7 @@ def __init__(self, adapter): self.Relation = RelationProxy(adapter) def __getattr__(self, name): - if name in self.adapter._available_: - return getattr(self.adapter, name) - else: - raise AttributeError( - "'{}' object has no attribute '{}'".format( - self.__class__.__name__, name - ) - ) + raise NotImplementedError('subclasses need to implement this') @property def config(self): @@ -371,7 +364,7 @@ def generate_base(model, model_dict, config, manifest, source_config, pre_hooks = None post_hooks = None - db_wrapper = DatabaseWrapper(adapter) + db_wrapper = provider.DatabaseWrapper(adapter) context = dbt.utils.merge(context, { "adapter": db_wrapper, diff --git a/core/dbt/context/parser.py b/core/dbt/context/parser.py index 3d2a8da5d78..9cb01a3d345 100644 --- a/core/dbt/context/parser.py +++ b/core/dbt/context/parser.py @@ -1,6 +1,7 @@ import dbt.exceptions import dbt.context.common +from dbt.clients.agate_helper import empty_table from dbt.adapters.factory import get_adapter @@ -97,6 +98,26 @@ def get(self, name, validator=None, default=None): return '' +class DatabaseWrapper(dbt.context.common.BaseDatabaseWrapper): + """The parser subclass of the database wrapper applies any explicit + parse-time overrides. + """ + def __getattr__(self, name): + override = (name in self.adapter._available_ and + name in self.adapter._parse_replacements_) + + if override: + return self.adapter._parse_replacements_[name] + elif name in self.adapter._available_: + return getattr(self.adapter, name) + else: + raise AttributeError( + "'{}' object has no attribute '{}'".format( + self.__class__.__name__, name + ) + ) + + def generate(model, runtime_config, manifest, source_config): # during parsing, we don't have a connection, but we might need one, so we # have to acquire it. diff --git a/core/dbt/context/runtime.py b/core/dbt/context/runtime.py index 2fc7b32cddb..adbaa968676 100644 --- a/core/dbt/context/runtime.py +++ b/core/dbt/context/runtime.py @@ -118,6 +118,21 @@ def get(self, name, validator=None, default=None): return to_return +class DatabaseWrapper(dbt.context.common.BaseDatabaseWrapper): + """The runtime database wrapper exposes everything the adapter marks + available. + """ + def __getattr__(self, name): + if name in self.adapter._available_: + return getattr(self.adapter, name) + else: + raise AttributeError( + "'{}' object has no attribute '{}'".format( + self.__class__.__name__, name + ) + ) + + def generate(model, runtime_config, manifest): return dbt.context.common.generate( model, runtime_config, manifest, None, dbt.context.runtime) diff --git a/plugins/bigquery/dbt/adapters/bigquery/impl.py b/plugins/bigquery/dbt/adapters/bigquery/impl.py index bb4c45f7f59..610a3da4b3c 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/impl.py +++ b/plugins/bigquery/dbt/adapters/bigquery/impl.py @@ -39,6 +39,16 @@ def column_to_bq_schema(col): **kwargs) +def _stub_relation(*args, **kwargs): + return BigQueryRelation.create( + database='', + schema='', + identifier='', + quote_policy={}, + type=BigQueryRelation.Table + ) + + class BigQueryAdapter(BaseAdapter): RELATION_TYPES = { @@ -310,7 +320,7 @@ def add_query(self, sql, auto_begin=True, bindings=None, ### # Special bigquery adapter methods ### - @available + @available.parse_none def make_date_partitioned_table(self, relation): return self.connections.create_date_partitioned_table( database=relation.database, @@ -318,7 +328,7 @@ def make_date_partitioned_table(self, relation): table_name=relation.identifier ) - @available + @available.parse(lambda *a, **k: '') def execute_model(self, model, materialization, sql_override=None, decorator=None): @@ -340,9 +350,8 @@ def execute_model(self, model, materialization, sql_override=None, return res - @available + @available.parse(_stub_relation) def create_temporary_table(self, sql, **kwargs): - # BQ queries always return a temp table with their results query_job, _ = self.connections.raw_execute(sql) bq_table = query_job.destination @@ -357,7 +366,7 @@ def create_temporary_table(self, sql, **kwargs): }, type=BigQueryRelation.Table) - @available + @available.parse_none def alter_table_add_columns(self, relation, columns): logger.debug('Adding columns ({}) to table {}".'.format( @@ -377,7 +386,7 @@ def alter_table_add_columns(self, relation, columns): new_table = google.cloud.bigquery.Table(table_ref, schema=new_schema) client.update_table(new_table, ['schema']) - @available + @available.parse_none def load_dataframe(self, database, schema, table_name, agate_table, column_override): bq_schema = self._agate_to_schema(agate_table, column_override) diff --git a/test/unit/mock_adapter.py b/test/unit/mock_adapter.py new file mode 100644 index 00000000000..f80ed8347ca --- /dev/null +++ b/test/unit/mock_adapter.py @@ -0,0 +1,81 @@ +import mock + +from dbt.adapters.base import BaseAdapter +from contextlib import contextmanager + + +def adapter_factory(): + class MockAdapter(BaseAdapter): + ConnectionManager = mock.MagicMock(TYPE='mock') + responder = mock.MagicMock() + # some convenient defaults + responder.quote.side_effect = lambda identifier: '"{}"'.format(identifier) + responder.date_function.side_effect = lambda: 'unitdate()' + responder.is_cancelable.side_effect = lambda: False + + @contextmanager + def exception_handler(self, *args, **kwargs): + self.responder.exception_handler(*args, **kwargs) + yield + + def execute(self, *args, **kwargs): + return self.responder.execute(*args, **kwargs) + + def drop_relation(self, *args, **kwargs): + return self.responder.drop_relation(*args, **kwargs) + + def truncate_relation(self, *args, **kwargs): + return self.responder.truncate_relation(*args, **kwargs) + + def rename_relation(self, *args, **kwargs): + return self.responder.rename_relation(*args, **kwargs) + + def get_columns_in_relation(self, *args, **kwargs): + return self.responder.get_columns_in_relation(*args, **kwargs) + + def expand_column_types(self, *args, **kwargs): + return self.responder.expand_column_types(*args, **kwargs) + + def list_relations_without_caching(self, *args, **kwargs): + return self.responder.list_relations_without_caching(*args, **kwargs) + + def create_schema(self, *args, **kwargs): + return self.responder.create_schema(*args, **kwargs) + + def drop_schema(self, *args, **kwargs): + return self.responder.drop_schema(*args, **kwargs) + + @classmethod + def quote(cls, identifier): + return cls.responder.quote(identifier) + + def convert_text_type(self, *args, **kwargs): + return self.responder.convert_text_type(*args, **kwargs) + + def convert_number_type(self, *args, **kwargs): + return self.responder.convert_number_type(*args, **kwargs) + + def convert_boolean_type(self, *args, **kwargs): + return self.responder.convert_boolean_type(*args, **kwargs) + + def convert_datetime_type(self, *args, **kwargs): + return self.responder.convert_datetime_type(*args, **kwargs) + + def convert_date_type(self, *args, **kwargs): + return self.responder.convert_date_type(*args, **kwargs) + + def convert_time_type(self, *args, **kwargs): + return self.responder.convert_time_type(*args, **kwargs) + + def list_schemas(self, *args, **kwargs): + return self.responder.list_schemas(*args, **kwargs) + + @classmethod + def date_function(cls): + return cls.responder.date_function() + + @classmethod + def is_cancelable(cls): + return cls.responder.is_cancelable() + + return MockAdapter diff --git a/test/unit/test_context.py b/test/unit/test_context.py index c5d9a5c99f4..1d900b70818 100644 --- a/test/unit/test_context.py +++ b/test/unit/test_context.py @@ -3,7 +3,10 @@ from dbt.contracts.graph.parsed import ParsedNode from dbt.context.common import Var +from dbt.context import parser, runtime import dbt.exceptions +from test.unit.mock_adapter import adapter_factory + class TestVar(unittest.TestCase): def setUp(self): @@ -54,3 +57,49 @@ def test_var_defined_is_missing(self): var.assert_var_defined('foo', 'bar') with self.assertRaises(dbt.exceptions.CompilationException): var.assert_var_defined('foo', None) + + +class TestParseWrapper(unittest.TestCase): + def setUp(self): + self.mock_config = mock.MagicMock() + adapter_class = adapter_factory() + self.mock_adapter = adapter_class(self.mock_config) + self.wrapper = parser.DatabaseWrapper(self.mock_adapter) + self.responder = self.mock_adapter.responder + + def test_unwrapped_method(self): + self.assertEqual(self.wrapper.quote('test_value'), '"test_value"') + self.responder.quote.assert_called_once_with('test_value') + + def test_wrapped_method(self): + found = self.wrapper.get_relation('database', 'schema', 'identifier') + self.assertEqual(found, None) + self.responder.get_relation.assert_not_called() + + +class TestRuntimeWrapper(unittest.TestCase): + def setUp(self): + self.mock_config = mock.MagicMock() + adapter_class = adapter_factory() + self.mock_adapter = adapter_class(self.mock_config) + self.wrapper = runtime.DatabaseWrapper(self.mock_adapter) + self.responder = self.mock_adapter.responder + + def test_unwrapped_method(self): + # the 'quote' method isn't wrapped, we should get our expected inputs + self.assertEqual(self.wrapper.quote('test_value'), '"test_value"') + self.responder.quote.assert_called_once_with('test_value') + + def test_wrapped_method(self): + rel = mock.MagicMock() + rel.matches.return_value = True + self.responder.list_relations_without_caching.return_value = [rel] + + found = self.wrapper.get_relation('database', 'schema', 'identifier') + + self.assertEqual(found, rel) + # it gets called with an information schema relation as the first arg, + # which is hard to mock. + self.responder.list_relations_without_caching.assert_called_once_with( + mock.ANY, 'schema' + )