From 05c0bcbb95f33e43de6405f51bc9849171ff5ab7 Mon Sep 17 00:00:00 2001 From: Julian Maicher Date: Fri, 6 Sep 2024 08:45:37 +0200 Subject: [PATCH 1/2] Fix task owner fallback `dag.owner` is a computed property that joins owners of existing tasks. We should rely on airflow's existing owner fallback in airflow.models.baseoperator.BaseOperator. Fixes #1194 --- cosmos/core/airflow.py | 7 +++-- tests/airflow/test_graph.py | 51 ++++++++++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/cosmos/core/airflow.py b/cosmos/core/airflow.py index acff5d012..9e1d08ac1 100644 --- a/cosmos/core/airflow.py +++ b/cosmos/core/airflow.py @@ -25,16 +25,15 @@ def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None module = importlib.import_module(module_name) Operator = getattr(module, class_name) + task_kwargs = {} if task.owner != "": - task_owner = task.owner - else: - task_owner = dag.owner + task_kwargs["owner"] = task.owner airflow_task = Operator( task_id=task.id, dag=dag, task_group=task_group, - owner=task_owner, + **task_kwargs, **({} if class_name == "EmptyOperator" else {"extra_context": task.extra_context}), **task.arguments, ) diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index ddffa226c..c1f0b2947 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -1,11 +1,13 @@ import os from datetime import datetime from pathlib import Path +from typing import Any from unittest.mock import patch import pytest from airflow import __version__ as airflow_version from airflow.models import DAG +from airflow.models.abstractoperator import DEFAULT_OWNER from airflow.utils.task_group import TaskGroup from packaging import version @@ -130,8 +132,9 @@ def test_build_airflow_graph_with_after_each(): task_seed_parent_seed = dag.tasks[0] task_parent_run = dag.tasks[1] - assert task_seed_parent_seed.owner == "" + assert task_seed_parent_seed.owner == DEFAULT_OWNER assert task_parent_run.owner == "parent_node" + assert {d for d in dag.owner.split(", ")} == {DEFAULT_OWNER, "parent_node"} @pytest.mark.parametrize( @@ -604,3 +607,49 @@ def test_airflow_kwargs_generation(): result = airflow_kwargs(**task_args) assert "dag" in result + + +@pytest.mark.parametrize( + "dbt_extra_config,expected_owner", + [ + ({}, DEFAULT_OWNER), + ({"meta": {}}, DEFAULT_OWNER), + ({"meta": {"owner": ""}}, DEFAULT_OWNER), + ({"meta": {"owner": "dbt-owner"}}, "dbt-owner"), + ], +) +def test_owner(dbt_extra_config, expected_owner): + with DAG("test-task-group-after-each", start_date=datetime(2022, 1, 1)) as dag: + node = DbtNode( + unique_id=f"{DbtResourceType.MODEL.value}.my_folder.my_model", + resource_type=DbtResourceType.MODEL, + file_path=SAMPLE_PROJ_PATH / "gen2/models/parent.sql", + tags=["has_child"], + config={"materialized": "view", **dbt_extra_config}, + depends_on=[] + ) + + output: TaskGroup = generate_task_or_group( + dag=dag, + task_group=None, + node=node, + execution_mode=ExecutionMode.LOCAL, + test_indirect_selection=TestIndirectSelection.EAGER, + task_args={ + "project_dir": SAMPLE_PROJ_PATH, + "profile_config": ProfileConfig( + profile_name="default", + target_name="default", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="fake_conn", + profile_args={"schema": "public"}, + ), + ), + }, + test_behavior=TestBehavior.AFTER_EACH, + on_warning_callback=None, + source_rendering_behavior=SOURCE_RENDERING_BEHAVIOR, + ) + + assert len(output.leaves) == 1 + assert output.leaves[0].owner == expected_owner From fbd1b5233d9e75181980c776c55152c06cb944ae Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Sep 2024 06:45:59 +0000 Subject: [PATCH 2/2] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/airflow/test_graph.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index c1f0b2947..72a09a5e5 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -1,7 +1,6 @@ import os from datetime import datetime from pathlib import Path -from typing import Any from unittest.mock import patch import pytest @@ -626,7 +625,7 @@ def test_owner(dbt_extra_config, expected_owner): file_path=SAMPLE_PROJ_PATH / "gen2/models/parent.sql", tags=["has_child"], config={"materialized": "view", **dbt_extra_config}, - depends_on=[] + depends_on=[], ) output: TaskGroup = generate_task_or_group(