From c7a25ac6c2426a403a02e99bf917d1f50757503e Mon Sep 17 00:00:00 2001 From: Pankaj Date: Tue, 6 Feb 2024 15:03:49 +0530 Subject: [PATCH 1/5] Add ExternalDeploymentSensor --- .../core/example_dags/example_astro.py | 63 ++++++++ astronomer/providers/core/hooks/__init__.py | 0 astronomer/providers/core/hooks/astro.py | 112 ++++++++++++++ astronomer/providers/core/sensors/astro.py | 115 ++++++++++++++ astronomer/providers/core/triggers/astro.py | 78 ++++++++++ astronomer/providers/package.py | 7 +- tests/core/hooks/__init__.py | 0 tests/core/hooks/test_astro.py | 144 ++++++++++++++++++ tests/core/sensors/test_astro.py | 90 +++++++++++ tests/core/triggers/test_astro.py | 120 +++++++++++++++ 10 files changed, 728 insertions(+), 1 deletion(-) create mode 100644 astronomer/providers/core/example_dags/example_astro.py create mode 100644 astronomer/providers/core/hooks/__init__.py create mode 100644 astronomer/providers/core/hooks/astro.py create mode 100644 astronomer/providers/core/sensors/astro.py create mode 100644 astronomer/providers/core/triggers/astro.py create mode 100644 tests/core/hooks/__init__.py create mode 100644 tests/core/hooks/test_astro.py create mode 100644 tests/core/sensors/test_astro.py create mode 100644 tests/core/triggers/test_astro.py diff --git a/astronomer/providers/core/example_dags/example_astro.py b/astronomer/providers/core/example_dags/example_astro.py new file mode 100644 index 000000000..fc4b60b11 --- /dev/null +++ b/astronomer/providers/core/example_dags/example_astro.py @@ -0,0 +1,63 @@ +import time +from datetime import datetime + +from airflow import DAG +from airflow.decorators import task +from airflow.operators.trigger_dagrun import TriggerDagRunOperator + +from astronomer.providers.core.sensors.astro import ExternalDeploymentSensor + +with DAG( + dag_id="example_astro_task", + start_date=datetime(2022, 1, 1), + schedule=None, + catchup=False, + tags=["example", "async", "core"], +): + ExternalDeploymentSensor( + task_id="test1", + external_dag_id="example_wait_to_test_example_astro_task", + ) + + ExternalDeploymentSensor( + task_id="test2", + external_dag_id="example_wait_to_test_example_astro_task", + external_task_id="wait_for_2_min", + ) + +with DAG( + dag_id="wait_to_test_example_astro_dag", + start_date=datetime(2022, 1, 1), + schedule=None, + catchup=False, + tags=["example", "async", "core"], +): + + @task + def wait_for_2_min() -> None: + """Wait for 2 min.""" + time.sleep(120) + + wait_for_2_min() + + +with DAG( + dag_id="trigger_astro_test_and_example", + start_date=datetime(2022, 1, 1), + schedule=None, + catchup=False, + tags=["example", "async", "core"], +): + run_wait_dag = TriggerDagRunOperator( + task_id="run_wait_dag", + trigger_dag_id="example_external_task_async_waits_for_me", + wait_for_completion=False, + ) + + run_astro_dag = TriggerDagRunOperator( + task_id="run_astro_dag", + trigger_dag_id="example_astro_task", + wait_for_completion=False, + ) + + run_wait_dag >> run_astro_dag diff --git a/astronomer/providers/core/hooks/__init__.py b/astronomer/providers/core/hooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/astronomer/providers/core/hooks/astro.py b/astronomer/providers/core/hooks/astro.py new file mode 100644 index 000000000..eaedcc908 --- /dev/null +++ b/astronomer/providers/core/hooks/astro.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import os +from typing import Any +from urllib.parse import quote + +import requests +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook + + +class AstroHook(BaseHook): + """ + Custom Apache Airflow Hook for interacting with Astro Cloud API. + + :param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials. + """ + + conn_name_attr = "astro_cloud_conn_id" + default_conn_name = "astro_cloud_default" + conn_type = "Astro Cloud" + hook_name = "Astro Cloud" + + def __init__(self, astro_cloud_conn_id: str = "astro_cloud_conn_id"): + super().__init__() + self.astro_cloud_conn_id = astro_cloud_conn_id + + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: + """ + Returns UI field behavior customization for the Astro Cloud connection. + + This method defines hidden fields, relabeling, and placeholders for UI display. + """ + return { + "hidden_fields": ["login", "port", "schema", "extra"], + "relabeling": { + "password": "Astro Cloud API Token", + }, + "placeholders": { + "host": "https://clmkpsyfc010391acjie00t1l.astronomer.run/d5lc9c9x", + "password": "JWT API Token", + }, + } + + def get_conn(self) -> tuple[str, str]: + """Retrieves the Astro Cloud connection details.""" + conn = BaseHook.get_connection(self.astro_cloud_conn_id) + base_url = conn.host or os.environ.get("AIRFLOW__WEBSERVER__BASE_URL") + if base_url is None: + raise AirflowException(f"Airflow host is missing in connection {self.astro_cloud_conn_id}") + token = conn.password + if token is None: + raise AirflowException(f"Astro API token is missing in connection {self.astro_cloud_conn_id}") + return base_url, token + + @property + def _headers(self) -> dict[str, str]: + """Generates and returns headers for Astro Cloud API requests.""" + _, token = self.get_conn() + headers = {"accept": "application/json", "Authorization": f"Bearer {token}"} + return headers + + def get_dag_runs(self, external_dag_id: str) -> list[dict[str, str]]: + """ + Retrieves information about running or queued DAG runs. + + :param external_dag_id: External ID of the DAG. + """ + base_url, _ = self.get_conn() + path = f"/api/v1/dags/{external_dag_id}/dagRuns" + params: dict[str, int | str | list[str]] = {"limit": 1, "state": ["running", "queued"]} + url = f"{base_url}{path}" + response = requests.get(url, headers=self._headers, params=params) + response.raise_for_status() + data: dict[str, list[dict[str, str]]] = response.json() + return data["dag_runs"] + + def get_dag_run(self, external_dag_id: str, dag_run_id: str) -> dict[str, Any] | None: + """ + Retrieves information about a specific DAG run. + + :param external_dag_id: External ID of the DAG. + :param dag_run_id: ID of the DAG run. + """ + base_url, _ = self.get_conn() + dag_run_id = quote(dag_run_id) + path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}" + url = f"{base_url}{path}" + response = requests.get(url, headers=self._headers) + response.raise_for_status() + dr: dict[str, Any] = response.json() + return dr + + def get_task_instance( + self, external_dag_id: str, dag_run_id: str, external_task_id: str + ) -> dict[str, Any] | None: + """ + Retrieves information about a specific task instance within a DAG run. + + :param external_dag_id: External ID of the DAG. + :param dag_run_id: ID of the DAG run. + :param external_task_id: External ID of the task. + """ + base_url, _ = self.get_conn() + dag_run_id = quote(dag_run_id) + path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}" + url = f"{base_url}{path}" + response = requests.get(url, headers=self._headers) + response.raise_for_status() + ti: dict[str, Any] = response.json() + return ti diff --git a/astronomer/providers/core/sensors/astro.py b/astronomer/providers/core/sensors/astro.py new file mode 100644 index 000000000..e863ae693 --- /dev/null +++ b/astronomer/providers/core/sensors/astro.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import datetime + +# import time +from typing import Any, cast + +from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.sensors.base import BaseSensorOperator, PokeReturnValue + +from astronomer.providers.core.hooks.astro import AstroHook +from astronomer.providers.core.triggers.astro import AstroDeploymentTrigger +from astronomer.providers.utils.typing_compat import Context + + +class ExternalDeploymentSensor(BaseSensorOperator): + """ + Custom Apache Airflow sensor for monitoring external deployments using Astro Cloud. + + :param external_dag_id: External ID of the DAG being monitored. + :param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials. + Defaults to "astro_cloud_default". + :param external_task_id: External ID of the task being monitored. If None, monitors the entire DAG. + :param kwargs: Additional keyword arguments passed to the BaseSensorOperator constructor. + """ + + def __init__( + self, + external_dag_id: str, + astro_cloud_conn_id: str = "astro_cloud_default", + external_task_id: str | None = None, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.astro_cloud_conn_id = astro_cloud_conn_id + self.external_task_id = external_task_id + self.external_dag_id = external_dag_id + self._dag_run_id: str = "" + + # def wait_for_dag_start(self, second_to_wait: int = 120, sleep: int = 5) -> None: + # """TODO""" + # hook = AstroHook(self.astro_cloud_conn_id) + # end_time = datetime.datetime.now() + datetime.timedelta(seconds=second_to_wait) + # while end_time >= datetime.datetime.now(): + # try: + # dag_runs = hook.get_dag_runs(self.external_dag_id) + # if dag_runs is not None: + # return + # except Exception: + # time.sleep(sleep) + + def poke(self, context: Context) -> bool | PokeReturnValue: + """ + Check the status of a DAG/task in another deployment. + + Queries Airflow's REST API for the status of the specified DAG or task instance. + Returns True if successful, False otherwise. + + :param context: The task execution context. + """ + hook = AstroHook(self.astro_cloud_conn_id) + dag_runs: list[dict[str, Any]] = hook.get_dag_runs(self.external_dag_id) + if dag_runs is None or len(dag_runs) == 0: + self.log.info("No DAG runs found for DAG %s", self.external_dag_id) + return True + self._dag_run_id = cast(str, dag_runs[0]["dag_run_id"]) + if self.external_task_id is not None: + task_instance = hook.get_task_instance( + self.external_dag_id, self._dag_run_id, self.external_task_id + ) + task_state = task_instance.get("state") if task_instance else None + if task_state == "success": + return True + else: + state = dag_runs[0].get("state") + if state == "success": + return True + return False + + def execute(self, context: Context) -> Any: + """ + Executes the sensor. + + If the external deployment is not successful, it defers the execution using an AstroDeploymentTrigger. + + :param context: The task execution context. + """ + if not self.poke(context): + self.defer( + timeout=datetime.timedelta(seconds=self.timeout), + trigger=AstroDeploymentTrigger( + astro_cloud_conn_id=self.astro_cloud_conn_id, + external_task_id=self.external_task_id, + external_dag_id=self.external_dag_id, + poke_interval=self.poke_interval, + dag_run_id=self._dag_run_id, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, str]) -> None: + """ + Handles the completion event from the deferred execution. + + Raises AirflowSkipException if the upstream job failed and `soft_fail` is True. + Otherwise, raises AirflowException. + + :param context: The task execution context. + :param event: The event dictionary received from the deferred execution. + """ + if event.get("status") == "failed": + if self.soft_fail: + raise AirflowSkipException("Upstream job failed. Skipping the task.") + else: + raise AirflowException("Upstream job failed.") diff --git a/astronomer/providers/core/triggers/astro.py b/astronomer/providers/core/triggers/astro.py new file mode 100644 index 000000000..11a8693bd --- /dev/null +++ b/astronomer/providers/core/triggers/astro.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import asyncio +from typing import Any, AsyncIterator + +from airflow.triggers.base import BaseTrigger, TriggerEvent + +from astronomer.providers.core.hooks.astro import AstroHook + + +class AstroDeploymentTrigger(BaseTrigger): + """ + Custom Apache Airflow trigger for monitoring the completion status of an external deployment using Astro Cloud. + + :param external_dag_id: External ID of the DAG being monitored. + :param dag_run_id: ID of the DAG run being monitored. + :param external_task_id: External ID of the task being monitored. If None, monitors the entire DAG. + :param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials. Defaults to "astro_cloud_default". + :param poke_interval: Time in seconds to wait between consecutive checks for completion status. + :param kwargs: Additional keyword arguments passed to the BaseTrigger constructor. + """ + + def __init__( + self, + external_dag_id: str, + dag_run_id: str, + external_task_id: str | None = None, + astro_cloud_conn_id: str = "astro_cloud_default", + poke_interval: float = 5.0, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.external_dag_id = external_dag_id + self.dag_run_id = dag_run_id + self.external_task_id = external_task_id + self.astro_cloud_conn_id = astro_cloud_conn_id + self.poke_interval = poke_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize the trigger for storage in the database.""" + return ( + "astronomer.providers.core.triggers.astro.AstroDeploymentTrigger", + { + "external_dag_id": self.external_dag_id, + "external_task_id": self.external_task_id, + "dag_run_id": self.dag_run_id, + "astro_cloud_conn_id": self.astro_cloud_conn_id, + "poke_interval": self.poke_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """ + Asynchronously runs the trigger and yields completion status events. + + Checks the status of the external deployment using Astro Cloud at regular intervals. + Yields TriggerEvent with the status "done" if successful, "failed" if failed. + """ + hook = AstroHook(self.astro_cloud_conn_id) + while True: + if self.external_task_id is not None: + task_instance = hook.get_task_instance( + self.external_dag_id, self.dag_run_id, self.external_task_id + ) + state = task_instance.get("state") if task_instance else None + if state in ("success", "skipped"): + yield TriggerEvent({"status": "done"}) + elif state in ("failed", "upstream_failed"): + yield TriggerEvent({"status": "failed"}) + else: + dag_run = hook.get_dag_run(self.external_dag_id, self.dag_run_id) + state = dag_run.get("state") if dag_run else None + if state == "success": + yield TriggerEvent({"status": "done"}) + elif state == "failed": + yield TriggerEvent({"status": "failed"}) + self.log.info("Job status is %s sleeping for %s seconds.", state, self.poke_interval) + await asyncio.sleep(self.poke_interval) diff --git a/astronomer/providers/package.py b/astronomer/providers/package.py index de4ead1d6..a5644c911 100644 --- a/astronomer/providers/package.py +++ b/astronomer/providers/package.py @@ -10,6 +10,11 @@ def get_provider_info() -> Dict[str, Any]: "description": "Apache Airflow Providers containing Deferrable Operators & Sensors from Astronomer", "versions": "1.18.4", # Optional. - "hook-class-names": [], + "connection-types": [ + { + "hook-class-name": "astronomer.providers.core.hooks.astro.AstroHook", + "connection-type": "Astro Cloud", + } + ], "extra-links": [], } diff --git a/tests/core/hooks/__init__.py b/tests/core/hooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/core/hooks/test_astro.py b/tests/core/hooks/test_astro.py new file mode 100644 index 000000000..cae9789f5 --- /dev/null +++ b/tests/core/hooks/test_astro.py @@ -0,0 +1,144 @@ +from unittest.mock import MagicMock, patch + +import pytest +from airflow.exceptions import AirflowException + +from astronomer.providers.core.hooks.astro import AstroHook + + +class TestAstroHook: + + def test_get_ui_field_behaviour(self): + hook = AstroHook() + + result = hook.get_ui_field_behaviour() + + expected_result = { + "hidden_fields": ["login", "port", "schema", "extra"], + "relabeling": { + "password": "Astro Cloud API Token", + }, + "placeholders": { + "password": "ey...xz.ey...fq.tw...ap", + }, + } + + assert result == expected_result + + @patch("airflow.hooks.base.BaseHook.get_connection") + @patch("os.environ.get") + def test_get_conn(self, mock_os_get, mock_get_connection): + # Create an instance of your class + hook = AstroHook() + + # Mock the return values for BaseHook.get_connection and os.environ.get + mock_conn = MagicMock() + mock_conn.host = "http://example.com" + mock_conn.password = "your_api_token" + mock_get_connection.return_value = mock_conn + mock_os_get.return_value = "http://example.com" + + result = hook.get_conn() + + expected_result = ("http://example.com", "your_api_token") + + # Assert that the actual result matches the expected result + assert result == expected_result + + mock_get_connection.assert_called_once_with(hook.astro_cloud_conn_id) + + # Reset the mocks + mock_get_connection.reset_mock() + mock_os_get.reset_mock() + + # Test case where conn.host is None + mock_conn.host = None + mock_os_get.return_value = None + + with pytest.raises(AirflowException): + hook.get_conn() + + mock_get_connection.assert_called_once_with(hook.astro_cloud_conn_id) + + # Reset the mocks + mock_get_connection.reset_mock() + + # Test case where conn.password is None + mock_conn.host = "http://example.com" + mock_conn.password = None + + with pytest.raises(AirflowException): + hook.get_conn() + + mock_get_connection.assert_called_once_with(hook.astro_cloud_conn_id) + + @patch("astronomer.providers.core.hooks.astro.AstroHook.get_conn") + def test_headers(self, mock_get_conn): + # Create an instance of your class + your_instance = AstroHook() + + # Mock the return value for the get_conn method + mock_get_conn.return_value = ("http://example.com", "your_api_token") + + # Call the property and get the result + result = your_instance._headers + + # Define the expected result based on the method implementation + expected_result = {"accept": "application/json", "Authorization": "Bearer your_api_token"} + + # Assert that the actual result matches the expected result + assert result == expected_result + + # Assert that get_conn was called once + mock_get_conn.assert_called_once() + + @patch("airflow.hooks.base.BaseHook.get_connection") + @patch("astronomer.providers.core.hooks.astro.requests.get") + def test_get_dag_runs(self, mock_requests_get, mock_get_connection): + hook = AstroHook() + + # Mocking the response from requests.get + mock_response = MagicMock() + mock_response.json.return_value = {"dag_runs": [{"dag_run_id": "123", "state": "running"}]} + mock_requests_get.return_value = mock_response + + # Calling the method to be tested + result = hook.get_dag_runs("external_dag_id") + + # Assertions + mock_requests_get.assert_called_once() + assert result == [{"dag_run_id": "123", "state": "running"}] + + @patch("airflow.hooks.base.BaseHook.get_connection") + @patch("astronomer.providers.core.hooks.astro.requests.get") + def test_get_dag_run(self, mock_requests_get, mock_get_connection): + hook = AstroHook() + + # Mocking the response from requests.get + mock_response = MagicMock() + mock_response.json.return_value = {"dag_run_id": "123", "state": "running"} + mock_requests_get.return_value = mock_response + + # Calling the method to be tested + result = hook.get_dag_run("external_dag_id", "123") + + # Assertions + mock_requests_get.assert_called_once() + assert result == {"dag_run_id": "123", "state": "running"} + + @patch("airflow.hooks.base.BaseHook.get_connection") + @patch("astronomer.providers.core.hooks.astro.requests.get") + def test_get_task_instance(self, mock_requests_get, mock_get_connection): + hook = AstroHook() + + # Mocking the response from requests.get + mock_response = MagicMock() + mock_response.json.return_value = {"task_instance_id": "456", "state": "success"} + mock_requests_get.return_value = mock_response + + # Calling the method to be tested + result = hook.get_task_instance("external_dag_id", "123", "external_task_id") + + # Assertions + mock_requests_get.assert_called_once() + assert result == {"task_instance_id": "456", "state": "success"} diff --git a/tests/core/sensors/test_astro.py b/tests/core/sensors/test_astro.py new file mode 100644 index 000000000..f7324c310 --- /dev/null +++ b/tests/core/sensors/test_astro.py @@ -0,0 +1,90 @@ +from unittest import mock + +import pytest +from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred + +from astronomer.providers.core.sensors.astro import ExternalDeploymentSensor +from astronomer.providers.core.triggers.astro import AstroDeploymentTrigger + + +class TestExternalDeploymentSensor: + + @pytest.mark.parametrize( + "get_dag_runs_response", + [ + None, + [], + [{"dag_run_id": "run_id", "state": "success"}], + [{"dag_run_id": "run_id", "state": "running"}], + [{"dag_run_id": "run_id", "state": "queued"}], + ], + ) + @mock.patch("astronomer.providers.core.hooks.astro.AstroHook.get_dag_runs") + def test_poke_dag(self, mock_get_dag_runs_response, get_dag_runs_response, context): + mock_get_dag_runs_response.return_value = get_dag_runs_response + sensor = ExternalDeploymentSensor(task_id="test_me", external_dag_id="test_dag") + response = sensor.poke(context) + if get_dag_runs_response in [None, []]: + assert response is True + elif get_dag_runs_response[0].get("state") == "success": + assert response is True + else: + assert response is False + + @pytest.mark.parametrize("task_state", ["success", "running"]) + @mock.patch("astronomer.providers.core.hooks.astro.AstroHook.get_task_instance") + @mock.patch("astronomer.providers.core.hooks.astro.AstroHook.get_dag_runs") + def test_poke_task(self, mock_get_dag_runs_response, mock_get_task_instance, task_state, context): + mock_get_dag_runs_response.return_value = [{"dag_run_id": "run_id", "state": "running"}] + mock_get_task_instance.return_value = {"state": task_state} + sensor = ExternalDeploymentSensor( + task_id="test_me", external_dag_id="test_dag", external_task_id="task_id" + ) + response = sensor.poke(context) + if task_state == "success": + assert response is True + else: + assert response is False + + @pytest.mark.parametrize("poke_response", [True, False]) + @mock.patch("astronomer.providers.core.sensors.astro.ExternalDeploymentSensor.poke") + def test_execute(self, mock_poke, poke_response, context): + mock_poke.return_value = poke_response + + sensor = ExternalDeploymentSensor(task_id="test_me", external_dag_id="test_dag") + if poke_response: + response = sensor.execute(context) + assert response is None + else: + with pytest.raises(TaskDeferred) as exc: + sensor.execute(context) + assert isinstance( + exc.value.trigger, AstroDeploymentTrigger + ), "Trigger is not a AstroDeploymentTrigger" + + @pytest.mark.parametrize( + "event,soft_fail", + [ + ({"status": "done"}, False), + ({"status": "done"}, True), + ({"status": "failed"}, False), + ({"status": "failed"}, True), + ], + ) + def test_execute_complete(self, event, soft_fail, context): + sensor = ExternalDeploymentSensor(task_id="test_me", external_dag_id="test_dag", soft_fail=soft_fail) + + if soft_fail: + if event.get("status") == "failed": + with pytest.raises(AirflowSkipException) as exc: + sensor.execute_complete(context, event) + assert str(exc.value) == "Upstream job failed. Skipping the task." + if event.get("status") == "done": + assert sensor.execute_complete(context, event) is None + else: + if event.get("status") == "failed": + with pytest.raises(AirflowException) as exc: + sensor.execute_complete(context, event) + assert str(exc.value) == "Upstream job failed." + if event.get("status") == "done": + assert sensor.execute_complete(context, event) is None diff --git a/tests/core/triggers/test_astro.py b/tests/core/triggers/test_astro.py new file mode 100644 index 000000000..2dddf0260 --- /dev/null +++ b/tests/core/triggers/test_astro.py @@ -0,0 +1,120 @@ +import asyncio +from unittest.mock import patch + +import pytest +from airflow.triggers.base import TriggerEvent + +from astronomer.providers.core.triggers.astro import AstroDeploymentTrigger + + +class TestAstroDeploymentTrigger: + + def test_serialize(self): + trigger = AstroDeploymentTrigger( + external_dag_id="external_dag_id", + dag_run_id="dag_run_id", + external_task_id="external_task_id", + astro_cloud_conn_id="astro_cloud_conn_id", + poke_interval=1.0, + ) + + serialized_data = trigger.serialize() + + expected_result = ( + "astronomer.providers.core.triggers.astro.AstroDeploymentTrigger", + { + "external_dag_id": "external_dag_id", + "external_task_id": "external_task_id", + "dag_run_id": "dag_run_id", + "astro_cloud_conn_id": "astro_cloud_conn_id", + "poke_interval": 1.0, + }, + ) + + assert serialized_data == expected_result + + @pytest.mark.asyncio + @patch("astronomer.providers.core.hooks.astro.AstroHook.get_task_instance") + async def test_run_task_successful(self, mock_get_task_instance): + trigger = AstroDeploymentTrigger( + external_dag_id="external_dag_id", + dag_run_id="dag_run_id", + external_task_id="external_task_id", + astro_cloud_conn_id="astro_cloud_conn_id", + poke_interval=1.0, + ) + + mock_get_task_instance.return_value = {"state": "success"} + + generator = trigger.run() + actual = await generator.asend(None) + assert actual == TriggerEvent({"status": "done"}) + + @pytest.mark.asyncio + @patch("astronomer.providers.core.hooks.astro.AstroHook.get_task_instance") + async def test_run_task_failed(self, mock_get_task_instance): + trigger = AstroDeploymentTrigger( + external_dag_id="external_dag_id", + dag_run_id="dag_run_id", + external_task_id="external_task_id", + astro_cloud_conn_id="astro_cloud_conn_id", + poke_interval=1.0, + ) + + mock_get_task_instance.return_value = {"state": "failed"} + + generator = trigger.run() + actual = await generator.asend(None) + assert actual == TriggerEvent({"status": "failed"}) + + @pytest.mark.asyncio + @patch("astronomer.providers.core.hooks.astro.AstroHook.get_dag_run") + async def test_run_dag_successful(self, mock_get_dag_run): + trigger = AstroDeploymentTrigger( + external_dag_id="external_dag_id", + dag_run_id="dag_run_id", + astro_cloud_conn_id="astro_cloud_conn_id", + poke_interval=1.0, + ) + + # Mocking AstroHook responses for a successful DAG run + mock_get_dag_run.return_value = {"state": "success"} + + generator = trigger.run() + actual = await generator.asend(None) + assert actual == TriggerEvent({"status": "done"}) + + @pytest.mark.asyncio + @patch("astronomer.providers.core.hooks.astro.AstroHook.get_dag_run") + async def test_run_dag_failed(self, mock_get_dag_run): + trigger = AstroDeploymentTrigger( + external_dag_id="external_dag_id", + dag_run_id="dag_run_id", + astro_cloud_conn_id="astro_cloud_conn_id", + poke_interval=1.0, + ) + + mock_get_dag_run.return_value = {"state": "failed"} + + generator = trigger.run() + actual = await generator.asend(None) + assert actual == TriggerEvent({"status": "failed"}) + + @pytest.mark.asyncio + @patch("astronomer.providers.core.hooks.astro.AstroHook.get_dag_run") + async def test_run_dag_wait(self, mock_get_dag_run): + trigger = AstroDeploymentTrigger( + external_dag_id="external_dag_id", + dag_run_id="dag_run_id", + astro_cloud_conn_id="astro_cloud_conn_id", + poke_interval=1.0, + ) + + mock_get_dag_run.return_value = {"state": "running"} + + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was not returned + assert task.done() is False + asyncio.get_event_loop().stop() From c90cb543f6a75ed685f21d4868ffc6ec311e596e Mon Sep 17 00:00:00 2001 From: Pankaj Date: Wed, 14 Feb 2024 22:11:57 +0530 Subject: [PATCH 2/5] Update tests --- astronomer/providers/core/hooks/astro.py | 2 +- tests/core/hooks/test_astro.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/astronomer/providers/core/hooks/astro.py b/astronomer/providers/core/hooks/astro.py index eaedcc908..ea14faa64 100644 --- a/astronomer/providers/core/hooks/astro.py +++ b/astronomer/providers/core/hooks/astro.py @@ -39,7 +39,7 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: }, "placeholders": { "host": "https://clmkpsyfc010391acjie00t1l.astronomer.run/d5lc9c9x", - "password": "JWT API Token", + "password": "Astro API JWT Token", }, } diff --git a/tests/core/hooks/test_astro.py b/tests/core/hooks/test_astro.py index cae9789f5..b48b61a06 100644 --- a/tests/core/hooks/test_astro.py +++ b/tests/core/hooks/test_astro.py @@ -19,7 +19,8 @@ def test_get_ui_field_behaviour(self): "password": "Astro Cloud API Token", }, "placeholders": { - "password": "ey...xz.ey...fq.tw...ap", + "host": "https://clmkpsyfc010391acjie00t1l.astronomer.run/d5lc9c9x", + "password": "Astro API JWT Token", }, } From 25adf93b8066e472a76e313727fb245b9c7453ff Mon Sep 17 00:00:00 2001 From: Pankaj Date: Thu, 15 Feb 2024 00:40:40 +0530 Subject: [PATCH 3/5] Fix bug fech last dagrun --- astronomer/providers/core/example_dags/example_astro.py | 4 ++-- astronomer/providers/core/hooks/astro.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/astronomer/providers/core/example_dags/example_astro.py b/astronomer/providers/core/example_dags/example_astro.py index fc4b60b11..bbcd50be0 100644 --- a/astronomer/providers/core/example_dags/example_astro.py +++ b/astronomer/providers/core/example_dags/example_astro.py @@ -26,7 +26,7 @@ ) with DAG( - dag_id="wait_to_test_example_astro_dag", + dag_id="example_wait_to_test_example_astro_task", start_date=datetime(2022, 1, 1), schedule=None, catchup=False, @@ -50,7 +50,7 @@ def wait_for_2_min() -> None: ): run_wait_dag = TriggerDagRunOperator( task_id="run_wait_dag", - trigger_dag_id="example_external_task_async_waits_for_me", + trigger_dag_id="example_wait_to_test_example_astro_task", wait_for_completion=False, ) diff --git a/astronomer/providers/core/hooks/astro.py b/astronomer/providers/core/hooks/astro.py index ea14faa64..e7955e8c6 100644 --- a/astronomer/providers/core/hooks/astro.py +++ b/astronomer/providers/core/hooks/astro.py @@ -69,7 +69,11 @@ def get_dag_runs(self, external_dag_id: str) -> list[dict[str, str]]: """ base_url, _ = self.get_conn() path = f"/api/v1/dags/{external_dag_id}/dagRuns" - params: dict[str, int | str | list[str]] = {"limit": 1, "state": ["running", "queued"]} + params: dict[str, int | str | list[str]] = { + "limit": 1, + "state": ["running", "queued"], + "order_by": "-execution_date", + } url = f"{base_url}{path}" response = requests.get(url, headers=self._headers, params=params) response.raise_for_status() From ec5cdface3c4f1bce3073e2715cf40b2752298fb Mon Sep 17 00:00:00 2001 From: Pankaj Date: Thu, 15 Feb 2024 16:31:37 +0530 Subject: [PATCH 4/5] Change http call to async --- astronomer/providers/core/hooks/astro.py | 40 +++++++++ astronomer/providers/core/triggers/astro.py | 4 +- tests/core/hooks/test_astro.py | 99 ++++++++++++++++++++- tests/core/triggers/test_astro.py | 10 +-- 4 files changed, 145 insertions(+), 8 deletions(-) diff --git a/astronomer/providers/core/hooks/astro.py b/astronomer/providers/core/hooks/astro.py index e7955e8c6..b83b88e13 100644 --- a/astronomer/providers/core/hooks/astro.py +++ b/astronomer/providers/core/hooks/astro.py @@ -5,6 +5,7 @@ from urllib.parse import quote import requests +from aiohttp import ClientSession from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook @@ -96,6 +97,24 @@ def get_dag_run(self, external_dag_id: str, dag_run_id: str) -> dict[str, Any] | dr: dict[str, Any] = response.json() return dr + async def get_a_dag_run(self, external_dag_id: str, dag_run_id: str) -> dict[str, Any] | None: + """ + Retrieves information about a specific DAG run. + + :param external_dag_id: External ID of the DAG. + :param dag_run_id: ID of the DAG run. + """ + base_url, _ = self.get_conn() + dag_run_id = quote(dag_run_id) + path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}" + url = f"{base_url}{path}" + + async with ClientSession(headers=self._headers) as session: + async with session.get(url) as response: + response.raise_for_status() + dr: dict[str, Any] = await response.json() + return dr + def get_task_instance( self, external_dag_id: str, dag_run_id: str, external_task_id: str ) -> dict[str, Any] | None: @@ -114,3 +133,24 @@ def get_task_instance( response.raise_for_status() ti: dict[str, Any] = response.json() return ti + + async def get_a_task_instance( + self, external_dag_id: str, dag_run_id: str, external_task_id: str + ) -> dict[str, Any] | None: + """ + Retrieves information about a specific task instance within a DAG run. + + :param external_dag_id: External ID of the DAG. + :param dag_run_id: ID of the DAG run. + :param external_task_id: External ID of the task. + """ + base_url, _ = self.get_conn() + dag_run_id = quote(dag_run_id) + path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}" + url = f"{base_url}{path}" + + async with ClientSession(headers=self._headers) as session: + async with session.get(url) as response: + response.raise_for_status() + ti: dict[str, Any] = await response.json() + return ti diff --git a/astronomer/providers/core/triggers/astro.py b/astronomer/providers/core/triggers/astro.py index 11a8693bd..2eb1f1396 100644 --- a/astronomer/providers/core/triggers/astro.py +++ b/astronomer/providers/core/triggers/astro.py @@ -59,7 +59,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: hook = AstroHook(self.astro_cloud_conn_id) while True: if self.external_task_id is not None: - task_instance = hook.get_task_instance( + task_instance = await hook.get_a_task_instance( self.external_dag_id, self.dag_run_id, self.external_task_id ) state = task_instance.get("state") if task_instance else None @@ -68,7 +68,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: elif state in ("failed", "upstream_failed"): yield TriggerEvent({"status": "failed"}) else: - dag_run = hook.get_dag_run(self.external_dag_id, self.dag_run_id) + dag_run = await hook.get_a_dag_run(self.external_dag_id, self.dag_run_id) state = dag_run.get("state") if dag_run else None if state == "success": yield TriggerEvent({"status": "done"}) diff --git a/tests/core/hooks/test_astro.py b/tests/core/hooks/test_astro.py index b48b61a06..e9ed9ef90 100644 --- a/tests/core/hooks/test_astro.py +++ b/tests/core/hooks/test_astro.py @@ -1,6 +1,8 @@ -from unittest.mock import MagicMock, patch +from unittest import mock +from unittest.mock import MagicMock, Mock, patch import pytest +from aioresponses import aioresponses from airflow.exceptions import AirflowException from astronomer.providers.core.hooks.astro import AstroHook @@ -143,3 +145,98 @@ def test_get_task_instance(self, mock_requests_get, mock_get_connection): # Assertions mock_requests_get.assert_called_once() assert result == {"task_instance_id": "456", "state": "success"} + + @pytest.mark.asyncio + @mock.patch("astronomer.providers.core.hooks.astro.AstroHook._headers") + async def test_get_a_dag_run(self, mock_headers): + external_dag_id = "your_external_dag_id" + dag_run_id = "your_dag_run_id" + url = f"https://test.com/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}" + + # Mocking necessary objects + your_class_instance = AstroHook() + your_class_instance.get_conn = Mock(return_value=("https://test.com", "Test Token")) + mock_headers.return_value = {"accept": "application/json", "Authorization": "Bearer Token"} + response_data = { + "conf": {}, + "dag_id": "my_dag", + "dag_run_id": "manual__2024-02-14T19:06:32.053905+00:00", + "data_interval_end": "2024-02-14T19:06:32.053905+00:00", + "data_interval_start": "2024-02-14T19:06:32.053905+00:00", + "end_date": "2024-02-14T19:16:33.987139+00:00", + "execution_date": "2024-02-14T19:06:32.053905+00:00", + "external_trigger": True, + "last_scheduling_decision": "2024-02-14T19:16:33.985973+00:00", + "logical_date": "2024-02-14T19:06:32.053905+00:00", + "note": None, + "run_type": "manual", + "start_date": "2024-02-14T19:06:33.004299+00:00", + "state": "success", + } + + with aioresponses() as mock_session: + mock_session.get( + url, + headers=your_class_instance._headers, + status=200, + payload=response_data, + ) + + result = await your_class_instance.get_a_dag_run(external_dag_id, dag_run_id) + + assert result == response_data + + @pytest.mark.asyncio + @mock.patch("astronomer.providers.core.hooks.astro.AstroHook._headers") + async def test_get_a_task_instance(self, mock_headers): + external_dag_id = "your_external_dag_id" + dag_run_id = "your_dag_run_id" + external_task_id = "your_external_task_id" + url = f"https://test.com/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}" + + # Mocking necessary objects + your_class_instance = AstroHook() + your_class_instance.get_conn = Mock(return_value=("https://test.com", "Test Token")) + mock_headers.return_value = {"accept": "application/json", "Authorization": "Bearer Token"} + response_data = { + "dag_id": "my_dag", + "dag_run_id": "manual__2024-02-14T19:06:32.053905+00:00", + "duration": 600.233105, + "end_date": "2024-02-14T19:16:33.459676+00:00", + "execution_date": "2024-02-14T19:06:32.053905+00:00", + "executor_config": "{}", + "hostname": "d10fc8b0ad27", + "map_index": -1, + "max_tries": 0, + "note": None, + "operator": "_PythonDecoratedOperator", + "pid": 927, + "pool": "default_pool", + "pool_slots": 1, + "priority_weight": 1, + "queue": "default", + "queued_when": "2024-02-14T19:06:33.036108+00:00", + "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "sla_miss": None, + "start_date": "2024-02-14T19:06:33.226571+00:00", + "state": "success", + "task_id": "my_python_function", + "trigger": None, + "triggerer_job": None, + "try_number": 1, + "unixname": "astro", + } + + with aioresponses() as mock_session: + mock_session.get( + url, + headers=your_class_instance._headers, + status=200, + payload=response_data, + ) + + result = await your_class_instance.get_a_task_instance( + external_dag_id, dag_run_id, external_task_id + ) + + assert result == response_data diff --git a/tests/core/triggers/test_astro.py b/tests/core/triggers/test_astro.py index 2dddf0260..8fc8f962d 100644 --- a/tests/core/triggers/test_astro.py +++ b/tests/core/triggers/test_astro.py @@ -34,7 +34,7 @@ def test_serialize(self): assert serialized_data == expected_result @pytest.mark.asyncio - @patch("astronomer.providers.core.hooks.astro.AstroHook.get_task_instance") + @patch("astronomer.providers.core.hooks.astro.AstroHook.get_a_task_instance") async def test_run_task_successful(self, mock_get_task_instance): trigger = AstroDeploymentTrigger( external_dag_id="external_dag_id", @@ -51,7 +51,7 @@ async def test_run_task_successful(self, mock_get_task_instance): assert actual == TriggerEvent({"status": "done"}) @pytest.mark.asyncio - @patch("astronomer.providers.core.hooks.astro.AstroHook.get_task_instance") + @patch("astronomer.providers.core.hooks.astro.AstroHook.get_a_task_instance") async def test_run_task_failed(self, mock_get_task_instance): trigger = AstroDeploymentTrigger( external_dag_id="external_dag_id", @@ -68,7 +68,7 @@ async def test_run_task_failed(self, mock_get_task_instance): assert actual == TriggerEvent({"status": "failed"}) @pytest.mark.asyncio - @patch("astronomer.providers.core.hooks.astro.AstroHook.get_dag_run") + @patch("astronomer.providers.core.hooks.astro.AstroHook.get_a_dag_run") async def test_run_dag_successful(self, mock_get_dag_run): trigger = AstroDeploymentTrigger( external_dag_id="external_dag_id", @@ -85,7 +85,7 @@ async def test_run_dag_successful(self, mock_get_dag_run): assert actual == TriggerEvent({"status": "done"}) @pytest.mark.asyncio - @patch("astronomer.providers.core.hooks.astro.AstroHook.get_dag_run") + @patch("astronomer.providers.core.hooks.astro.AstroHook.get_a_dag_run") async def test_run_dag_failed(self, mock_get_dag_run): trigger = AstroDeploymentTrigger( external_dag_id="external_dag_id", @@ -101,7 +101,7 @@ async def test_run_dag_failed(self, mock_get_dag_run): assert actual == TriggerEvent({"status": "failed"}) @pytest.mark.asyncio - @patch("astronomer.providers.core.hooks.astro.AstroHook.get_dag_run") + @patch("astronomer.providers.core.hooks.astro.AstroHook.get_a_dag_run") async def test_run_dag_wait(self, mock_get_dag_run): trigger = AstroDeploymentTrigger( external_dag_id="external_dag_id", From ac900b18885d2c844dc359056a553ce336e873f8 Mon Sep 17 00:00:00 2001 From: Pankaj Date: Thu, 15 Feb 2024 16:50:18 +0530 Subject: [PATCH 5/5] Apply review suggestions --- astronomer/providers/core/sensors/astro.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/astronomer/providers/core/sensors/astro.py b/astronomer/providers/core/sensors/astro.py index e863ae693..681f348f8 100644 --- a/astronomer/providers/core/sensors/astro.py +++ b/astronomer/providers/core/sensors/astro.py @@ -1,8 +1,6 @@ from __future__ import annotations import datetime - -# import time from typing import Any, cast from airflow.exceptions import AirflowException, AirflowSkipException @@ -37,18 +35,6 @@ def __init__( self.external_dag_id = external_dag_id self._dag_run_id: str = "" - # def wait_for_dag_start(self, second_to_wait: int = 120, sleep: int = 5) -> None: - # """TODO""" - # hook = AstroHook(self.astro_cloud_conn_id) - # end_time = datetime.datetime.now() + datetime.timedelta(seconds=second_to_wait) - # while end_time >= datetime.datetime.now(): - # try: - # dag_runs = hook.get_dag_runs(self.external_dag_id) - # if dag_runs is not None: - # return - # except Exception: - # time.sleep(sleep) - def poke(self, context: Context) -> bool | PokeReturnValue: """ Check the status of a DAG/task in another deployment. @@ -60,7 +46,7 @@ def poke(self, context: Context) -> bool | PokeReturnValue: """ hook = AstroHook(self.astro_cloud_conn_id) dag_runs: list[dict[str, Any]] = hook.get_dag_runs(self.external_dag_id) - if dag_runs is None or len(dag_runs) == 0: + if not dag_runs: self.log.info("No DAG runs found for DAG %s", self.external_dag_id) return True self._dag_run_id = cast(str, dag_runs[0]["dag_run_id"]) @@ -111,5 +97,4 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None: if event.get("status") == "failed": if self.soft_fail: raise AirflowSkipException("Upstream job failed. Skipping the task.") - else: - raise AirflowException("Upstream job failed.") + raise AirflowException("Upstream job failed.")