Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only create schemas for selected nodes (#1239) #1258

Merged
merged 3 commits into from
Jan 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/dbt/adapters/sql/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,4 +241,4 @@ def check_schema_exists(self, database, schema, model_name=None):
kwargs={'database': database, 'schema': schema},
connection_name=model_name
)
return results[0] > 0
return results[0][0] > 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happened here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the existing check_schema_exists was wrong and did not work.

2 changes: 1 addition & 1 deletion core/dbt/include/global_project/macros/adapters/common.sql
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@


{% macro check_schema_exists(database, schema) -%}
{{ return(adapter_macro('check_schema_exists', database)) }}
{{ return(adapter_macro('check_schema_exists', database, schema)) }}
{% endmacro %}

{% macro default__check_schema_exists(database, schema) -%}
Expand Down
14 changes: 8 additions & 6 deletions core/dbt/node_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,11 @@ def do_skip(self, cause=None):
self.skip_cause = cause

@classmethod
def get_model_schemas(cls, manifest):
def get_model_schemas(cls, manifest, selected_uids):
schemas = set()
for node in manifest.nodes.values():
if node.unique_id not in selected_uids:
continue
if cls.is_refable(node) and not cls.is_ephemeral(node):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this going to cause a problem with Sources? I imagine sources are refable (idk if you actually made that change in the corresponding PR), but we definitely wouldn't want to touch source schemas at all! Might be worth being more explicit here? I'd only care about schemas for:

  • models
  • seeds
  • archives (to come in Wilt Chamberlain)

No change required necessarily, but wanted to surface it because i saw it :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sources are not refable, so no.

schemas.add((node.database, node.schema))

Expand All @@ -218,7 +220,7 @@ def before_hooks(self, config, adapter, manifest):
pass

@classmethod
def before_run(self, config, adapter, manifest):
def before_run(self, config, adapter, manifest, selected_uids):
pass

