diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f4718345d0..81ff2181933 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ ## dbt next (release TBD) +### Features +- Added a `get-manifest` API call. ([#2168](https://github.com/fishtown-analytics/dbt/issues/2168), [#2232](https://github.com/fishtown-analytics/dbt/pull/2232)) + ### Fixes - When a jinja value is undefined, give a helpful error instead of failing with cryptic "cannot pickle ParserMacroCapture" errors ([#2110](https://github.com/fishtown-analytics/dbt/issues/2110), [#2184](https://github.com/fishtown-analytics/dbt/pull/2184)) diff --git a/core/dbt/contracts/rpc.py b/core/dbt/contracts/rpc.py index 06e41e84c64..473f9619c97 100644 --- a/core/dbt/contracts/rpc.py +++ b/core/dbt/contracts/rpc.py @@ -9,6 +9,7 @@ from hologram.helpers import StrEnum from dbt.contracts.graph.compiled import CompileResultNode +from dbt.contracts.graph.manifest import WritableManifest from dbt.contracts.results import ( TimingInfo, CatalogResults, @@ -143,8 +144,13 @@ class RPCSourceFreshnessParameters(RPCParameters): select: Union[None, str, List[str]] = None +@dataclass +class GetManifestParameters(RPCParameters): + pass + # Outputs + @dataclass class RemoteResult(JsonSchemaMixin): logs: List[LogMessage] @@ -322,6 +328,11 @@ class KillResult(RemoteResult): logs: List[LogMessage] = field(default_factory=list) +@dataclass +class GetManifestResult(RemoteResult): + manifest: Optional[WritableManifest] + + # this is kind of carefuly structured: BlocksManifestTasks is implied by # RequiresConfigReloadBefore and RequiresManifestReloadAfter class RemoteMethodFlags(enum.Flag): @@ -526,8 +537,34 @@ class PollInProgressResult(PollResult): pass +@dataclass +class PollGetManifestResult(GetManifestResult, PollResult): + state: TaskHandlerState = field( + metadata=restrict_to(TaskHandlerState.Success, + TaskHandlerState.Failed), + ) + + @classmethod + def from_result( + cls: Type['PollGetManifestResult'], + base: GetManifestResult, + tags: TaskTags, + timing: TaskTiming, + logs: List[LogMessage], + ) -> 'PollGetManifestResult': + return cls( + manifest=base.manifest, + logs=logs, + tags=tags, + state=timing.state, + start=timing.start, + end=timing.end, + elapsed=timing.elapsed, + ) + # Manifest parsing types + class ManifestStatus(StrEnum): Init = 'init' Compiling = 'compiling' diff --git a/core/dbt/rpc/builtins.py b/core/dbt/rpc/builtins.py index 25b696fc0e0..c5cca284bbd 100644 --- a/core/dbt/rpc/builtins.py +++ b/core/dbt/rpc/builtins.py @@ -10,6 +10,7 @@ LastParse, GCParameters, GCResult, + GetManifestResult, KillParameters, KillResult, KillResultStatus, @@ -27,6 +28,7 @@ PollInProgressResult, PollKilledResult, PollExecuteCompleteResult, + PollGetManifestResult, PollRunCompleteResult, PollCompileCompleteResult, PollCatalogCompleteResult, @@ -144,6 +146,7 @@ def poll_complete( PollCatalogCompleteResult, PollRemoteEmptyCompleteResult, PollRunOperationCompleteResult, + PollGetManifestResult ]] if isinstance(result, RemoteExecutionResult): @@ -159,6 +162,8 @@ def poll_complete( cls = PollRemoteEmptyCompleteResult elif isinstance(result, RemoteRunOperationResult): cls = PollRunOperationCompleteResult + elif isinstance(result, GetManifestResult): + cls = PollGetManifestResult else: raise dbt.exceptions.InternalException( 'got invalid result in poll_complete: {}'.format(result) diff --git a/core/dbt/task/rpc/project_commands.py b/core/dbt/task/rpc/project_commands.py index d88f46b5371..465c228b4e2 100644 --- a/core/dbt/task/rpc/project_commands.py +++ b/core/dbt/task/rpc/project_commands.py @@ -1,8 +1,10 @@ from datetime import datetime from typing import List, Optional, Union - +from dbt.contracts.graph.manifest import WritableManifest from dbt.contracts.rpc import ( + GetManifestParameters, + GetManifestResult, RPCCompileParameters, RPCDocsGenerateParameters, RPCRunOperationParameters, @@ -188,3 +190,30 @@ def set_args(self, params: RPCSourceFreshnessParameters) -> None: if params.threads is not None: self.args.threads = params.threads self.args.output = None + + +# this is a weird and special method. +class GetManifest( + RemoteManifestMethod[GetManifestParameters, GetManifestResult] +): + METHOD_NAME = 'get-manifest' + + def set_args(self, params: GetManifestParameters) -> None: + self.args.models = None + self.args.exclude = None + + def handle_request(self) -> GetManifestResult: + task = RemoteCompileProjectTask(self.args, self.config, self.manifest) + task.handle_request() + + manifest: Optional[WritableManifest] = None + if task.manifest is not None: + manifest = task.manifest.writable_manifest() + + return GetManifestResult( + logs=[], + manifest=manifest, + ) + + def interpret_results(self, results): + return results.manifest is not None diff --git a/test/rpc/test_base.py b/test/rpc/test_base.py index 14694f155a3..e94f86d49a6 100644 --- a/test/rpc/test_base.py +++ b/test/rpc/test_base.py @@ -885,3 +885,32 @@ def test_rpc_vars( results = querier.async_wait_for_result(querier.cli_args('run --vars "{param: 100}"')) assert len(results['results']) == 1 assert results['results'][0]['node']['compiled_sql'] == 'select 100 as id' + + +def test_get_manifest( + project_root, profiles_root, postgres_profile, unique_schema +): + project = ProjectDefinition( + models={ + 'my_model.sql': 'select 1 as id', + }, + ) + querier_ctx = get_querier( + project_def=project, + project_dir=project_root, + profiles_dir=profiles_root, + schema=unique_schema, + test_kwargs={}, + ) + + with querier_ctx as querier: + results = querier.async_wait_for_result(querier.cli_args('run')) + assert len(results['results']) == 1 + assert results['results'][0]['node']['compiled_sql'] == 'select 1 as id' + result = querier.async_wait_for_result(querier.get_manifest()) + assert 'manifest' in result + manifest = result['manifest'] + assert manifest['nodes']['model.test.my_model']['raw_sql'] == 'select 1 as id' + assert 'manifest' in result + manifest = result['manifest'] + assert manifest['nodes']['model.test.my_model']['compiled_sql'] == 'select 1 as id' diff --git a/test/rpc/util.py b/test/rpc/util.py index 43e7a3ff87d..c684238f95c 100644 --- a/test/rpc/util.py +++ b/test/rpc/util.py @@ -369,6 +369,11 @@ def run_sql( method='run_sql', params=params, request_id=request_id ) + def get_manifest(self, request_id=1): + return self.request( + method='get-manifest', params={}, request_id=request_id + ) + def is_result(self, data: Dict[str, Any], id=None) -> Dict[str, Any]: if id is not None: assert data['id'] == id