Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ephemeral nodes (#1368) #1373

Merged
merged 3 commits into from
Mar 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions core/dbt/contracts/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,20 +482,18 @@ def __init__(self, meta, sources):
class RemoteCompileResult(APIObject):
SCHEMA = REMOTE_COMPILE_RESULT_CONTRACT

def __init__(self, raw_sql, compiled_sql, timing=None, **kwargs):
def __init__(self, raw_sql, compiled_sql, node, timing=None, **kwargs):
if timing is None:
timing = []
# this should not show up in the serialized output.
self.node = node
super(RemoteCompileResult, self).__init__(
raw_sql=raw_sql,
compiled_sql=compiled_sql,
timing=timing,
**kwargs
)

@property
def node(self):
return None

@property
def error(self):
return None
Expand Down Expand Up @@ -525,12 +523,13 @@ def error(self):
class RemoteRunResult(RemoteCompileResult):
SCHEMA = REMOTE_RUN_RESULT_CONTRACT

def __init__(self, raw_sql, compiled_sql, timing=None, table=None):
def __init__(self, raw_sql, compiled_sql, node, timing=None, table=None):
if table is None:
table = []
super(RemoteRunResult, self).__init__(
raw_sql=raw_sql,
compiled_sql=compiled_sql,
timing=timing,
table=table
table=table,
node=node
)
6 changes: 5 additions & 1 deletion core/dbt/node_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,8 @@ def compile(self, manifest):
def execute(self, compiled_node, manifest):
return RemoteCompileResult(
raw_sql=compiled_node.raw_sql,
compiled_sql=compiled_node.injected_sql
compiled_sql=compiled_node.injected_sql,
node=compiled_node
)

def error_result(self, node, error, start_time, timing_info):
Expand All @@ -561,6 +562,7 @@ def from_run_result(self, result, start_time, timing_info):
return RemoteCompileResult(
raw_sql=result.raw_sql,
compiled_sql=result.compiled_sql,
node=result.node,
timing=timing
)

Expand All @@ -571,6 +573,7 @@ def from_run_result(self, result, start_time, timing_info):
return RemoteRunResult(
raw_sql=result.raw_sql,
compiled_sql=result.compiled_sql,
node=result.node,
table=result.table,
timing=timing
)
Expand All @@ -586,5 +589,6 @@ def execute(self, compiled_node, manifest):
return RemoteRunResult(
raw_sql=compiled_node.raw_sql,
compiled_sql=compiled_node.injected_sql,
node=compiled_node,
table=table
)
19 changes: 9 additions & 10 deletions core/dbt/task/compile.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os

from dbt.adapters.factory import get_adapter
from dbt.clients.jinja import extract_toplevel_blocks
from dbt.compilation import compile_manifest
from dbt.loader import load_all_projects, GraphLoader
from dbt.loader import load_all_projects
from dbt.node_runners import CompileRunner, RPCCompileRunner
from dbt.node_types import NodeType
from dbt.parser.analysis import RPCCallParser
Expand Down Expand Up @@ -37,12 +36,9 @@ def task_end_messages(self, results):
class RemoteCompileTask(CompileTask, RemoteCallable):
METHOD_NAME = 'compile'

def __init__(self, args, config):
def __init__(self, args, config, manifest):
super(RemoteCompileTask, self).__init__(args, config)
self._base_manifest = GraphLoader.load_all(
config,
internal_manifest=get_adapter(config).check_internal_manifest()
)
self._base_manifest = manifest

def get_runner_type(self):
return RPCCompileRunner
Expand Down Expand Up @@ -70,7 +66,7 @@ def _extract_request_data(self, data):
sql = ''.join(data_chunks)
return sql, macros

def handle_request(self, name, sql):
def _get_exec_node(self, name, sql, macros):
request_path = os.path.join(self.config.target_path, 'rpc', name)
all_projects = load_all_projects(self.config)
macro_overrides = {}
Expand Down Expand Up @@ -103,7 +99,6 @@ def handle_request(self, name, sql):
}

unique_id, node = rpc_parser.parse_sql_node(node_dict)

self.manifest = ParserUtils.add_new_refs(
manifest=self._base_manifest,
current_project=self.config,
Expand All @@ -113,11 +108,15 @@ def handle_request(self, name, sql):

# don't write our new, weird manifest!
self.linker = compile_manifest(self.config, self.manifest, write=False)
return node

def handle_request(self, name, sql, macros=None):
node = self._get_exec_node(name, sql, macros)

selected_uids = [node.unique_id]
self.runtime_cleanup(selected_uids)
self.job_queue = self.linker.as_graph_queue(self.manifest,
selected_uids)

result = self.get_runner(node).safe_run(self.manifest)

return result.serialize()
2 changes: 1 addition & 1 deletion core/dbt/task/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(self, args, config, tasks=None):
self.dispatcher = Dispatcher()
tasks = tasks or [RemoteCompileTask, RemoteRunTask]
for cls in tasks:
self.register(cls(args, config))
self.register(cls(args, config, self.manifest))

def register(self, task):
self.dispatcher.add_method(RequestTaskHandler.factory(task),
Expand Down
3 changes: 3 additions & 0 deletions test/integration/042_sources_test/models/ephemeral_model.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{{ config(materialized='ephemeral') }}

select 1 as id
30 changes: 30 additions & 0 deletions test/integration/042_sources_test/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,13 @@ def start(self):
raise Exception('server never appeared!')


_select_from_ephemeral = '''with __dbt__CTE__ephemeral_model as (


select 1 as id
)select * from __dbt__CTE__ephemeral_model'''


@unittest.skipIf(os.name == 'nt', 'Windows not supported for now')
class TestRPCServer(BaseSourcesTest):
def setUp(self):
Expand Down Expand Up @@ -481,6 +488,17 @@ def test_compile(self):
compiled_sql='select 2 as id'
)

ephemeral = self.query(
'compile',
'select * from {{ ref("ephemeral_model") }}',
name='foo'
).json()
self.assertSuccessfulCompilationResult(
ephemeral,
'select * from {{ ref("ephemeral_model") }}',
compiled_sql=_select_from_ephemeral
)

@use_profile('postgres')
def test_run(self):
# seed + run dbt to make models before using them!
Expand Down Expand Up @@ -592,6 +610,18 @@ def test_run(self):
table={'column_names': ['id'], 'rows': [[1.0]]}
)

ephemeral = self.query(
'run',
'select * from {{ ref("ephemeral_model") }}',
name='foo'
).json()
self.assertSuccessfulRunResult(
ephemeral,
raw_sql='select * from {{ ref("ephemeral_model") }}',
compiled_sql=_select_from_ephemeral,
table={'column_names': ['id'], 'rows': [[1.0]]}
)

@use_profile('postgres')
def test_invalid_requests(self):
data = self.query(
Expand Down