Skip to content

Commit

Permalink
create a decorator for stubbing out methods at parse time
Browse files Browse the repository at this point in the history
Includes some unit tests
  • Loading branch information
Jacob Beck committed Apr 24, 2019
1 parent 32f74b6 commit 25189ed
Show file tree
Hide file tree
Showing 9 changed files with 270 additions and 38 deletions.
38 changes: 23 additions & 15 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from dbt.schema import Column
from dbt.utils import filter_null_values

from dbt.adapters.base.meta import AdapterMeta, available, available_deprecated

from dbt.adapters.base.meta import AdapterMeta, available
from dbt.adapters.base import BaseRelation
from dbt.adapters.cache import RelationsCache

Expand Down Expand Up @@ -221,7 +222,7 @@ def connection_named(self, name):
finally:
self.release_connection()

@available
@available.parse(lambda *a, **k: ('', dbt.clients.agate_helper()))
def execute(self, sql, auto_begin=False, fetch=False):
"""Execute the given SQL. This is a thin wrapper around
ConnectionManager.execute.
Expand Down Expand Up @@ -404,7 +405,7 @@ def check_schema_exists(self, database, schema):
# Abstract methods about relations
###
@abc.abstractmethod
@available
@available.parse_none
def drop_relation(self, relation):
"""Drop the given relation.
Expand All @@ -417,7 +418,7 @@ def drop_relation(self, relation):
)

@abc.abstractmethod
@available
@available.parse_none
def truncate_relation(self, relation):
"""Truncate the given relation.
Expand All @@ -428,7 +429,7 @@ def truncate_relation(self, relation):
)

@abc.abstractmethod
@available
@available.parse_none
def rename_relation(self, from_relation, to_relation):
"""Rename the relation from from_relation to to_relation.
Expand All @@ -442,7 +443,7 @@ def rename_relation(self, from_relation, to_relation):
)

