Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to select using graph-operators when using LoadMode.CUSTOM or LoadMode.DBT_MANIFEST #728

Merged
merged 15 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 197 additions & 18 deletions cosmos/dbt/selector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations
from pathlib import Path
import copy

import re
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any

from cosmos.constants import DbtResourceType
Expand All @@ -16,11 +18,154 @@
PATH_SELECTOR = "path:"
TAG_SELECTOR = "tag:"
CONFIG_SELECTOR = "config."

PLUS_SELECTOR = "+"
GRAPH_SELECTOR_REGEX = r"^([0-9]*\+)?([^\+]+)(\+[0-9]*)?$|"

logger = get_logger(__name__)


@dataclass
class GraphSelector:
"""
Implements dbt graph operator selectors:
model_a
+model_b
model_c+
+model_d+
2+model_e
model_f+3

https://docs.getdbt.com/reference/node-selection/graph-operators
"""

node_name: str
precursors: str | None
descendants: str | None

@property
def precursors_depth(self) -> int:
"""
Calculates the depth/degrees/generations of precursors (parents).
Return:
-1: if it should return all the generations of precursors
0: if it shouldn't return any precursors
>0: upperbound number of parent generations
"""
if not self.precursors:
return 0
if self.precursors == "+":
return -1
else:
return int(self.precursors[:-1])

@property
def descendants_depth(self) -> int:
"""
Calculates the depth/degrees/generations of descendants (children).
Return:
-1: if it should return all the generations of children
0: if it shouldn't return any children
>0: upperbound of children generations
"""
if not self.descendants:
return 0
if self.descendants == "+":
return -1
else:
return int(self.descendants[1:])

@staticmethod
def parse(text: str) -> GraphSelector | None:
"""
Parse a string and identify if there are graph selectors, including the desired node name, descendants and
precursors. Return a GraphSelector instance if the pattern matches.
"""
regex_match = re.search(GRAPH_SELECTOR_REGEX, text)
if regex_match:
precursors, node_name, descendants = regex_match.groups()
return GraphSelector(node_name, precursors, descendants)
return None

def select_node_precursors(self, nodes: dict[str, DbtNode], root_id: str, selected_nodes: set[str]) -> None:
"""
Parse original nodes and add the precursor nodes related to this config to the selected_nodes set.

:param nodes: Original dbt nodes list
:param root_id: Unique identifier of self.node_name
:param selected_nodes: Set where precursor nodes will be added to.
"""
if self.precursors:
depth = self.precursors_depth
previous_generation = {root_id}
processed_nodes = set()
while depth and previous_generation:
new_generation: set[str] = set()
for node_id in previous_generation:
if node_id not in processed_nodes:
new_generation.update(set(nodes[node_id].depends_on))
processed_nodes.add(node_id)
selected_nodes.update(new_generation)
previous_generation = new_generation
depth -= 1

def select_node_descendants(self, nodes: dict[str, DbtNode], root_id: str, selected_nodes: set[str]) -> None:
"""
Parse original nodes and add the descendant nodes related to this config to the selected_nodes set.

:param nodes: Original dbt nodes list
:param root_id: Unique identifier of self.node_name
:param selected_nodes: Set where descendant nodes will be added to.
"""
if self.descendants:
children_by_node = defaultdict(set)
# Index nodes by parent id
# We could optimize by doing this only once for the dbt project and giving it
# as a parameter to the GraphSelector
for node_id, node in nodes.items():
for parent_id in node.depends_on:
children_by_node[parent_id].add(node_id)

depth = self.descendants_depth
previous_generation = {root_id}
processed_nodes = set()
while depth and previous_generation:
new_generation: set[str] = set()
for node_id in previous_generation:
if node_id not in processed_nodes:
new_generation.update(children_by_node[node_id])
processed_nodes.add(node_id)
selected_nodes.update(new_generation)
previous_generation = new_generation
depth -= 1

def filter_nodes(self, nodes: dict[str, DbtNode]) -> set[str]:
"""
Given a dictionary with the original dbt project nodes, applies the current graph selector to
identify the subset of nodes that matches the selection criteria.

:param nodes: dbt project nodes
:return: set of node ids that matches current graph selector
"""
selected_nodes: set[str] = set()