@classmethod
Expand Down Expand Up @@ -339,8 +341,8 @@ def safe_run_hooks(cls, config, adapter, manifest, hook_type,
raise

@classmethod
def create_schemas(cls, config, adapter, manifest):
required_schemas = cls.get_model_schemas(manifest)
def create_schemas(cls, config, adapter, manifest, selected_uids):
required_schemas = cls.get_model_schemas(manifest, selected_uids)

# Snowflake needs to issue a "use {schema}" query, where schema
# is the one defined in the profile. Create this schema if it
Expand All @@ -364,10 +366,10 @@ def populate_adapter_cache(cls, config, adapter, manifest):
adapter.set_relations_cache(manifest)

@classmethod
def before_run(cls, config, adapter, manifest):
def before_run(cls, config, adapter, manifest, selected_uids):
cls.populate_adapter_cache(config, adapter, manifest)
cls.safe_run_hooks(config, adapter, manifest, RunHookType.Start, {})
cls.create_schemas(config, adapter, manifest)
cls.create_schemas(config, adapter, manifest, selected_uids)

@classmethod
def print_results_line(cls, results, execution_time):
Expand Down
6 changes: 3 additions & 3 deletions core/dbt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,6 @@ def execute_nodes(self):
dbt.ui.printer.print_timestamped_line(concurrency_line)
dbt.ui.printer.print_timestamped_line("")

schemas = list(self.Runner.get_model_schemas(self.manifest))

pool = ThreadPool(num_threads)
try:
self.run_queue(pool)
Expand Down Expand Up @@ -303,10 +301,12 @@ def run(self):
else:
logger.info("")

selected_uids = frozenset(n.unique_id for n in self._flattened_nodes)
try:
self.Runner.before_hooks(self.config, adapter, self.manifest)
started = time.time()
self.Runner.before_run(self.config, adapter, self.manifest)
self.Runner.before_run(self.config, adapter, self.manifest,
selected_uids)
res = self.execute_nodes()
self.Runner.after_run(self.config, adapter, res, self.manifest)
elapsed = time.time() - started
Expand Down
6 changes: 5 additions & 1 deletion plugins/postgres/dbt/include/postgres/macros/adapters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@
{% if database -%}
{{ adapter.verify_database(database) }}
{%- endif -%}
"select distinct nspname from pg_namespace"
{% call statement('list_schemas', fetch_result=True, auto_begin=False) %}
select distinct nspname from pg_namespace
{% endcall %}
{{ return(load_result('list_schemas').table) }}
{% endmacro %}

{% macro postgres__check_schema_exists(database, schema) -%}
Expand All @@ -81,4 +84,5 @@
{% call statement('check_schema_exists', fetch_result=True, auto_begin=False) %}
select count(*) from pg_namespace where nspname = '{{ schema }}'
{% endcall %}
{{ return(load_result('check_schema_exists').table) }}
{% endmacro %}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
{% macro snowflake__check_schema_exists(database, schema) -%}
{% call statement('check_schema_exists', fetch_result=True) -%}
select count(*)
from {{ information_schema_name(database) }}
from {{ information_schema_name(database) }}.schemata
where upper(schema_name) = upper('{{ schema }}')
and upper(catalog_name) = upper('{{ database }}')
{%- endcall %}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{{
config(schema='_and_then')
}}

select * from {{ this.schema }}.seed
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,22 @@ def schema(self):
def models(self):
return "test/integration/007_graph_selection_tests/models"

def assert_correct_schemas(self):
exists = self.adapter.check_schema_exists(
self.default_database,
self.unique_schema(),
'__test'
)
self.assertTrue(exists)

schema = self.unique_schema()+'_and_then'
exists = self.adapter.check_schema_exists(
self.default_database,
schema,
'__test'
)
self.assertFalse(exists)

@attr(type='postgres')
def test__postgres__specific_model(self):
self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql")
Expand All @@ -23,6 +39,7 @@ def test__postgres__specific_model(self):
self.assertFalse('users_rollup' in created_models)
self.assertFalse('base_users' in created_models)
self.assertFalse('emails' in created_models)
self.assert_correct_schemas()

@attr(type='postgres')
def test__postgres__tags(self):
Expand All @@ -36,6 +53,7 @@ def test__postgres__tags(self):
self.assertFalse('emails' in created_models)
self.assertTrue('users' in created_models)
self.assertTrue('users_rollup' in created_models)
self.assert_correct_schemas()

@attr(type='postgres')
def test__postgres__tags_and_children(self):
Expand All @@ -49,6 +67,7 @@ def test__postgres__tags_and_children(self):
self.assertFalse('emails' in created_models)
self.assertTrue('users_rollup' in created_models)
self.assertTrue('users' in created_models)
self.assert_correct_schemas()

@attr(type='snowflake')
def test__snowflake__specific_model(self):
Expand All @@ -62,7 +81,7 @@ def test__snowflake__specific_model(self):
self.assertFalse('USERS_ROLLUP' in created_models)
self.assertFalse('BASE_USERS' in created_models)
self.assertFalse('EMAILS' in created_models)

self.assert_correct_schemas()

@attr(type='postgres')
def test__postgres__specific_model_and_children(self):
Expand All @@ -76,6 +95,7 @@ def test__postgres__specific_model_and_children(self):
created_models = self.get_models_in_schema()
self.assertFalse('base_users' in created_models)
self.assertFalse('emails' in created_models)
self.assert_correct_schemas()

@attr(type='snowflake')
def test__snowflake__specific_model_and_children(self):
Expand Down Expand Up @@ -105,6 +125,7 @@ def test__postgres__specific_model_and_parents(self):
created_models = self.get_models_in_schema()
self.assertFalse('base_users' in created_models)
self.assertFalse('emails' in created_models)
self.assert_correct_schemas()

@attr(type='snowflake')
def test__snowflake__specific_model_and_parents(self):
Expand Down Expand Up @@ -137,6 +158,7 @@ def test__postgres__specific_model_with_exclusion(self):
self.assertFalse('base_users' in created_models)
self.assertFalse('users_rollup' in created_models)
self.assertFalse('emails' in created_models)
self.assert_correct_schemas()

@attr(type='snowflake')
def test__snowflake__specific_model_with_exclusion(self):
Expand Down Expand Up @@ -164,3 +186,4 @@ def test__postgres__locally_qualified_name(self):
self.assertNotIn('emails', created_models)
self.assertIn('subdir', created_models)
self.assertIn('nested_users', created_models)
self.assert_correct_schemas()
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def packages_config(self):
def run_schema_and_assert(self, include, exclude, expected_tests):
self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql")
self.run_dbt(["deps"])
results = self.run_dbt()
results = self.run_dbt(['run', '--exclude', 'never_selected'])
self.assertEqual(len(results), 7)

args = FakeArgs()
Expand Down
2 changes: 1 addition & 1 deletion test/unit/test_snowflake_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def test_authenticator_private_key_authentication(self, mock_get_private_key):
self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
client_session_keep_alive=False, database='test_databse',
client_session_keep_alive=False, database='test_database',
role=None, schema='public', user='test_user',
warehouse='test_warehouse', private_key='test_key')
])