Skip to content

Commit

Permalink
Merge pull request #1373 from fishtown-analytics/fix/rpc-ephemeral-clean
Browse files Browse the repository at this point in the history
Support ephemeral nodes (#1368)
  • Loading branch information
beckjake authored Mar 27, 2019
2 parents fcb97bf + 9373a45 commit bea2d4f
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 19 deletions.
13 changes: 6 additions & 7 deletions core/dbt/contracts/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,20 +482,18 @@ def __init__(self, meta, sources):
class RemoteCompileResult(APIObject):
SCHEMA = REMOTE_COMPILE_RESULT_CONTRACT

def __init__(self, raw_sql, compiled_sql, timing=None, **kwargs):
def __init__(self, raw_sql, compiled_sql, node, timing=None, **kwargs):
if timing is None:
timing = []
# this should not show up in the serialized output.
self.node = node
super(RemoteCompileResult, self).__init__(
raw_sql=raw_sql,
compiled_sql=compiled_sql,
timing=timing,
**kwargs
)

@property
def node(self):
return None

@property
def error(self):
return None
Expand Down Expand Up @@ -525,12 +523,13 @@ def error(self):
class RemoteRunResult(RemoteCompileResult):
SCHEMA = REMOTE_RUN_RESULT_CONTRACT

def __init__(self, raw_sql, compiled_sql, timing=None, table=None):
def __init__(self, raw_sql, compiled_sql, node, timing=None, table=None):
if table is None:
table = []
super(RemoteRunResult, self).__init__(
raw_sql=raw_sql,
compiled_sql=compiled_sql,
timing=timing,
table=table
table=table,
node=node
)
6 changes: 5 additions & 1 deletion core/dbt/node_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,8 @@ def compile(self, manifest):
def execute(self, compiled_node, manifest):
return RemoteCompileResult(
raw_sql=compiled_node.raw_sql,
compiled_sql=compiled_node.injected_sql
compiled_sql=compiled_node.injected_sql,
node=compiled_node
)

def error_result(self, node, error, start_time, timing_info):
Expand All @@ -561,6 +562,7 @@ def from_run_result(self, result, start_time, timing_info):
return RemoteCompileResult(
raw_sql=result.raw_sql,
compiled_sql=result.compiled_sql,
node=result.node,
timing=timing
)

Expand All @@ -571,6 +573,7 @@ def from_run_result(self, result, start_time, timing_info):
return RemoteRunResult(
raw_sql=result.raw_sql,
compiled_sql=result.compiled_sql,
node=result.node,
table=result.table,
timing=timing
)
Expand All @@ -586,5 +589,6 @@ def execute(self, compiled_node, manifest):
return RemoteRunResult(
raw_sql=compiled_node.raw_sql,
compiled_sql=compiled_node.injected_sql,
node=compiled_node,
table=table
)
19 changes: 9 additions & 10 deletions core/dbt/task/compile.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os

from dbt.adapters.factory import get_adapter
from dbt.clients.jinja import extract_toplevel_blocks
from dbt.compilation import compile_manifest
from dbt.loader import load_all_projects, GraphLoader
from dbt.loader import load_all_projects
from dbt.node_runners import CompileRunner, RPCCompileRunner
from dbt.node_types import NodeType
from dbt.parser.analysis import RPCCallParser
Expand Down Expand Up @@ -37,12 +36,9 @@ def task_end_messages(self, results):
class RemoteCompileTask(CompileTask, RemoteCallable):
METHOD_NAME = 'compile'

def __init__(self, args, config):
def __init__(self, args, config, manifest):
super(RemoteCompileTask, self).__init__(args, config)
self._base_manifest = GraphLoader.load_all(
config,
internal_manifest=get_adapter(config).check_internal_manifest()
)
self._base_manifest = manifest

def get_runner_type(self):
return RPCCompileRunner
Expand Down Expand Up @@ -70,7 +66,7 @@ def _extract_request_data(self, data):
sql = ''.join(data_chunks)
return sql, macros

def handle_request(self, name, sql):
def _get_exec_node(self, name, sql, macros):
request_path = os.path.join(self.config.target_path, 'rpc', name)
all_projects = load_all_projects(self.config)
macro_overrides = {}
Expand Down Expand Up @@ -103,7 +99,6 @@ def handle_request(self, name, sql):
}

unique_id, node = rpc_parser.parse_sql_node(node_dict)

self.manifest = ParserUtils.add_new_refs(
manifest=self._base_manifest,
current_project=self.config,
Expand All @@ -113,11 +108,15 @@ def handle_request(self, name, sql):

# don't write our new, weird manifest!
self.linker = compile_manifest(self.config, self.manifest, write=False)
return node

def handle_request(self, name, sql, macros=None):
node = self._get_exec_node(name, sql, macros)

selected_uids = [node.unique_id]
self.runtime_cleanup(selected_uids)
self.job_queue = self.linker.as_graph_queue(self.manifest,
selected_uids)

result = self.get_runner(node).safe_run(self.manifest)

return result.serialize()
2 changes: 1 addition & 1 deletion core/dbt/task/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(self, args, config, tasks=None):
self.dispatcher = Dispatcher()
tasks = tasks or [RemoteCompileTask, RemoteRunTask]
for cls in tasks:
self.register(cls(args, config))
self.register(cls(args, config, self.manifest))

def register(self, task):
self.dispatcher.add_method(RequestTaskHandler.factory(task),
Expand Down
3 changes: 3 additions & 0 deletions test/integration/042_sources_test/models/ephemeral_model.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{{ config(materialized='ephemeral') }}

select 1 as id
30 changes: 30 additions & 0 deletions test/integration/042_sources_test/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,13 @@ def start(self):
raise Exception('server never appeared!')


_select_from_ephemeral = '''with __dbt__CTE__ephemeral_model as (
select 1 as id
)select * from __dbt__CTE__ephemeral_model'''


@unittest.skipIf(os.name == 'nt', 'Windows not supported for now')
class TestRPCServer(BaseSourcesTest):
def setUp(self):
Expand Down Expand Up @@ -481,6 +488,17 @@ def test_compile(self):
compiled_sql='select 2 as id'
)

ephemeral = self.query(
'compile',
'select * from {{ ref("ephemeral_model") }}',
name='foo'
).json()
self.assertSuccessfulCompilationResult(
ephemeral,
'select * from {{ ref("ephemeral_model") }}',
compiled_sql=_select_from_ephemeral
)

@use_profile('postgres')
def test_run(self):
# seed + run dbt to make models before using them!
Expand Down Expand Up @@ -592,6 +610,18 @@ def test_run(self):
table={'column_names': ['id'], 'rows': [[1.0]]}
)

ephemeral = self.query(
'run',
'select * from {{ ref("ephemeral_model") }}',
name='foo'
).json()
self.assertSuccessfulRunResult(
ephemeral,
raw_sql='select * from {{ ref("ephemeral_model") }}',
compiled_sql=_select_from_ephemeral,
table={'column_names': ['id'], 'rows': [[1.0]]}
)

@use_profile('postgres')
def test_invalid_requests(self):
data = self.query(
Expand Down

0 comments on commit bea2d4f

Please sign in to comment.