diff --git a/cosmos/core/airflow.py b/cosmos/core/airflow.py index acff5d012..9e1d08ac1 100644 --- a/cosmos/core/airflow.py +++ b/cosmos/core/airflow.py @@ -25,16 +25,15 @@ def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None module = importlib.import_module(module_name) Operator = getattr(module, class_name) + task_kwargs = {} if task.owner != "": - task_owner = task.owner - else: - task_owner = dag.owner + task_kwargs["owner"] = task.owner airflow_task = Operator( task_id=task.id, dag=dag, task_group=task_group, - owner=task_owner, + **task_kwargs, **({} if class_name == "EmptyOperator" else {"extra_context": task.extra_context}), **task.arguments, ) diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index ddffa226c..72a09a5e5 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -6,6 +6,7 @@ import pytest from airflow import __version__ as airflow_version from airflow.models import DAG +from airflow.models.abstractoperator import DEFAULT_OWNER from airflow.utils.task_group import TaskGroup from packaging import version @@ -130,8 +131,9 @@ def test_build_airflow_graph_with_after_each(): task_seed_parent_seed = dag.tasks[0] task_parent_run = dag.tasks[1] - assert task_seed_parent_seed.owner == "" + assert task_seed_parent_seed.owner == DEFAULT_OWNER assert task_parent_run.owner == "parent_node" + assert {d for d in dag.owner.split(", ")} == {DEFAULT_OWNER, "parent_node"} @pytest.mark.parametrize( @@ -604,3 +606,49 @@ def test_airflow_kwargs_generation(): result = airflow_kwargs(**task_args) assert "dag" in result + + +@pytest.mark.parametrize( + "dbt_extra_config,expected_owner", + [ + ({}, DEFAULT_OWNER), + ({"meta": {}}, DEFAULT_OWNER), + ({"meta": {"owner": ""}}, DEFAULT_OWNER), + ({"meta": {"owner": "dbt-owner"}}, "dbt-owner"), + ], +) +def test_owner(dbt_extra_config, expected_owner): + with DAG("test-task-group-after-each", start_date=datetime(2022, 1, 1)) as dag: + node = DbtNode( + unique_id=f"{DbtResourceType.MODEL.value}.my_folder.my_model", + resource_type=DbtResourceType.MODEL, + file_path=SAMPLE_PROJ_PATH / "gen2/models/parent.sql", + tags=["has_child"], + config={"materialized": "view", **dbt_extra_config}, + depends_on=[], + ) + + output: TaskGroup = generate_task_or_group( + dag=dag, + task_group=None, + node=node, + execution_mode=ExecutionMode.LOCAL, + test_indirect_selection=TestIndirectSelection.EAGER, + task_args={ + "project_dir": SAMPLE_PROJ_PATH, + "profile_config": ProfileConfig( + profile_name="default", + target_name="default", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="fake_conn", + profile_args={"schema": "public"}, + ), + ), + }, + test_behavior=TestBehavior.AFTER_EACH, + on_warning_callback=None, + source_rendering_behavior=SOURCE_RENDERING_BEHAVIOR, + ) + + assert len(output.leaves) == 1 + assert output.leaves[0].owner == expected_owner