# Index nodes by name, we can improve performance by doing this once
# for multiple GraphSelectors
node_by_name = {}
for node_id, node in nodes.items():
node_by_name[node.name] = node_id

if self.node_name in node_by_name:
root_id = node_by_name[self.node_name]
else:
logger.warn(f"Selector {self.node_name} not found.")
return selected_nodes

selected_nodes.add(root_id)
self.select_node_precursors(nodes, root_id, selected_nodes)
self.select_node_descendants(nodes, root_id, selected_nodes)
return selected_nodes


class SelectorConfig:
"""
Represents a select/exclude statement.
Expand All @@ -43,11 +188,12 @@ def __init__(self, project_dir: Path | None, statement: str):
self.tags: list[str] = []
self.config: dict[str, str] = {}
self.other: list[str] = []
self.graph_selectors: list[GraphSelector] = []
self.load_from_statement(statement)

@property
def is_empty(self) -> bool:
return not (self.paths or self.tags or self.config or self.other)
return not (self.paths or self.tags or self.config or self.graph_selectors or self.other)

def load_from_statement(self, statement: str) -> None:
"""
Expand All @@ -61,6 +207,7 @@ def load_from_statement(self, statement: str) -> None:
https://docs.getdbt.com/reference/node-selection/yaml-selectors
"""
items = statement.split(",")

for item in items:
if item.startswith(PATH_SELECTOR):
index = len(PATH_SELECTOR)
Expand All @@ -77,11 +224,16 @@ def load_from_statement(self, statement: str) -> None:
if key in SUPPORTED_CONFIG:
self.config[key] = value
else:
self.other.append(item)
logger.warning("Unsupported select statement: %s", item)
if item:
graph_selector = GraphSelector.parse(item)
if graph_selector is not None:
self.graph_selectors.append(graph_selector)
else:
self.other.append(item)
logger.warning("Unsupported select statement: %s", item)

def __repr__(self) -> str:
return f"SelectorConfig(paths={self.paths}, tags={self.tags}, config={self.config}, other={self.other})"
return f"SelectorConfig(paths={self.paths}, tags={self.tags}, config={self.config}, other={self.other}, graph_selectors={self.graph_selectors})"


class NodeSelector:
Expand All @@ -95,7 +247,9 @@ class NodeSelector:
def __init__(self, nodes: dict[str, DbtNode], config: SelectorConfig) -> None:
self.nodes = nodes
self.config = config
self.selected_nodes: set[str] = set()

@property
def select_nodes_ids_by_intersection(self) -> set[str]:
"""
Return a list of node ids which matches the configuration defined in config.
Expand All @@ -107,14 +261,19 @@ def select_nodes_ids_by_intersection(self) -> set[str]:
if self.config.is_empty:
return set(self.nodes.keys())

self.selected_nodes: set[str] = set()
selected_nodes: set[str] = set()
self.visited_nodes: set[str] = set()

for node_id, node in self.nodes.items():
if self._should_include_node(node_id, node):
self.selected_nodes.add(node_id)
selected_nodes.add(node_id)

if self.config.graph_selectors:
nodes_by_graph_selector = self.select_by_graph_operator()
selected_nodes = selected_nodes.intersection(nodes_by_graph_selector)

return self.selected_nodes
self.selected_nodes = selected_nodes
return selected_nodes

def _should_include_node(self, node_id: str, node: DbtNode) -> bool:
"Checks if a single node should be included. Only runs once per node with caching."
Expand Down Expand Up @@ -175,6 +334,22 @@ def _is_path_matching(self, node: DbtNode) -> bool:
return self._should_include_node(node.depends_on[0], model_node)
return False

def select_by_graph_operator(self) -> set[str]:
"""
Return a list of node ids which match the configuration defined in the config.

Return all nodes that are parents (or parents from parents) of the root defined in the configuration.

References:
https://docs.getdbt.com/reference/node-selection/syntax
https://docs.getdbt.com/reference/node-selection/yaml-selectors
"""
selected_nodes_by_selector: list[set[str]] = []

