Skip to content

Commit

Permalink
Merge branch 'main' into snowflake-encrypted-pk-path
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana authored Oct 18, 2023
2 parents a90316a + c82e6bc commit 79eac78
Show file tree
Hide file tree
Showing 8 changed files with 373 additions and 18 deletions.
17 changes: 12 additions & 5 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def create_test_task_metadata(
execution_mode: ExecutionMode,
task_args: dict[str, Any],
on_warning_callback: Callable[..., Any] | None = None,
model_name: str | None = None,
node: DbtNode | None = None,
) -> TaskMetadata:
"""
Create the metadata that will be used to instantiate the Airflow Task that will be used to run the Dbt test node.
Expand All @@ -66,13 +66,18 @@ def create_test_task_metadata(
:param task_args: Arguments to be used to instantiate an Airflow Task
:param on_warning_callback: A callback function called on warnings with additional Context variables “test_names”
and “test_results” of type List.
:param model_name: If the test relates to a specific model, the name of the model it relates to
:param node: If the test relates to a specific node, the node reference
:returns: The metadata necessary to instantiate the source dbt node as an Airflow task.
"""
task_args = dict(task_args)
task_args["on_warning_callback"] = on_warning_callback
if model_name is not None:
task_args["models"] = model_name
if node is not None:
if node.resource_type == DbtResourceType.MODEL:
task_args["models"] = node.name
elif node.resource_type == DbtResourceType.SOURCE:
task_args["select"] = f"source:{node.unique_id[len('source.'):]}"
else: # tested with node.resource_type == DbtResourceType.SEED or DbtResourceType.SNAPSHOT
task_args["select"] = node.name
return TaskMetadata(
id=test_task_name,
operator_class=calculate_operator_class(
Expand Down Expand Up @@ -112,6 +117,8 @@ def create_task_metadata(
task_id = "run"
else:
task_id = f"{node.name}_{node.resource_type.value}"
if use_task_group is True:
task_id = node.resource_type.value

task_metadata = TaskMetadata(
id=task_id,
Expand Down Expand Up @@ -163,7 +170,7 @@ def generate_task_or_group(
"test",
execution_mode,
task_args=task_args,
model_name=node.name,
node=node,
on_warning_callback=on_warning_callback,
)
test_task = create_airflow_task(test_meta, dag, task_group=model_task_group)
Expand Down
8 changes: 4 additions & 4 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _missing_value_(cls, value): # type: ignore

DEFAULT_DBT_RESOURCES = DbtResourceType.__members__.values()


TESTABLE_DBT_RESOURCES = {
DbtResourceType.MODEL
} # TODO: extend with DbtResourceType.SOURCE, DbtResourceType.SNAPSHOT, DbtResourceType.SEED)
# dbt test runs tests defined on models, sources, snapshots, and seeds.
# It expects that you have already created those resources through the appropriate commands.
# https://docs.getdbt.com/reference/commands/test
TESTABLE_DBT_RESOURCES = {DbtResourceType.MODEL, DbtResourceType.SOURCE, DbtResourceType.SNAPSHOT, DbtResourceType.SEED}
3 changes: 3 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .trino.certificate import TrinoCertificateProfileMapping
from .trino.jwt import TrinoJWTProfileMapping
from .trino.ldap import TrinoLDAPProfileMapping
from .vertica.user_pass import VerticaUserPasswordProfileMapping

profile_mappings: list[Type[BaseProfileMapping]] = [
AthenaAccessKeyProfileMapping,
Expand All @@ -38,6 +39,7 @@
TrinoLDAPProfileMapping,
TrinoCertificateProfileMapping,
TrinoJWTProfileMapping,
VerticaUserPasswordProfileMapping,
]


Expand Down Expand Up @@ -74,4 +76,5 @@ def get_automatic_profile_mapping(
"TrinoLDAPProfileMapping",
"TrinoCertificateProfileMapping",
"TrinoJWTProfileMapping",
"VerticaUserPasswordProfileMapping",
]
5 changes: 5 additions & 0 deletions cosmos/profiles/vertica/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"Vertica Airflow connection -> dbt profile mappings"

from .user_pass import VerticaUserPasswordProfileMapping

__all__ = ["VerticaUserPasswordProfileMapping"]
76 changes: 76 additions & 0 deletions cosmos/profiles/vertica/user_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"Maps Airflow Vertica connections using user + password authentication to dbt profiles."
from __future__ import annotations

from typing import Any

from ..base import BaseProfileMapping


class VerticaUserPasswordProfileMapping(BaseProfileMapping):
"""
Maps Airflow Vertica connections using user + password authentication to dbt profiles.
https://docs.getdbt.com/reference/warehouse-setups/vertica-setup
https://airflow.apache.org/docs/apache-airflow-providers-vertica/stable/connections/vertica.html
"""

airflow_connection_type: str = "vertica"
dbt_profile_type: str = "vertica"

required_fields = [
"host",
"user",
"password",
"database",
"schema",
]
secret_fields = [
"password",
]
airflow_param_mapping = {
"host": "host",
"user": "login",
"password": "password",
"port": "port",
"schema": "schema",
"database": "extra.database",
"autocommit": "extra.autocommit",
"backup_server_node": "extra.backup_server_node",
"binary_transfer": "extra.binary_transfer",
"connection_load_balance": "extra.connection_load_balance",
"connection_timeout": "extra.connection_timeout",
"disable_copy_local": "extra.disable_copy_local",
"kerberos_host_name": "extra.kerberos_host_name",
"kerberos_service_name": "extra.kerberos_service_name",
"log_level": "extra.log_level",
"log_path": "extra.log_path",
"oauth_access_token": "extra.oauth_access_token",
"request_complex_types": "extra.request_complex_types",
"session_label": "extra.session_label",
"ssl": "extra.ssl",
"unicode_error": "extra.unicode_error",
"use_prepared_statements": "extra.use_prepared_statements",
"workload": "extra.workload",
}

@property
def profile(self) -> dict[str, Any | None]:
"Gets profile. The password is stored in an environment variable."
profile = {
"port": 5433,
**self.mapped_params,
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
}

return self.filter_null(profile)

@property
def mock_profile(self) -> dict[str, Any | None]:
"Gets mock profile. Defaults port to 5433."
parent_mock = super().mock_profile

return {
"port": 5433,
**parent_mock,
}
14 changes: 9 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ dbt-all = [
"dbt-redshift",
"dbt-snowflake",
"dbt-spark",
"dbt-vertica",
]
dbt-athena = [
"dbt-athena-community",
Expand All @@ -80,6 +81,9 @@ dbt-snowflake = [
dbt-spark = [
"dbt-spark",
]
dbt-vertica = [
"dbt-vertica<=1.5.4",
]
openlineage = [
"openlineage-integration-common",
"openlineage-airflow",
Expand Down Expand Up @@ -165,18 +169,18 @@ test = 'pytest -vv --durations=0 . -m "not integration" --ignore=tests/test_exam
test-cov = """pytest -vv --cov=cosmos --cov-report=term-missing --cov-report=xml --durations=0 -m "not integration" --ignore=tests/test_example_dags.py --ignore=tests/test_example_dags_no_connections.py"""
# we install using the following workaround to overcome installation conflicts, such as:
# apache-airflow 2.3.0 and dbt-core [0.13.0 - 1.5.2] and jinja2>=3.0.0 because these package versions have conflicting dependencies
test-integration-setup = """pip uninstall -y dbt-core dbt-databricks dbt-sqlite dbt-postgres dbt-sqlite; \
test-integration-setup = """pip uninstall dbt-postgres dbt-databricks dbt-vertica; \
rm -rf airflow.*; \
airflow db init; \
pip install 'dbt-core' 'dbt-databricks' 'dbt-postgres' 'openlineage-airflow'"""
test-integration = """pytest -vv \
pip install 'dbt-core' 'dbt-databricks' 'dbt-postgres' 'dbt-vertica' 'openlineage-airflow'"""
test-integration = """rm -rf dbt/jaffle_shop/dbt_packages;
pytest -vv \
--cov=cosmos \
--cov-report=term-missing \
--cov-report=xml \
--durations=0 \
-m integration \
-k 'not (sqlite or example_cosmos_sources or example_cosmos_python_models or example_virtualenv or cosmos_manifest_example)'
"""
-k 'not (sqlite or example_cosmos_sources or example_cosmos_python_models or example_virtualenv or cosmos_manifest_example)'"""
test-integration-expensive = """pytest -vv \
--cov=cosmos \
--cov-report=term-missing \
Expand Down
77 changes: 73 additions & 4 deletions tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from airflow import __version__ as airflow_version
from airflow.models import DAG
from airflow.utils.task_group import TaskGroup
from packaging import version

from cosmos.airflow.graph import (
Expand All @@ -13,6 +14,7 @@
calculate_operator_class,
create_task_metadata,
create_test_task_metadata,
generate_task_or_group,
)
from cosmos.config import ProfileConfig
from cosmos.constants import DbtResourceType, ExecutionMode, TestBehavior
Expand Down Expand Up @@ -101,6 +103,50 @@ def test_build_airflow_graph_with_after_each():
assert dag.leaves[0].task_id == "child_run"


@pytest.mark.parametrize(
"node_type,task_suffix",
[(DbtResourceType.MODEL, "run"), (DbtResourceType.SEED, "seed"), (DbtResourceType.SNAPSHOT, "snapshot")],
)
def test_create_task_group_for_after_each_supported_nodes(node_type, task_suffix):
"""
dbt test runs tests defined on models, sources, snapshots, and seeds.
It expects that you have already created those resources through the appropriate commands.
https://docs.getdbt.com/reference/commands/test
"""
with DAG("test-task-group-after-each", start_date=datetime(2022, 1, 1)) as dag:
node = DbtNode(
name="dbt_node",
unique_id="dbt_node",
resource_type=node_type,
file_path=SAMPLE_PROJ_PATH / "gen2/models/parent.sql",
tags=["has_child"],
config={"materialized": "view"},
depends_on=[],
has_test=True,
)
output = generate_task_or_group(
dag=dag,
task_group=None,
node=node,
execution_mode=ExecutionMode.LOCAL,
task_args={
"project_dir": SAMPLE_PROJ_PATH,
"profile_config": ProfileConfig(
profile_name="default",
target_name="default",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="fake_conn",
profile_args={"schema": "public"},
),
),
},
test_behavior=TestBehavior.AFTER_EACH,
on_warning_callback=None,
)
assert isinstance(output, TaskGroup)
assert list(output.children.keys()) == [f"dbt_node.{task_suffix}", "dbt_node.test"]


@pytest.mark.skipif(
version.parse(airflow_version) < version.parse("2.4"),
reason="Airflow DAG did not have task_group_dict until the 2.4 release",
Expand Down Expand Up @@ -259,7 +305,12 @@ def test_create_task_metadata_seed(caplog, use_task_group):
args={},
use_task_group=use_task_group,
)
assert metadata.id == "my_seed_seed"

if not use_task_group:
assert metadata.id == "my_seed_seed"
else:
assert metadata.id == "seed"

assert metadata.operator_class == "cosmos.operators.docker.DbtSeedDockerOperator"
assert metadata.arguments == {"models": "my_seed"}

Expand All @@ -280,14 +331,32 @@ def test_create_task_metadata_snapshot(caplog):
assert metadata.arguments == {"models": "my_snapshot"}


def test_create_test_task_metadata():
@pytest.mark.parametrize(
"node_type,node_unique_id,selector_key,selector_value",
[
(DbtResourceType.MODEL, "node_name", "models", "node_name"),
(DbtResourceType.SEED, "node_name", "select", "node_name"),
(DbtResourceType.SOURCE, "source.node_name", "select", "source:node_name"),
(DbtResourceType.SNAPSHOT, "node_name", "select", "node_name"),
],
)
def test_create_test_task_metadata(node_type, node_unique_id, selector_key, selector_value):
sample_node = DbtNode(
name="node_name",
unique_id=node_unique_id,
resource_type=node_type,
depends_on=[],
file_path="",
tags=[],
config={},
)
metadata = create_test_task_metadata(
test_task_name="test_no_nulls",
execution_mode=ExecutionMode.LOCAL,
task_args={"task_arg": "value"},
on_warning_callback=True,
model_name="my_model",
node=sample_node,
)
assert metadata.id == "test_no_nulls"
assert metadata.operator_class == "cosmos.operators.local.DbtTestLocalOperator"
assert metadata.arguments == {"task_arg": "value", "on_warning_callback": True, "models": "my_model"}
assert metadata.arguments == {"task_arg": "value", "on_warning_callback": True, selector_key: selector_value}
Loading

0 comments on commit 79eac78

Please sign in to comment.