@abc.abstractmethod
@available
@available.parse_list
def get_columns_in_relation(self, relation):
"""Get a list of the columns in the given Relation.
Expand All @@ -454,7 +455,7 @@ def get_columns_in_relation(self, relation):
'`get_columns_in_relation` is not implemented for this adapter!'
)

@available_deprecated('get_columns_in_relation')
@available.deprecated('get_columns_in_relation', lambda *a, **k: [])
def get_columns_in_table(self, schema, identifier):
"""DEPRECATED: Get a list of the columns in the given table."""
relation = self.Relation.create(
Expand Down Expand Up @@ -488,7 +489,7 @@ def list_relations_without_caching(self, information_schema, schema):
relations from.
:param str schema: The name of the schema to list relations from.
:return: The relations in schema
:retype: List[self.Relation]
:rtype: List[self.Relation]
"""
raise dbt.exceptions.NotImplementedException(
'`list_relations_without_caching` is not implemented for this '
Expand All @@ -498,10 +499,17 @@ def list_relations_without_caching(self, information_schema, schema):
###
# Provided methods about relations
###
@available
@available.parse_list
def get_missing_columns(self, from_relation, to_relation):
"""Returns dict of {column:type} for columns in from_table that are
missing from to_relation
"""Returns a list of Columns in from_relation that are missing from
to_relation.
:param Relation from_relation: The relation that might have extra
columns
:param Relation to_relation: The realtion that might have columns
missing
:return: The columns in from_relation that are missing from to_relation
:rtype: List[self.Relation]
"""
if not isinstance(from_relation, self.Relation):
dbt.exceptions.invalid_type_error(
Expand Down Expand Up @@ -534,7 +542,7 @@ def get_missing_columns(self, from_relation, to_relation):
if col_name in missing_columns
]

@available
@available.parse_none
def expand_target_column_types(self, temp_table, to_relation):
if not isinstance(to_relation, self.Relation):
dbt.exceptions.invalid_type_error(
Expand Down Expand Up @@ -600,7 +608,7 @@ def _make_match(self, relations_list, database, schema, identifier):

return matches

@available
@available.parse_none
def get_relation(self, database, schema, identifier):
relations_list = self.list_relations(database, schema)

Expand All @@ -622,7 +630,7 @@ def get_relation(self, database, schema, identifier):

return None

@available_deprecated('get_relation')
@available.deprecated('get_relation', lambda *a, **k: False)
def already_exists(self, schema, name):
"""DEPRECATED: Return if a model already exists in the database"""
database = self.config.credentials.database
Expand All @@ -634,7 +642,7 @@ def already_exists(self, schema, name):
# although some adapters may override them
###
@abc.abstractmethod
@available
@available.parse_none
def create_schema(self, database, schema):
"""Create the given schema if it does not exist.
Expand Down
68 changes: 62 additions & 6 deletions core/dbt/adapters/base/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,42 @@
from dbt.deprecations import warn, renamed_method


def _always_none(*args, **kwargs):
return None


def _always_list(*args, **kwargs):
return None


def available(func):
"""A decorator to indicate that a method on the adapter will be exposed to
the database wrapper, and the model name will be injected into the
arguments.
"""A decorator to indicate that a method on the adapter will be
exposed to the database wrapper, and will be available at parse and run
time.
"""
func._is_available_ = True
return func


def available_deprecated(supported_name):
def available_deprecated(supported_name, parse_replacement=None):
"""A decorator that marks a function as available, but also prints a
deprecation warning. Use like
@available_deprecated('my_new_method')
def my_old_method(self, arg, model_name=None):
def my_old_method(self, arg):
args = compatability_shim(arg)
return self.my_new_method(*args, model_name=None)
return self.my_new_method(*args)
@available_deprecated('my_new_slow_method', lambda *a, **k: (0, ''))
def my_old_slow_method(self, arg):
args = compatibility_shim(arg)
return self.my_new_slow_method(*args)
To make `adapter.my_old_method` available but also print out a warning on
use directing users to `my_new_method`.
The optional parse_replacement, if provided, will provide a parse-time
replacement for the actual method (see `available_parse`).
"""
def wrapper(func):
func_name = func.__name__
Expand All @@ -32,10 +48,43 @@ def wrapper(func):
def inner(*args, **kwargs):
warn('adapter:{}'.format(func_name))
return func(*args, **kwargs)

if parse_replacement:
available = available_parse(parse_replacement)
return available(inner)
return wrapper


def available_parse(parse_replacement):
"""A decorator factory to indicate that a method on the adapter will be
exposed to the database wrapper, and will be stubbed out at parse time with
the given function.
@available_parse()
def my_method(self, a, b):
if something:
return None
return big_expensive_db_query()
@available_parse(lambda *args, **args: {})
def my_other_method(self, a, b):
x = {}
x.update(big_expensive_db_query())
return x
"""
def inner(func):
func._parse_replacement_ = parse_replacement
available(func)
return func
return inner


available.deprecated = available_deprecated
available.parse = available_parse
available.parse_none = available_parse(lambda *a, **k: None)
available.parse_list = available_parse(lambda *a, **k: [])


class AdapterMeta(abc.ABCMeta):
def __new__(mcls, name, bases, namespace, **kwargs):
cls = super(AdapterMeta, mcls).__new__(mcls, name, bases, namespace,
Expand All @@ -47,15 +96,22 @@ def __new__(mcls, name, bases, namespace, **kwargs):
# injected into the arguments. All methods in here are exposed to the
# context.
available = set()
replacements = {}

# collect base class data first
for base in bases:
available.update(getattr(base, '_available_', set()))
replacements.update(getattr(base, '_parse_replacements_', set()))

# override with local data if it exists
for name, value in namespace.items():
if getattr(value, '_is_available_', False):
available.add(name)
parse_replacement = getattr(value, '_parse_replacement_', None)
if parse_replacement is not None:
replacements[name] = parse_replacement

cls._available_ = frozenset(available)
# should this be a namedtuple so it will be immutable like _available_?
cls._parse_replacements_ = replacements
return cls
2 changes: 1 addition & 1 deletion core/dbt/adapters/sql/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class SQLAdapter(BaseAdapter):
- list_relations_without_caching
- get_columns_in_relation
"""
@available
@available.parse(lambda *a, **k: (None, None))
def add_query(self, sql, auto_begin=True, bindings=None,
abridge_sql_log=False):
"""Add a query to the current transaction. A thin wrapper around
Expand Down
13 changes: 3 additions & 10 deletions core/dbt/context/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def create(self, *args, **kwargs):
return self.relation_type.create(*args, **kwargs)


class DatabaseWrapper(object):
class BaseDatabaseWrapper(object):
"""
Wrapper for runtime database interaction. Applies the runtime quote policy
via a relation proxy.
Expand All @@ -57,14 +57,7 @@ def __init__(self, adapter):
self.Relation = RelationProxy(adapter)

def __getattr__(self, name):
if name in self.adapter._available_:
return getattr(self.adapter, name)
else:
raise AttributeError(
"'{}' object has no attribute '{}'".format(
self.__class__.__name__, name
)
)
raise NotImplementedError('subclasses need to implement this')

@property
def config(self):
Expand Down Expand Up @@ -371,7 +364,7 @@ def generate_base(model, model_dict, config, manifest, source_config,
pre_hooks = None
post_hooks = None

db_wrapper = DatabaseWrapper(adapter)
db_wrapper = provider.DatabaseWrapper(adapter)

context = dbt.utils.merge(context, {
"adapter": db_wrapper,
Expand Down
21 changes: 21 additions & 0 deletions core/dbt/context/parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dbt.exceptions

import dbt.context.common
from dbt.clients.agate_helper import empty_table
from dbt.adapters.factory import get_adapter


Expand Down Expand Up @@ -97,6 +98,26 @@ def get(self, name, validator=None, default=None):
return ''


class DatabaseWrapper(dbt.context.common.BaseDatabaseWrapper):
"""The parser subclass of the database wrapper applies any explicit
parse-time overrides.
"""
def __getattr__(self, name):
override = (name in self.adapter._available_ and
name in self.adapter._parse_replacements_)

if override:
return self.adapter._parse_replacements_[name]
elif name in self.adapter._available_:
return getattr(self.adapter, name)
else:
raise AttributeError(
"'{}' object has no attribute '{}'".format(
self.__class__.__name__, name
)
)


def generate(model, runtime_config, manifest, source_config):
# during parsing, we don't have a connection, but we might need one, so we
# have to acquire it.
Expand Down
15 changes: 15 additions & 0 deletions core/dbt/context/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,21 @@ def get(self, name, validator=None, default=None):
return to_return


class DatabaseWrapper(dbt.context.common.BaseDatabaseWrapper):
"""The runtime database wrapper exposes everything the adapter marks
available.
"""
def __getattr__(self, name):
if name in self.adapter._available_:
return getattr(self.adapter, name)
else:
raise AttributeError(
"'{}' object has no attribute '{}'".format(
self.__class__.__name__, name
)
)


def generate(model, runtime_config, manifest):
return dbt.context.common.generate(
model, runtime_config, manifest, None, dbt.context.runtime)
Expand Down
Loading

0 comments on commit 25189ed

Please sign in to comment.