Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add run-operation + snapshot to the RPC server (#1875) #1878

Merged
merged 3 commits into from
Nov 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions core/dbt/contracts/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class RPCCompileParameters(RPCParameters):
exclude: Union[None, str, List[str]] = None


@dataclass
class RPCSnapshotParameters(RPCParameters):
select: Union[None, str, List[str]] = None
exclude: Union[None, str, List[str]] = None


@dataclass
class RPCTestParameters(RPCCompileParameters):
data: bool = False
Expand Down Expand Up @@ -116,9 +122,15 @@ class GCParameters(RPCParameters):
will be applied to the task manager before GC starts. By default the
existing gc settings remain.
"""
task_ids: Optional[List[TaskID]]
before: Optional[datetime]
settings: Optional[GCSettings]
task_ids: Optional[List[TaskID]] = None
before: Optional[datetime] = None
settings: Optional[GCSettings] = None


@dataclass
class RPCRunOperationParameters(RPCParameters):
macro: str
args: Dict[str, Any] = field(default_factory=dict)


# Outputs
Expand Down Expand Up @@ -161,6 +173,11 @@ class ResultTable(JsonSchemaMixin):
rows: List[Any]


@dataclass
class RemoteRunOperationResult(RemoteResult):
success: bool


@dataclass
class RemoteRunResult(RemoteCompileResult):
table: ResultTable
Expand Down Expand Up @@ -431,6 +448,31 @@ def from_result(
)


@dataclass
class PollRunOperationCompleteResult(RemoteRunOperationResult, PollResult):
state: TaskHandlerState = field(
metadata=restrict_to(TaskHandlerState.Success,
TaskHandlerState.Failed),
)

@classmethod
def from_result(
cls: Type['PollRunOperationCompleteResult'],
base: RemoteRunOperationResult,
tags: TaskTags,
timing: TaskTiming,
) -> 'PollRunOperationCompleteResult':
return cls(
success=base.success,
logs=base.logs,
tags=tags,
state=timing.state,
start=timing.start,
end=timing.end,
elapsed=timing.elapsed,
)


@dataclass
class PollCatalogCompleteResult(RemoteCatalogResults, PollResult):
state: TaskHandlerState = field(
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def _build_snapshot_subparser(subparsers, base_subparser):
'''
)
sub.set_defaults(cls=snapshot_task.SnapshotTask, which='snapshot',
rpc_method=None)
rpc_method='snapshot')
return sub


Expand Down Expand Up @@ -707,7 +707,7 @@ def _build_run_operation_subparser(subparsers, base_subparser):
'''
)
sub.set_defaults(cls=run_operation_task.RunOperationTask,
which='run-operation', rpc_method=None)
which='run-operation', rpc_method='run-operation')
return sub


Expand Down
5 changes: 5 additions & 0 deletions core/dbt/rpc/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RemoteCompileResult,
RemoteCatalogResults,
RemoteEmptyResult,
RemoteRunOperationResult,
PollParameters,
PollResult,
PollInProgressResult,
Expand All @@ -30,6 +31,7 @@
PollCompileCompleteResult,
PollCatalogCompleteResult,
PollRemoteEmptyCompleteResult,
PollRunOperationCompleteResult,
TaskHandlerState,
TaskTiming,
)
Expand Down Expand Up @@ -141,6 +143,7 @@ def poll_complete(
PollCompileCompleteResult,
PollCatalogCompleteResult,
PollRemoteEmptyCompleteResult,
PollRunOperationCompleteResult,
]]

if isinstance(result, RemoteExecutionResult):
Expand All @@ -154,6 +157,8 @@ def poll_complete(
cls = PollCatalogCompleteResult
elif isinstance(result, RemoteEmptyResult):
cls = PollRemoteEmptyCompleteResult
elif isinstance(result, RemoteRunOperationResult):
cls = PollRunOperationCompleteResult
else:
raise dbt.exceptions.InternalException(
'got invalid result in poll_complete: {}'.format(result)
Expand Down
43 changes: 43 additions & 0 deletions core/dbt/task/rpc/project_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@
from dbt.contracts.rpc import (
RPCCompileParameters,
RPCDocsGenerateParameters,
RPCRunOperationParameters,
RPCSeedParameters,
RPCTestParameters,
RemoteCatalogResults,
RemoteExecutionResult,
RemoteRunOperationResult,
RPCSnapshotParameters,
)
from dbt.rpc.method import (
Parameters,
)
from dbt.task.compile import CompileTask
from dbt.task.generate import GenerateTask
from dbt.task.run import RunTask
from dbt.task.run_operation import RunOperationTask
from dbt.task.seed import SeedTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask

from .base import RPCTask
Expand Down Expand Up @@ -97,3 +102,41 @@ def get_catalog_results(
_compile_results=compile_results,
logs=[],
)


class RemoteRunOperationTask(
RPCTask[RPCRunOperationParameters],
HasCLI[RPCRunOperationParameters, RemoteRunOperationResult],
RunOperationTask,
):
METHOD_NAME = 'run-operation'

def set_args(self, params: RPCRunOperationParameters) -> None:
self.args.macro = params.macro
self.args.args = params.args

def _get_kwargs(self):
if isinstance(self.args.args, dict):
return self.args.args
else:
return RunOperationTask._get_kwargs(self)

def _runtime_initialize(self):
return RunOperationTask._runtime_initialize(self)

def handle_request(self) -> RemoteRunOperationResult:
success, _ = RunOperationTask.run(self)
result = RemoteRunOperationResult(logs=[], success=success)
return result

def interpret_results(self, results):
return results.success


class RemoteSnapshotTask(RPCCommandTask[RPCSnapshotParameters], SnapshotTask):
METHOD_NAME = 'snapshot'

def set_args(self, params: RPCSnapshotParameters) -> None:
# select has an argparse `dest` value of `models`.
self.args.models = self._listify(params.select)
self.args.exclude = self._listify(params.exclude)
4 changes: 4 additions & 0 deletions core/dbt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,10 @@ def default(self, obj):
return float(obj)
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
return obj.isoformat()
if hasattr(obj, 'to_dict'):
# if we have a to_dict we should try to serialize the result of
# that!
obj = obj.to_dict()
return super().default(obj)


Expand Down
3 changes: 3 additions & 0 deletions plugins/snowflake/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
install_requires=[
'dbt-core=={}'.format(package_version),
'snowflake-connector-python>=1.6.12,<2.1',
'azure-storage-blob~=2.1',
'azure-storage-common~=2.1',

],
zip_safe=False,
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pickle
import os


class TestRpcExecuteReturnsResults(DBTIntegrationTest):

@property
Expand All @@ -26,7 +27,7 @@ def test_pickle(self, agate_table):

pickle.dumps(table)

def test_file(self, filename):
def do_test_file(self, filename):
file_path = os.path.join("sql", filename)
with open(file_path) as fh:
query = fh.read()
Expand All @@ -39,12 +40,12 @@ def test_file(self, filename):

@use_profile('bigquery')
def test__bigquery_fetch_and_serialize(self):
self.test_file('bigquery.sql')
self.do_test_file('bigquery.sql')

@use_profile('snowflake')
def test__snowflake_fetch_and_serialize(self):
self.test_file('snowflake.sql')
self.do_test_file('snowflake.sql')

@use_profile('redshift')
def test__redshift_fetch_and_serialize(self):
self.test_file('redshift.sql')
self.do_test_file('redshift.sql')
54 changes: 54 additions & 0 deletions test/rpc/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import os
import pytest
import random
import time
from typing import Dict, Any

import yaml


@pytest.fixture
def unique_schema() -> str:
return "test{}{:04}".format(int(time.time()), random.randint(0, 9999))


@pytest.fixture
def profiles_root(tmpdir):
return tmpdir.mkdir('profile')


@pytest.fixture
def project_root(tmpdir):
return tmpdir.mkdir('project')


@pytest.fixture
def postgres_profile_data(unique_schema):
return {
'config': {
'send_anonymous_usage_stats': False
},
'test': {
'outputs': {
'default': {
'type': 'postgres',
'threads': 4,
'host': 'database',
'port': 5432,
'user': 'root',
'pass': 'password',
'dbname': 'dbt',
'schema': unique_schema,
},
},
'target': 'default'
}
}


@pytest.fixture
def postgres_profile(profiles_root, postgres_profile_data) -> Dict[str, Any]:
path = os.path.join(profiles_root, 'profiles.yml')
with open(path, 'w') as fp:
fp.write(yaml.safe_dump(postgres_profile_data))
return postgres_profile_data
Loading