diff --git a/core/dbt/task/show.py b/core/dbt/task/show.py index 1b84b85e093..19681b3a0c3 100644 --- a/core/dbt/task/show.py +++ b/core/dbt/task/show.py @@ -2,12 +2,14 @@ import threading import time +from dbt.contracts.graph.nodes import SeedNode from dbt.contracts.results import RunResult, RunStatus from dbt.events.base_types import EventLevel from dbt.events.functions import fire_event from dbt.events.types import ShowNode, Note from dbt.exceptions import DbtRuntimeError from dbt.task.compile import CompileTask, CompileRunner +from dbt.task.seed import SeedRunner class ShowRunner(CompileRunner): @@ -17,6 +19,12 @@ def __init__(self, config, adapter, node, node_index, num_nodes): def execute(self, compiled_node, manifest): start_time = time.time() + + if "sql_header" in compiled_node.unrendered_config: + compiled_node.compiled_code = ( + compiled_node.unrendered_config["sql_header"] + compiled_node.compiled_code + ) + adapter_response, execute_result = self.adapter.execute( compiled_node.compiled_code, fetch=True ) @@ -41,8 +49,11 @@ def _runtime_initialize(self): raise DbtRuntimeError("Either --select or --inline must be passed to show") super()._runtime_initialize() - def get_runner_type(self, _): - return ShowRunner + def get_runner_type(self, node): + if isinstance(node, SeedNode): + return SeedRunner + else: + return ShowRunner def task_end_messages(self, results): is_inline = bool(getattr(self.args, "inline", None)) diff --git a/tests/functional/show/fixtures.py b/tests/functional/show/fixtures.py index 48248195b72..d3d8c57e96c 100644 --- a/tests/functional/show/fixtures.py +++ b/tests/functional/show/fixtures.py @@ -10,6 +10,14 @@ from {{ ref('sample_model') }} """ +models__sql_header = """ +{% call set_sql_header(config) %} +set session time zone 'Asia/Kolkata'; +{%- endcall %} +select current_setting('timezone') as timezone +""" + + schema_yml = """ models: - name: sample_model diff --git a/tests/functional/show/test_show.py b/tests/functional/show/test_show.py index aa5dbb6025c..d2033f9b82a 100644 --- a/tests/functional/show/test_show.py +++ b/tests/functional/show/test_show.py @@ -9,6 +9,7 @@ models__second_model, models__ephemeral_model, schema_yml, + models__sql_header, ) @@ -19,6 +20,7 @@ def models(self): "sample_model.sql": models__sample_model, "second_model.sql": models__second_model, "ephemeral_model.sql": models__ephemeral_model, + "sql_header.sql": models__sql_header, } @pytest.fixture(scope="class") @@ -87,6 +89,15 @@ def test_second_ephemeral_model(self, project): ) assert "col_hundo" in log_output + def test_seed(self, project): + (results, log_output) = run_dbt_and_capture(["show", "--select", "sample_seed"]) + assert "Previewing node 'sample_seed'" in log_output + + def test_sql_header(self, project): + run_dbt(["build"]) + (results, log_output) = run_dbt_and_capture(["show", "--select", "sql_header"]) + assert "Asia/Kolkata" in log_output + class TestShowModelVersions: @pytest.fixture(scope="class")