Skip to content

Commit

Permalink
Merge pull request #1429 from fishtown-analytics/fix/vars-in-disabled…
Browse files Browse the repository at this point in the history
…-models

Fix: missing vars in disabled models fail compilation (#434)
  • Loading branch information
beckjake authored Apr 30, 2019
2 parents aa4f771 + d57f4c5 commit ad2f228
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 18 deletions.
41 changes: 23 additions & 18 deletions core/dbt/context/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,30 +239,33 @@ def __init__(self, model, context, overrides):
def pretty_dict(self, data):
return json.dumps(data, sort_keys=True, indent=4)

def get_missing_var(self, var_name):
pretty_vars = self.pretty_dict(self.local_vars)
msg = self.UndefinedVarError.format(
var_name, self.model_name, pretty_vars
)
dbt.exceptions.raise_compiler_error(msg, self.model)

def assert_var_defined(self, var_name, default):
if var_name not in self.local_vars and default is self._VAR_NOTSET:
pretty_vars = self.pretty_dict(self.local_vars)
dbt.exceptions.raise_compiler_error(
self.UndefinedVarError.format(
var_name, self.model_name, pretty_vars
),
self.model
)

def __call__(self, var_name, default=_VAR_NOTSET):
self.assert_var_defined(var_name, default)

if var_name not in self.local_vars:
return default
return self.get_missing_var(var_name)

def get_rendered_var(self, var_name):
raw = self.local_vars[var_name]

# if bool/int/float/etc are passed in, don't compile anything
if not isinstance(raw, basestring):
return raw

return dbt.clients.jinja.get_rendered(raw, self.context)

def __call__(self, var_name, default=_VAR_NOTSET):
if var_name in self.local_vars:
return self.get_rendered_var(var_name)
elif default is not self._VAR_NOTSET:
return default
else:
return self.get_missing_var(var_name)


def write(node, target_path, subdirectory):
def fn(payload):
Expand Down Expand Up @@ -395,7 +398,8 @@ def generate_base(model, model_dict, config, manifest, source_config,
return context


def modify_generated_context(context, model, model_dict, config, manifest):
def modify_generated_context(context, model, model_dict, config, manifest,
provider):
cli_var_overrides = config.cli_vars

context = _add_tracking(context)
Expand All @@ -408,7 +412,8 @@ def modify_generated_context(context, model, model_dict, config, manifest):

context["write"] = write(model_dict, config.target_path, 'run')
context["render"] = render(context, model_dict)
context["var"] = Var(model, context=context, overrides=cli_var_overrides)
context["var"] = provider.Var(model, context=context,
overrides=cli_var_overrides)
context['context'] = context

return context
Expand All @@ -427,7 +432,7 @@ def generate_execute_macro(model, config, manifest, provider):
provider)

return modify_generated_context(context, model, model_dict, config,
manifest)
manifest, provider)


def generate_model(model, config, manifest, source_config, provider):
Expand All @@ -448,7 +453,7 @@ def generate_model(model, config, manifest, source_config, provider):
})

return modify_generated_context(context, model, model_dict, config,
manifest)
manifest, provider)


def generate(model, config, manifest, source_config=None, provider=None):
Expand Down
6 changes: 6 additions & 0 deletions core/dbt/context/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def get(self, name, validator=None, default=None):
return ''


class Var(dbt.context.common.Var):
def get_missing_var(self, var_name):
# in the parser, just always return None.
return None


def generate(model, runtime_config, manifest, source_config):
# during parsing, we don't have a connection, but we might need one, so we
# have to acquire it.
Expand Down
4 changes: 4 additions & 0 deletions core/dbt/context/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def get(self, name, validator=None, default=None):
return to_return


class Var(dbt.context.common.Var):
pass


def generate(model, runtime_config, manifest):
return dbt.context.common.generate(
model, runtime_config, manifest, None, dbt.context.runtime)
Expand Down
20 changes: 20 additions & 0 deletions test/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

from dbt.contracts.graph.parsed import ParsedNode
from dbt.context.common import Var
from dbt.context.parser import Var as ParserVar
import dbt.exceptions


class TestVar(unittest.TestCase):
def setUp(self):
self.model = ParsedNode(
Expand Down Expand Up @@ -59,3 +61,21 @@ def test_var_not_defined(self):
self.assertEqual(var('foo', 'bar'), 'bar')
with self.assertRaises(dbt.exceptions.CompilationException):
var('foo')

def test_parser_var_default_something(self):
var = ParserVar(self.model, self.context, overrides={'foo': 'baz'})
self.assertEqual(var('foo'), 'baz')
self.assertEqual(var('foo', 'bar'), 'baz')

def test_parser_var_default_none(self):
var = ParserVar(self.model, self.context, overrides={'foo': None})
self.assertEqual(var('foo'), None)
self.assertEqual(var('foo', 'bar'), None)

def test_parser_var_not_defined(self):
# at parse-time, we should not raise if we encounter a missing var
# that way disabled models don't get parse errors
var = ParserVar(self.model, self.context, overrides={})

self.assertEqual(var('foo', 'bar'), 'bar')
self.assertEqual(var('foo'), None)

0 comments on commit ad2f228

Please sign in to comment.