diff --git a/core/dbt/contracts/rpc.py b/core/dbt/contracts/rpc.py index 8241695de2f..1270acadfb4 100644 --- a/core/dbt/contracts/rpc.py +++ b/core/dbt/contracts/rpc.py @@ -15,6 +15,8 @@ CatalogArtifact, CatalogResults, ExecutionResult, + FreshnessExecutionResultArtifact, + FreshnessResult, RunOperationResult, RunOperationResultsArtifact, RunResult, @@ -282,6 +284,28 @@ def write(self, path: str): writable.write(path) +@dataclass +@schema_version('remote-freshness-result', 1) +class RemoteFreshnessResult(FreshnessResult, RemoteResult): + + @classmethod + def from_local_result( + cls, + base: FreshnessResult, + logs: List[LogMessage], + ) -> 'RemoteFreshnessResult': + return cls( + metadata=base.metadata, + results=base.results, + elapsed_time=base.elapsed_time, + logs=logs, + ) + + def write(self, path: str): + writable = FreshnessExecutionResultArtifact.from_result(base=self) + writable.write(path) + + @dataclass @schema_version('remote-run-result', 1) class RemoteRunResult(RemoteCompileResultMixin): @@ -292,6 +316,7 @@ class RemoteRunResult(RemoteCompileResultMixin): RPCResult = Union[ RemoteCompileResult, RemoteExecutionResult, + RemoteFreshnessResult, RemoteCatalogResults, RemoteDepsResult, RemoteRunOperationResult, @@ -300,7 +325,6 @@ class RemoteRunResult(RemoteCompileResultMixin): # GC types - class GCResultState(StrEnum): Deleted = 'deleted' # successful GC Missing = 'missing' # nothing to GC @@ -682,6 +706,35 @@ def from_result( elapsed=timing.elapsed, ) + +@dataclass +@schema_version('poll-remote-freshness-result', 1) +class PollFreshnessResult(RemoteFreshnessResult, PollResult): + state: TaskHandlerState = field( + metadata=restrict_to(TaskHandlerState.Success, + TaskHandlerState.Failed), + ) + + @classmethod + def from_result( + cls: Type['PollFreshnessResult'], + base: RemoteFreshnessResult, + tags: TaskTags, + timing: TaskTiming, + logs: List[LogMessage], + ) -> 'PollFreshnessResult': + return cls( + logs=logs, + tags=tags, + state=timing.state, + start=timing.start, + end=timing.end, + elapsed=timing.elapsed, + metadata=base.metadata, + results=base.results, + elapsed_time=base.elapsed_time, + ) + # Manifest parsing types diff --git a/core/dbt/rpc/builtins.py b/core/dbt/rpc/builtins.py index 66cb171ff41..7e900ca5357 100644 --- a/core/dbt/rpc/builtins.py +++ b/core/dbt/rpc/builtins.py @@ -18,6 +18,7 @@ TaskRow, PSResult, RemoteExecutionResult, + RemoteFreshnessResult, RemoteRunResult, RemoteCompileResult, RemoteCatalogResults, @@ -32,6 +33,7 @@ PollRunCompleteResult, PollCompileCompleteResult, PollCatalogCompleteResult, + PollFreshnessResult, PollRemoteEmptyCompleteResult, PollRunOperationCompleteResult, TaskHandlerState, @@ -146,7 +148,8 @@ def poll_complete( PollCatalogCompleteResult, PollRemoteEmptyCompleteResult, PollRunOperationCompleteResult, - PollGetManifestResult + PollGetManifestResult, + PollFreshnessResult, ]] if isinstance(result, RemoteExecutionResult): @@ -164,6 +167,8 @@ def poll_complete( cls = PollRunOperationCompleteResult elif isinstance(result, GetManifestResult): cls = PollGetManifestResult + elif isinstance(result, RemoteFreshnessResult): + cls = PollFreshnessResult else: raise dbt.exceptions.InternalException( 'got invalid result in poll_complete: {}'.format(result) diff --git a/core/dbt/task/rpc/base.py b/core/dbt/task/rpc/base.py index 306e6fa6763..d31b4310905 100644 --- a/core/dbt/task/rpc/base.py +++ b/core/dbt/task/rpc/base.py @@ -1,9 +1,24 @@ -from dbt.contracts.results import RunResultsArtifact -from dbt.contracts.rpc import RemoteExecutionResult +from dbt.contracts.results import ( + RunResult, + RunOperationResult, + FreshnessResult, +) +from dbt.contracts.rpc import ( + RemoteExecutionResult, + RemoteFreshnessResult, + RemoteRunOperationResult, +) from dbt.task.runnable import GraphRunnableTask from dbt.rpc.method import RemoteManifestMethod, Parameters +RESULT_TYPE_MAP = { + RunResult: RemoteExecutionResult, + RunOperationResult: RemoteRunOperationResult, + FreshnessResult: RemoteFreshnessResult, +} + + class RPCTask( GraphRunnableTask, RemoteManifestMethod[Parameters, RemoteExecutionResult] @@ -21,10 +36,7 @@ def load_manifest(self): def get_result( self, results, elapsed_time, generated_at ) -> RemoteExecutionResult: - base = RunResultsArtifact.from_node_results( - results=results, - elapsed_time=elapsed_time, - generated_at=generated_at, - ) - rpc_result = RemoteExecutionResult.from_local_result(base, logs=[]) + base = super().get_result(results, elapsed_time, generated_at) + cls = RESULT_TYPE_MAP.get(type(base), RemoteExecutionResult) + rpc_result = cls.from_local_result(base, logs=[]) return rpc_result