for graph_selector in self.config.graph_selectors:
selected_nodes_by_selector.append(graph_selector.filter_nodes(self.nodes))
return set.intersection(*selected_nodes_by_selector)


def retrieve_by_label(statement_list: list[str], label: str) -> set[str]:
"""
Expand All @@ -189,7 +364,7 @@ def retrieve_by_label(statement_list: list[str], label: str) -> set[str]:
for statement in statement_list:
config = SelectorConfig(Path(), statement)
item_values = getattr(config, label)
label_values = label_values.union(item_values)
label_values.update(item_values)

return label_values

Expand Down Expand Up @@ -217,20 +392,24 @@ def select_nodes(
filters = [["select", select], ["exclude", exclude]]
for filter_type, filter in filters:
for filter_parameter in filter:
if filter_parameter.startswith(PATH_SELECTOR) or filter_parameter.startswith(TAG_SELECTOR):
if (
filter_parameter.startswith(PATH_SELECTOR)
or filter_parameter.startswith(TAG_SELECTOR)
or PLUS_SELECTOR in filter_parameter
or any([filter_parameter.startswith(CONFIG_SELECTOR + config + ":") for config in SUPPORTED_CONFIG])
):
continue
elif any([filter_parameter.startswith(CONFIG_SELECTOR + config + ":") for config in SUPPORTED_CONFIG]):
continue
else:
elif ":" in filter_parameter:
raise CosmosValueError(f"Invalid {filter_type} filter: {filter_parameter}")

subset_ids: set[str] = set()

for statement in select:
config = SelectorConfig(project_dir, statement)
node_selector = NodeSelector(nodes, config)
select_ids = node_selector.select_nodes_ids_by_intersection()
subset_ids = subset_ids.union(set(select_ids))

select_ids = node_selector.select_nodes_ids_by_intersection
subset_ids.update(set(select_ids))

if select:
nodes = {id_: nodes[id_] for id_ in subset_ids}
Expand All @@ -241,7 +420,7 @@ def select_nodes(
for statement in exclude:
config = SelectorConfig(project_dir, statement)
node_selector = NodeSelector(nodes, config)
exclude_ids = exclude_ids.union(set(node_selector.select_nodes_ids_by_intersection()))
exclude_ids.update(set(node_selector.select_nodes_ids_by_intersection))
subset_ids = set(nodes_ids) - set(exclude_ids)

return {id_: nodes[id_] for id_ in subset_ids}
35 changes: 34 additions & 1 deletion docs/configuration/selecting-excluding.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ The ``select`` and ``exclude`` parameters are lists, with values like the follow
- ``tag:my_tag``: include/exclude models with the tag ``my_tag``
- ``config.materialized:table``: include/exclude models with the config ``materialized: table``
- ``path:analytics/tables``: include/exclude models in the ``analytics/tables`` directory

- ``+node_name+1`` (graph operators): include/exclude the node with name ``node_name``, all its parents, and its first generation of children (`dbt graph selector docs <https://docs.getdbt.com/reference/node-selection/graph-operators>`_)
- ``tag:my_tag,+node_name`` (intersection): include/exclude ``node_name`` and its parents if they have the tag ``my_tag`` (`dbt set operator docs <https://docs.getdbt.com/reference/node-selection/set-operators>`_)
- ``['tag:first_tag', 'tag:second_tag']`` (union): include/exclude nodes that have either ``tag:first_tag`` or ``tag:second_tag``

.. note::

Expand Down Expand Up @@ -51,3 +53,34 @@ Examples:
select=["path:analytics/tables"],
)
)


.. code-block:: python

from cosmos import DbtDag, RenderConfig

jaffle_shop = DbtDag(
render_config=RenderConfig(
select=["tag:include_tag1", "tag:include_tag2"], # union
)
)

.. code-block:: python

from cosmos import DbtDag, RenderConfig

jaffle_shop = DbtDag(
render_config=RenderConfig(
select=["tag:include_tag1,tag:include_tag2"], # intersection
)
)

.. code-block:: python

from cosmos import DbtDag, RenderConfig

jaffle_shop = DbtDag(
render_config=RenderConfig(
exclude=["node_name+"], # node_name and its children
)
)
Loading
Loading