diff --git a/core/dbt/contracts/results.py b/core/dbt/contracts/results.py index 9e2a3edf0e1..eebec17624d 100644 --- a/core/dbt/contracts/results.py +++ b/core/dbt/contracts/results.py @@ -1,7 +1,7 @@ from dbt.contracts.graph.manifest import CompileResultNode from dbt.contracts.graph.unparsed import Time, FreshnessStatus from dbt.contracts.graph.parsed import ParsedSourceDefinition -from dbt.contracts.util import Writable +from dbt.contracts.util import Writable, Replaceable from dbt.logger import LogMessage from hologram.helpers import StrEnum from hologram import JsonSchemaMixin @@ -10,7 +10,7 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import Union, Dict, List, Optional, Any +from typing import Union, Dict, List, Optional, Any, NamedTuple from numbers import Real @@ -227,3 +227,72 @@ class ResultTable(JsonSchemaMixin): @dataclass class RemoteRunResult(RemoteCompileResult): table: ResultTable + + +Primitive = Union[bool, str, float, None] + +CatalogKey = NamedTuple( + 'CatalogKey', + [('database', str), ('schema', str), ('name', str)] +) + + +@dataclass +class StatsItem(JsonSchemaMixin): + id: str + label: str + value: Primitive + description: str + include: bool + + +StatsDict = Dict[str, StatsItem] + + +@dataclass +class ColumnMetadata(JsonSchemaMixin): + type: str + comment: Optional[str] + index: int + name: str + + +ColumnMap = Dict[str, ColumnMetadata] + + +@dataclass +class TableMetadata(JsonSchemaMixin): + type: str + database: str + schema: str + name: str + comment: Optional[str] + owner: Optional[str] + + +@dataclass +class CatalogTable(JsonSchemaMixin, Replaceable): + metadata: TableMetadata + columns: ColumnMap + stats: StatsDict + # the same table with two unique IDs will just be listed two times + unique_id: Optional[str] = None + + def key(self) -> CatalogKey: + return CatalogKey( + self.metadata.database.lower(), + self.metadata.schema.lower(), + self.metadata.name.lower(), + ) + + +@dataclass +class CatalogResults(JsonSchemaMixin, Writable): + nodes: Dict[str, CatalogTable] + generated_at: datetime + _compile_results: Optional[Any] = None + + +@dataclass +class RemoteCatalogResults(CatalogResults): + logs: List[LogMessage] = field(default_factory=list) diff --git a/core/dbt/rpc/logger.py b/core/dbt/rpc/logger.py index 99885011299..1f7b733af6e 100644 --- a/core/dbt/rpc/logger.py +++ b/core/dbt/rpc/logger.py @@ -9,12 +9,18 @@ from queue import Empty from typing import Optional, Any, Union -from dbt.contracts.results import RemoteCompileResult, RemoteExecutionResult +from dbt.contracts.results import ( + RemoteCompileResult, RemoteExecutionResult, RemoteCatalogResults +) from dbt.exceptions import InternalException from dbt.utils import restrict_to -RemoteCallableResult = Union[RemoteCompileResult, RemoteExecutionResult] +RemoteCallableResult = Union[ + RemoteCompileResult, + RemoteExecutionResult, + RemoteCatalogResults, +] class QueueMessageType(StrEnum): diff --git a/core/dbt/rpc/response_manager.py b/core/dbt/rpc/response_manager.py index 31319418695..2d4a4252b57 100644 --- a/core/dbt/rpc/response_manager.py +++ b/core/dbt/rpc/response_manager.py @@ -88,7 +88,7 @@ def _get_responses(cls, requests, dispatcher): # to_dict if hasattr(output, 'result'): if isinstance(output.result, JsonSchemaMixin): - output.result = output.result.to_dict(omit_empty=False) + output.result = output.result.to_dict(omit_none=False) yield output @classmethod diff --git a/core/dbt/rpc/task_manager.py b/core/dbt/rpc/task_manager.py index c1e4dfb0582..79bbe3418d2 100644 --- a/core/dbt/rpc/task_manager.py +++ b/core/dbt/rpc/task_manager.py @@ -18,6 +18,7 @@ RemoteCompileResult, RemoteRunResult, RemoteExecutionResult, + RemoteCatalogResults, ) from dbt.logger import LogMessage from dbt.rpc.error import dbt_error, RPCException @@ -210,6 +211,24 @@ def from_result(cls, status, base): ) +@dataclass +class PollCatalogSuccessResult(PollResult, RemoteCatalogResults): + status: TaskHandlerState = field( + metadata=restrict_to(TaskHandlerState.Success), + default=TaskHandlerState.Success + ) + + @classmethod + def from_result(cls, status, base): + return cls( + status=status, + nodes=base.nodes, + generated_at=base.generated_at, + _compile_results=base._compile_results, + logs=base.logs, + ) + + def poll_success(status, logs, result): if status != TaskHandlerState.Success: raise dbt.exceptions.InternalException( @@ -223,6 +242,8 @@ def poll_success(status, logs, result): return PollRunSuccessResult.from_result(status=status, base=result) elif isinstance(result, RemoteCompileResult): return PollCompileSuccessResult.from_result(status=status, base=result) + elif isinstance(result, RemoteCatalogResults): + return PollCatalogSuccessResult.from_result(status=status, base=result) else: raise dbt.exceptions.InternalException( 'got invalid result in poll_success: {}'.format(result) diff --git a/core/dbt/task/generate.py b/core/dbt/task/generate.py index e98830d0518..d4c6d5eb48a 100644 --- a/core/dbt/task/generate.py +++ b/core/dbt/task/generate.py @@ -1,15 +1,17 @@ import os import shutil -from dataclasses import dataclass from datetime import datetime -from typing import Union, Dict, List, Optional, Any, NamedTuple +from typing import Dict, List, Any -from hologram import JsonSchemaMixin, ValidationError +from hologram import ValidationError from dbt.adapters.factory import get_adapter from dbt.contracts.graph.compiled import CompileResultNode from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.util import Writable, Replaceable +from dbt.contracts.results import ( + TableMetadata, CatalogTable, CatalogResults, Primitive, CatalogKey, + StatsItem, StatsDict, ColumnMetadata +) from dbt.include.global_project import DOCS_INDEX_FILE_PATH import dbt.ui.printer import dbt.utils @@ -33,87 +35,31 @@ def get_stripped_prefix(source: Dict[str, Any], prefix: str) -> Dict[str, Any]: } -Primitive = Union[bool, str, float, None] PrimitiveDict = Dict[str, Primitive] -Key = NamedTuple( - 'Key', - [('database', str), ('schema', str), ('name', str)] -) - - -@dataclass -class StatsItem(JsonSchemaMixin): - id: str - label: str - value: Primitive - description: str - include: bool - - -StatsDict = Dict[str, StatsItem] - - -@dataclass -class ColumnMetadata(JsonSchemaMixin): - type: str - comment: Optional[str] - index: int - name: str - - -ColumnMap = Dict[str, ColumnMetadata] - - -@dataclass -class TableMetadata(JsonSchemaMixin): - type: str - database: str - schema: str - name: str - comment: Optional[str] - owner: Optional[str] - +def build_catalog_table(data) -> CatalogTable: + # build the new table's metadata + stats + metadata = TableMetadata.from_dict(get_stripped_prefix(data, 'table_')) + stats = format_stats(get_stripped_prefix(data, 'stats:')) -@dataclass -class Table(JsonSchemaMixin, Replaceable): - metadata: TableMetadata - columns: ColumnMap - stats: StatsDict - # the same table with two unique IDs will just be listed two times - unique_id: Optional[str] = None - - @classmethod - def from_query_result(cls, data) -> 'Table': - # build the new table's metadata + stats - metadata = TableMetadata.from_dict(get_stripped_prefix(data, 'table_')) - stats = format_stats(get_stripped_prefix(data, 'stats:')) - - return cls( - metadata=metadata, - stats=stats, - columns={}, - ) - - def key(self) -> Key: - return Key( - self.metadata.database.lower(), - self.metadata.schema.lower(), - self.metadata.name.lower(), - ) + return CatalogTable( + metadata=metadata, + stats=stats, + columns={}, + ) # keys are database name, schema name, table name -class Catalog(Dict[Key, Table]): +class Catalog(Dict[CatalogKey, CatalogTable]): def __init__(self, columns: List[PrimitiveDict]): super().__init__() for col in columns: self.add_column(col) - def get_table(self, data: PrimitiveDict) -> Table: + def get_table(self, data: PrimitiveDict) -> CatalogTable: try: - key = Key( + key = CatalogKey( str(data['table_database']), str(data['table_schema']), str(data['table_name']), @@ -123,10 +69,11 @@ def get_table(self, data: PrimitiveDict) -> Table: 'Catalog information missing required key {} (got {})' .format(exc, data) ) + table: CatalogTable if key in self: table = self[key] else: - table = Table.from_query_result(data) + table = build_catalog_table(data) self[key] = table return table @@ -140,8 +87,10 @@ def add_column(self, data: PrimitiveDict): column = ColumnMetadata.from_dict(column_data) table.columns[column.name] = column - def make_unique_id_map(self, manifest: Manifest) -> Dict[str, Table]: - nodes: Dict[str, Table] = {} + def make_unique_id_map( + self, manifest: Manifest + ) -> Dict[str, CatalogTable]: + nodes: Dict[str, CatalogTable] = {} manifest_mapping = get_unique_id_mapping(manifest) for table in self.values(): @@ -201,16 +150,16 @@ def format_stats(stats: PrimitiveDict) -> StatsDict: return stats_collector -def mapping_key(node: CompileResultNode) -> Key: - return Key( +def mapping_key(node: CompileResultNode) -> CatalogKey: + return CatalogKey( node.database.lower(), node.schema.lower(), node.identifier.lower() ) -def get_unique_id_mapping(manifest: Manifest) -> Dict[Key, List[str]]: +def get_unique_id_mapping(manifest: Manifest) -> Dict[CatalogKey, List[str]]: # A single relation could have multiple unique IDs pointing to it if a # source were also a node. - ident_map: Dict[Key, List[str]] = {} + ident_map: Dict[CatalogKey, List[str]] = {} for unique_id, node in manifest.nodes.items(): key = mapping_key(node) @@ -221,13 +170,6 @@ def get_unique_id_mapping(manifest: Manifest) -> Dict[Key, List[str]]: return ident_map -@dataclass -class CatalogResults(JsonSchemaMixin, Writable): - nodes: Dict[str, Table] - generated_at: datetime - _compile_results: Optional[Any] = None - - def _coerce_decimal(value): if isinstance(value, dbt.utils.DECIMALS): return float(value) @@ -242,7 +184,7 @@ def _get_manifest(self) -> Manifest: def run(self): compile_results = None if self.args.compile: - compile_results = super().run() + compile_results = CompileTask.run(self) if any(r.error is not None for r in compile_results): dbt.ui.printer.print_timestamped_line( 'compile failed, cannot generate docs' @@ -266,10 +208,10 @@ def run(self): ] catalog = Catalog(catalog_data) - results = CatalogResults( + results = self.get_catalog_results( nodes=catalog.make_unique_id_map(manifest), generated_at=datetime.utcnow(), - _compile_results=compile_results, + compile_results=compile_results, ) path = os.path.join(self.config.target_path, CATALOG_FILENAME) @@ -280,6 +222,15 @@ def run(self): ) return results + def get_catalog_results( + self, nodes, generated_at, compile_results + ) -> CatalogResults: + return CatalogResults( + nodes=nodes, + generated_at=datetime.utcnow(), + _compile_results=compile_results, + ) + def interpret_results(self, results): compile_results = results._compile_results if compile_results is None: diff --git a/core/dbt/task/remote.py b/core/dbt/task/remote.py index 7abc149e3d5..db16026ca37 100644 --- a/core/dbt/task/remote.py +++ b/core/dbt/task/remote.py @@ -9,6 +9,7 @@ from dbt.adapters.factory import get_adapter from dbt.clients.jinja import extract_toplevel_blocks from dbt.compilation import compile_manifest +from dbt.contracts.results import RemoteCatalogResults from dbt.parser.results import ParseResult from dbt.parser.rpc import RPCCallParser, RPCMacroParser from dbt.parser.util import ParserUtils @@ -17,6 +18,7 @@ from dbt.rpc.node_runners import RPCCompileRunner, RPCExecuteRunner from dbt.rpc.task import RemoteCallableResult, RPCTask +from dbt.task.generate import GenerateTask from dbt.task.run import RunTask from dbt.task.seed import SeedTask from dbt.task.test import TestTask @@ -46,6 +48,11 @@ class RPCSeedProjectParameters(JsonSchemaMixin): show: bool = False +@dataclass +class RPCDocsGenerateProjectParameters(JsonSchemaMixin): + compile: bool = True + + class _RPCExecTask(RPCTask): def __init__(self, args, config, manifest): super().__init__(args, config) @@ -259,3 +266,36 @@ def handle_request( results = self.run() return results + + +class RemoteDocsGenerateProjectTask(RPCTask, GenerateTask): + METHOD_NAME = 'docs.generate' + + def __init__(self, args, config, manifest): + super().__init__(args, config) + self.manifest = manifest.deepcopy(config=config) + + def load_manifest(self): + # we started out with a manifest! + pass + + def handle_request( + self, params: RPCDocsGenerateProjectParameters, + ) -> RemoteCallableResult: + self.args.models = None + self.args.exclude = None + self.args.compile = params.compile + + results = self.run() + assert isinstance(results, RemoteCatalogResults) + return results + + def get_catalog_results( + self, nodes, generated_at, compile_results + ) -> RemoteCatalogResults: + return RemoteCatalogResults( + nodes=nodes, + generated_at=datetime.utcnow(), + _compile_results=compile_results, + logs=[], + ) diff --git a/core/dbt/task/rpc_server.py b/core/dbt/task/rpc_server.py index dc296af768b..9e9f3d331fe 100644 --- a/core/dbt/task/rpc_server.py +++ b/core/dbt/task/rpc_server.py @@ -22,6 +22,7 @@ RemoteRunTask, RemoteRunProjectTask, RemoteSeedProjectTask, RemoteTestProjectTask, + RemoteDocsGenerateProjectTask, ) from dbt.utils import ForgivingJSONEncoder, env_set_truthy from dbt import rpc @@ -149,7 +150,8 @@ def _default_tasks(): return [ RemoteCompileTask, RemoteCompileProjectTask, RemoteRunTask, RemoteRunProjectTask, - RemoteSeedProjectTask, RemoteTestProjectTask + RemoteSeedProjectTask, RemoteTestProjectTask, + RemoteDocsGenerateProjectTask ] def single_threaded(self): diff --git a/test/integration/048_rpc_test/test_rpc.py b/test/integration/048_rpc_test/test_rpc.py index f6f67e2fd43..08d48dc1f24 100644 --- a/test/integration/048_rpc_test/test_rpc.py +++ b/test/integration/048_rpc_test/test_rpc.py @@ -922,6 +922,34 @@ def test_postgres_gc_change_interval(self): result = self.assertIsResult(resp) self.assertEqual(len(result['rows']), 2) + @use_profile('postgres') + def test_docs_generate_postgres(self): + self.run_dbt_with_vars(['seed']) + self.run_dbt_with_vars(['run']) + self.assertFalse(os.path.exists('target/catalog.json')) + result = self.async_query('docs.generate').json() + dct = self.assertIsResult(result) + self.assertTrue(os.path.exists('target/catalog.json')) + self.assertIn('status', dct) + self.assertTrue(dct['status']) + self.assertIn('nodes', dct) + nodes = dct['nodes'] + self.assertEqual(len(nodes), 10) + expected = { + 'model.test.descendant_model', + 'model.test.multi_source_model', + 'model.test.nonsource_descendant', + 'seed.test.expected_multi_source', + 'seed.test.other_source_table', + 'seed.test.other_table', + 'seed.test.source', + 'source.test.other_source.test_table', + 'source.test.test_source.other_test_table', + 'source.test.test_source.test_table', + } + for uid in expected: + self.assertIn(uid, nodes) + class FailedServerProcess(ServerProcess): def _compare_result(self, result):