Skip to content

Commit

Permalink
Add task tags parameter
Browse files Browse the repository at this point in the history
Create base class for parameter types
 - include arbitrary "task_tags" parameter
 - include "timeout" so we can extract/document full/valid json schemas
   - this includes some extra-fiddly behavior around numbers
Add output to ps and poll results (including errors)
Fix a number of type annotation issues along the way
  • Loading branch information
Jacob Beck committed Oct 11, 2019
1 parent 2815b33 commit fb9747b
Show file tree
Hide file tree
Showing 12 changed files with 429 additions and 221 deletions.
35 changes: 0 additions & 35 deletions core/dbt/contracts/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from dbt.contracts.graph.parsed import ParsedSourceDefinition
from dbt.contracts.util import Writable, Replaceable
from dbt.logger import (
LogMessage,
TimingProcessor,
JsonOnly,
GLOBAL_LOGGER as logger,
Expand Down Expand Up @@ -207,35 +206,6 @@ class FreshnessRunOutput(JsonSchemaMixin, Writable):
sources: Dict[str, SourceFreshnessRunResult]


@dataclass
class RemoteCompileResult(JsonSchemaMixin):
raw_sql: str
compiled_sql: str
node: CompileResultNode
timing: List[TimingInfo]
logs: List[LogMessage]

@property
def error(self):
return None


@dataclass
class RemoteExecutionResult(ExecutionResult):
logs: List[LogMessage]


@dataclass
class ResultTable(JsonSchemaMixin):
column_names: List[str]
rows: List[Any]


@dataclass
class RemoteRunResult(RemoteCompileResult):
table: ResultTable


Primitive = Union[bool, str, float, None]

CatalogKey = NamedTuple(
Expand Down Expand Up @@ -298,8 +268,3 @@ class CatalogResults(JsonSchemaMixin, Writable):
nodes: Dict[str, CatalogTable]
generated_at: datetime
_compile_results: Optional[Any] = None


@dataclass
class RemoteCatalogResults(CatalogResults):
logs: List[LogMessage] = field(default_factory=list)
92 changes: 92 additions & 0 deletions core/dbt/contracts/rpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from dataclasses import dataclass, field
from numbers import Real
from typing import Optional, Union, List, Any, Dict

from hologram import JsonSchemaMixin

from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.results import (
TimingInfo,
CatalogResults,
ExecutionResult,
)
from dbt.logger import LogMessage

# Inputs


@dataclass
class RPCParameters(JsonSchemaMixin):
timeout: Optional[Real]
task_tags: Optional[Dict[str, Any]]


@dataclass
class RPCExecParameters(RPCParameters):
name: str
sql: str
macros: Optional[str]


@dataclass
class RPCCompileParameters(RPCParameters):
models: Union[None, str, List[str]] = None
exclude: Union[None, str, List[str]] = None


@dataclass
class RPCTestParameters(RPCCompileParameters):
data: bool = False
schema: bool = False


@dataclass
class RPCSeedParameters(RPCParameters):
show: bool = False


@dataclass
class RPCDocsGenerateParameters(RPCParameters):
compile: bool = True


@dataclass
class RPCCliParameters(RPCParameters):
cli: str


# Outputs


@dataclass
class RemoteCatalogResults(CatalogResults):
logs: List[LogMessage] = field(default_factory=list)


@dataclass
class RemoteCompileResult(JsonSchemaMixin):
raw_sql: str
compiled_sql: str
node: CompileResultNode
timing: List[TimingInfo]
logs: List[LogMessage]

@property
def error(self):
return None


@dataclass
class RemoteExecutionResult(ExecutionResult):
logs: List[LogMessage]


@dataclass
class ResultTable(JsonSchemaMixin):
column_names: List[str]
rows: List[Any]


@dataclass
class RemoteRunResult(RemoteCompileResult):
table: ResultTable
8 changes: 8 additions & 0 deletions core/dbt/helper_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# never name this package "types", or mypy will crash in ugly ways
from datetime import timedelta
from numbers import Real
from typing import NewType

from hologram import (
Expand Down Expand Up @@ -37,7 +38,14 @@ def json_schema(self) -> JsonDict:
return {'type': 'number'}


class RealEncoder(FieldEncoder):
@property
def json_schema(self):
return {'type': 'number'}


JsonSchemaMixin.register_field_encoders({
Port: PortEncoder(),
timedelta: TimeDeltaFieldEncoder(),
Real: RealEncoder(),
})
32 changes: 25 additions & 7 deletions core/dbt/rpc/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(
message: Optional[str] = None,
data: Optional[Dict[str, Any]] = None,
logs: Optional[List[Dict[str, Any]]] = None,
tags: Optional[Dict[str, Any]] = None
) -> None:
if code is None:
code = -32000
Expand All @@ -23,6 +24,7 @@ def __init__(
super().__init__(code=code, message=message, data=data)
if logs is not None:
self.logs = logs
self.error.data['tags'] = tags

def __str__(self):
return (
Expand All @@ -40,9 +42,25 @@ def logs(self, value):
return
self.error.data['logs'] = value

@property
def tags(self):
return self.error.data.get('tags')

@tags.setter
def tags(self, value):
if value is None:
return
self.error.data['tags'] = value

@classmethod
def from_error(cls, err):
return cls(err.code, err.message, err.data, err.data.get('logs'))
return cls(
code=err.code,
message=err.message,
data=err.data,
logs=err.data.get('logs'),
tags=err.data.get('tags'),
)


def invalid_params(data):
Expand All @@ -53,17 +71,17 @@ def invalid_params(data):
)


def server_error(err, logs=None):
def server_error(err, logs=None, tags=None):
exc = dbt.exceptions.Exception(str(err))
return dbt_error(exc, logs)
return dbt_error(exc, logs, tags)


def timeout_error(timeout_value, logs=None):
def timeout_error(timeout_value, logs=None, tags=None):
exc = dbt.exceptions.RPCTimeoutException(timeout_value)
return dbt_error(exc, logs)
return dbt_error(exc, logs, tags)


def dbt_error(exc, logs=None):
def dbt_error(exc, logs=None, tags=None):
exc = RPCException(code=exc.CODE, message=exc.MESSAGE, data=exc.data(),
logs=logs)
logs=logs, tags=tags)
return exc
2 changes: 1 addition & 1 deletion core/dbt/rpc/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from queue import Empty
from typing import Optional, Any, Union

from dbt.contracts.results import (
from dbt.contracts.rpc import (
RemoteCompileResult, RemoteExecutionResult, RemoteCatalogResults
)
from dbt.exceptions import InternalException
Expand Down
75 changes: 49 additions & 26 deletions core/dbt/rpc/node_runners.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
from abc import abstractmethod
from typing import Generic, TypeVar

import dbt.exceptions
from dbt.compilation import compile_node
from dbt.contracts.results import (
from dbt.contracts.rpc import (
RemoteCompileResult, RemoteRunResult, ResultTable,
)
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.node_runners import CompileRunner
from dbt.rpc.error import dbt_error, RPCException, server_error


class RPCCompileRunner(CompileRunner):
RPCSQLResult = TypeVar('RPCSQLResult', bound=RemoteCompileResult)


class GenericRPCRunner(CompileRunner, Generic[RPCSQLResult]):
def __init__(self, config, adapter, node, node_index, num_nodes):
super().__init__(config, adapter, node, node_index, num_nodes)
CompileRunner.__init__(
self, config, adapter, node, node_index, num_nodes
)

def handle_exception(self, e, ctx):
logger.debug('Got an exception: {}'.format(e), exc_info=True)
Expand All @@ -33,14 +41,13 @@ def compile(self, manifest):
return compile_node(self.adapter, self.config, self.node, manifest, {},
write=False)

def execute(self, compiled_node, manifest):
return RemoteCompileResult(
raw_sql=compiled_node.raw_sql,
compiled_sql=compiled_node.injected_sql,
node=compiled_node,
timing=[], # this will get added later
logs=[],
)
@abstractmethod
def execute(self, compiled_node, manifest) -> RPCSQLResult:
pass

@abstractmethod
def from_run_result(self, result, start_time, timing_info) -> RPCSQLResult:
pass

def error_result(self, node, error, start_time, timing_info):
raise error
Expand All @@ -50,34 +57,38 @@ def ephemeral_result(self, node, start_time, timing_info):
'cannot execute ephemeral nodes remotely!'
)

def from_run_result(self, result, start_time, timing_info):

class RPCCompileRunner(GenericRPCRunner[RemoteCompileResult]):
def execute(self, compiled_node, manifest) -> RemoteCompileResult:
return RemoteCompileResult(
raw_sql=result.raw_sql,
compiled_sql=result.compiled_sql,
node=result.node,
timing=timing_info,
raw_sql=compiled_node.raw_sql,
compiled_sql=compiled_node.injected_sql,
node=compiled_node,
timing=[], # this will get added later
logs=[],
)


class RPCExecuteRunner(RPCCompileRunner):
def from_run_result(self, result, start_time, timing_info):
return RemoteRunResult(
def from_run_result(
self, result, start_time, timing_info
) -> RemoteCompileResult:
return RemoteCompileResult(
raw_sql=result.raw_sql,
compiled_sql=result.compiled_sql,
node=result.node,
table=result.table,
timing=timing_info,
logs=[],
)

def execute(self, compiled_node, manifest):
status, table = self.adapter.execute(compiled_node.injected_sql,
fetch=True)

class RPCExecuteRunner(GenericRPCRunner[RemoteRunResult]):
def execute(self, compiled_node, manifest) -> RemoteRunResult:
_, execute_result = self.adapter.execute(
compiled_node.injected_sql, fetch=True
)

table = ResultTable(
column_names=list(table.column_names),
rows=[list(row) for row in table],
column_names=list(execute_result.column_names),
rows=[list(row) for row in execute_result],
)

return RemoteRunResult(
Expand All @@ -88,3 +99,15 @@ def execute(self, compiled_node, manifest):
timing=[],
logs=[],
)

def from_run_result(
self, result, start_time, timing_info
) -> RemoteRunResult:
return RemoteRunResult(
raw_sql=result.raw_sql,
compiled_sql=result.compiled_sql,
node=result.node,
table=result.table,
timing=timing_info,
logs=[],
)
Loading

0 comments on commit fb9747b

Please sign in to comment.