Skip to content

Commit

Permalink
Merge pull request #1878 from fishtown-analytics/feature/run-operatio…
Browse files Browse the repository at this point in the history
…n-snapshot-rpc

add run-operation + snapshot to the RPC server (#1875)
  • Loading branch information
beckjake authored Nov 1, 2019
2 parents e022e73 + 658be46 commit f985902
Show file tree
Hide file tree
Showing 10 changed files with 442 additions and 116 deletions.
48 changes: 45 additions & 3 deletions core/dbt/contracts/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class RPCCompileParameters(RPCParameters):
exclude: Union[None, str, List[str]] = None


@dataclass
class RPCSnapshotParameters(RPCParameters):
select: Union[None, str, List[str]] = None
exclude: Union[None, str, List[str]] = None


@dataclass
class RPCTestParameters(RPCCompileParameters):
data: bool = False
Expand Down Expand Up @@ -116,9 +122,15 @@ class GCParameters(RPCParameters):
will be applied to the task manager before GC starts. By default the
existing gc settings remain.
"""
task_ids: Optional[List[TaskID]]
before: Optional[datetime]
settings: Optional[GCSettings]
task_ids: Optional[List[TaskID]] = None
before: Optional[datetime] = None
settings: Optional[GCSettings] = None


@dataclass
class RPCRunOperationParameters(RPCParameters):
macro: str
args: Dict[str, Any] = field(default_factory=dict)


# Outputs
Expand Down Expand Up @@ -161,6 +173,11 @@ class ResultTable(JsonSchemaMixin):
rows: List[Any]


@dataclass
class RemoteRunOperationResult(RemoteResult):
success: bool


@dataclass
class RemoteRunResult(RemoteCompileResult):
table: ResultTable
Expand Down Expand Up @@ -431,6 +448,31 @@ def from_result(
)


@dataclass
class PollRunOperationCompleteResult(RemoteRunOperationResult, PollResult):
state: TaskHandlerState = field(
metadata=restrict_to(TaskHandlerState.Success,
TaskHandlerState.Failed),
)

@classmethod
def from_result(
cls: Type['PollRunOperationCompleteResult'],
base: RemoteRunOperationResult,
tags: TaskTags,
timing: TaskTiming,
) -> 'PollRunOperationCompleteResult':
return cls(
success=base.success,
logs=base.logs,
tags=tags,
state=timing.state,
start=timing.start,
end=timing.end,
elapsed=timing.elapsed,
)


@dataclass
class PollCatalogCompleteResult(RemoteCatalogResults, PollResult):
state: TaskHandlerState = field(
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def _build_snapshot_subparser(subparsers, base_subparser):
'''
)
sub.set_defaults(cls=snapshot_task.SnapshotTask, which='snapshot',
rpc_method=None)
rpc_method='snapshot')
return sub


Expand Down Expand Up @@ -707,7 +707,7 @@ def _build_run_operation_subparser(subparsers, base_subparser):
'''
)
sub.set_defaults(cls=run_operation_task.RunOperationTask,
which='run-operation', rpc_method=None)
which='run-operation', rpc_method='run-operation')
return sub


Expand Down
5 changes: 5 additions & 0 deletions core/dbt/rpc/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RemoteCompileResult,
RemoteCatalogResults,
RemoteEmptyResult,
RemoteRunOperationResult,
PollParameters,
PollResult,
PollInProgressResult,
Expand All @@ -30,6 +31,7 @@
PollCompileCompleteResult,
PollCatalogCompleteResult,
PollRemoteEmptyCompleteResult,
PollRunOperationCompleteResult,
TaskHandlerState,
TaskTiming,
)
Expand Down Expand Up @@ -141,6 +143,7 @@ def poll_complete(
PollCompileCompleteResult,
PollCatalogCompleteResult,
PollRemoteEmptyCompleteResult,
PollRunOperationCompleteResult,
]]

if isinstance(result, RemoteExecutionResult):
Expand All @@ -154,6 +157,8 @@ def poll_complete(
cls = PollCatalogCompleteResult
elif isinstance(result, RemoteEmptyResult):
cls = PollRemoteEmptyCompleteResult
elif isinstance(result, RemoteRunOperationResult):
cls = PollRunOperationCompleteResult
else:
raise dbt.exceptions.InternalException(
'got invalid result in poll_complete: {}'.format(result)
Expand Down
43 changes: 43 additions & 0 deletions core/dbt/task/rpc/project_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@
from dbt.contracts.rpc import (
RPCCompileParameters,
RPCDocsGenerateParameters,
RPCRunOperationParameters,
RPCSeedParameters,
RPCTestParameters,
RemoteCatalogResults,
RemoteExecutionResult,
RemoteRunOperationResult,
RPCSnapshotParameters,
)
from dbt.rpc.method import (
Parameters,
)
from dbt.task.compile import CompileTask
from dbt.task.generate import GenerateTask
from dbt.task.run import RunTask
from dbt.task.run_operation import RunOperationTask
from dbt.task.seed import SeedTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask

from .base import RPCTask
Expand Down Expand Up @@ -97,3 +102,41 @@ def get_catalog_results(
_compile_results=compile_results,
logs=[],
)


class RemoteRunOperationTask(
RPCTask[RPCRunOperationParameters],
HasCLI[RPCRunOperationParameters, RemoteRunOperationResult],
RunOperationTask,
):
METHOD_NAME = 'run-operation'

def set_args(self, params: RPCRunOperationParameters) -> None:
self.args.macro = params.macro
self.args.args = params.args

def _get_kwargs(self):
if isinstance(self.args.args, dict):
return self.args.args
else:
return RunOperationTask._get_kwargs(self)

def _runtime_initialize(self):
return RunOperationTask._runtime_initialize(self)

def handle_request(self) -> RemoteRunOperationResult:
success, _ = RunOperationTask.run(self)
result = RemoteRunOperationResult(logs=[], success=success)
return result

def interpret_results(self, results):
return results.success


class RemoteSnapshotTask(RPCCommandTask[RPCSnapshotParameters], SnapshotTask):
METHOD_NAME = 'snapshot'

def set_args(self, params: RPCSnapshotParameters) -> None:
# select has an argparse `dest` value of `models`.
self.args.models = self._listify(params.select)
self.args.exclude = self._listify(params.exclude)
4 changes: 4 additions & 0 deletions core/dbt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,10 @@ def default(self, obj):
return float(obj)
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
return obj.isoformat()
if hasattr(obj, 'to_dict'):
# if we have a to_dict we should try to serialize the result of
# that!
obj = obj.to_dict()
return super().default(obj)


Expand Down
3 changes: 3 additions & 0 deletions plugins/snowflake/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
install_requires=[
'dbt-core=={}'.format(package_version),
'snowflake-connector-python>=1.6.12,<2.1',
'azure-storage-blob~=2.1',
'azure-storage-common~=2.1',

],
zip_safe=False,
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pickle
import os


class TestRpcExecuteReturnsResults(DBTIntegrationTest):

@property
Expand All @@ -26,7 +27,7 @@ def test_pickle(self, agate_table):

pickle.dumps(table)

def test_file(self, filename):
def do_test_file(self, filename):
file_path = os.path.join("sql", filename)
with open(file_path) as fh:
query = fh.read()
Expand All @@ -39,12 +40,12 @@ def test_file(self, filename):

@use_profile('bigquery')
def test__bigquery_fetch_and_serialize(self):
self.test_file('bigquery.sql')
self.do_test_file('bigquery.sql')

@use_profile('snowflake')
def test__snowflake_fetch_and_serialize(self):
self.test_file('snowflake.sql')
self.do_test_file('snowflake.sql')

@use_profile('redshift')
def test__redshift_fetch_and_serialize(self):
self.test_file('redshift.sql')
self.do_test_file('redshift.sql')
54 changes: 54 additions & 0 deletions test/rpc/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import os
import pytest
import random
import time
from typing import Dict, Any

import yaml


@pytest.fixture
def unique_schema() -> str:
return "test{}{:04}".format(int(time.time()), random.randint(0, 9999))


@pytest.fixture
def profiles_root(tmpdir):
return tmpdir.mkdir('profile')


@pytest.fixture
def project_root(tmpdir):
return tmpdir.mkdir('project')


@pytest.fixture
def postgres_profile_data(unique_schema):
return {
'config': {
'send_anonymous_usage_stats': False
},
'test': {
'outputs': {
'default': {
'type': 'postgres',
'threads': 4,
'host': 'database',
'port': 5432,
'user': 'root',
'pass': 'password',
'dbname': 'dbt',
'schema': unique_schema,
},
},
'target': 'default'
}
}


@pytest.fixture
def postgres_profile(profiles_root, postgres_profile_data) -> Dict[str, Any]:
path = os.path.join(profiles_root, 'profiles.yml')
with open(path, 'w') as fp:
fp.write(yaml.safe_dump(postgres_profile_data))
return postgres_profile_data
Loading

0 comments on commit f985902

Please sign in to comment.