diff --git a/astronomer/providers/core/example_dags/example_astro.py b/astronomer/providers/core/example_dags/example_astro.py new file mode 100644 index 000000000..bbcd50be0 --- /dev/null +++ b/astronomer/providers/core/example_dags/example_astro.py @@ -0,0 +1,63 @@ +import time +from datetime import datetime + +from airflow import DAG +from airflow.decorators import task +from airflow.operators.trigger_dagrun import TriggerDagRunOperator + +from astronomer.providers.core.sensors.astro import ExternalDeploymentSensor + +with DAG( + dag_id="example_astro_task", + start_date=datetime(2022, 1, 1), + schedule=None, + catchup=False, + tags=["example", "async", "core"], +): + ExternalDeploymentSensor( + task_id="test1", + external_dag_id="example_wait_to_test_example_astro_task", + ) + + ExternalDeploymentSensor( + task_id="test2", + external_dag_id="example_wait_to_test_example_astro_task", + external_task_id="wait_for_2_min", + ) + +with DAG( + dag_id="example_wait_to_test_example_astro_task", + start_date=datetime(2022, 1, 1), + schedule=None, + catchup=False, + tags=["example", "async", "core"], +): + + @task + def wait_for_2_min() -> None: + """Wait for 2 min.""" + time.sleep(120) + + wait_for_2_min() + + +with DAG( + dag_id="trigger_astro_test_and_example", + start_date=datetime(2022, 1, 1), + schedule=None, + catchup=False, + tags=["example", "async", "core"], +): + run_wait_dag = TriggerDagRunOperator( + task_id="run_wait_dag", + trigger_dag_id="example_wait_to_test_example_astro_task", + wait_for_completion=False, + ) + + run_astro_dag = TriggerDagRunOperator( + task_id="run_astro_dag", + trigger_dag_id="example_astro_task", + wait_for_completion=False, + ) + + run_wait_dag >> run_astro_dag diff --git a/astronomer/providers/core/hooks/__init__.py b/astronomer/providers/core/hooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/astronomer/providers/core/hooks/astro.py b/astronomer/providers/core/hooks/astro.py new file mode 100644 index 000000000..b83b88e13 --- /dev/null +++ b/astronomer/providers/core/hooks/astro.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import os +from typing import Any +from urllib.parse import quote + +import requests +from aiohttp import ClientSession +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook + + +class AstroHook(BaseHook): + """ + Custom Apache Airflow Hook for interacting with Astro Cloud API. + + :param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials. + """ + + conn_name_attr = "astro_cloud_conn_id" + default_conn_name = "astro_cloud_default" + conn_type = "Astro Cloud" + hook_name = "Astro Cloud" + + def __init__(self, astro_cloud_conn_id: str = "astro_cloud_conn_id"): + super().__init__() + self.astro_cloud_conn_id = astro_cloud_conn_id + + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: + """ + Returns UI field behavior customization for the Astro Cloud connection. + + This method defines hidden fields, relabeling, and placeholders for UI display. + """ + return { + "hidden_fields": ["login", "port", "schema", "extra"], + "relabeling": { + "password": "Astro Cloud API Token", + }, + "placeholders": { + "host": "https://clmkpsyfc010391acjie00t1l.astronomer.run/d5lc9c9x", + "password": "Astro API JWT Token", + }, + } + + def get_conn(self) -> tuple[str, str]: + """Retrieves the Astro Cloud connection details.""" + conn = BaseHook.get_connection(self.astro_cloud_conn_id) + base_url = conn.host or os.environ.get("AIRFLOW__WEBSERVER__BASE_URL") + if base_url is None: + raise AirflowException(f"Airflow host is missing in connection {self.astro_cloud_conn_id}") + token = conn.password + if token is None: + raise AirflowException(f"Astro API token is missing in connection {self.astro_cloud_conn_id}") + return base_url, token + + @property + def _headers(self) -> dict[str, str]: + """Generates and returns headers for Astro Cloud API requests.""" + _, token = self.get_conn() + headers = {"accept": "application/json", "Authorization": f"Bearer {token}"} + return headers + + def get_dag_runs(self, external_dag_id: str) -> list[dict[str, str]]: + """ + Retrieves information about running or queued DAG runs. + + :param external_dag_id: External ID of the DAG. + """ + base_url, _ = self.get_conn() + path = f"/api/v1/dags/{external_dag_id}/dagRuns" + params: dict[str, int | str | list[str]] = { + "limit": 1, + "state": ["running", "queued"], + "order_by": "-execution_date", + } + url = f"{base_url}{path}" + response = requests.get(url, headers=self._headers, params=params) + response.raise_for_status() + data: dict[str, list[dict[str, str]]] = response.json() + return data["dag_runs"] + + def get_dag_run(self, external_dag_id: str, dag_run_id: str) -> dict[str, Any] | None: + """ + Retrieves information about a specific DAG run. + + :param external_dag_id: External ID of the DAG. + :param dag_run_id: ID of the DAG run. + """ + base_url, _ = self.get_conn() + dag_run_id = quote(dag_run_id) + path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}" + url = f"{base_url}{path}" + response = requests.get(url, headers=self._headers) + response.raise_for_status() + dr: dict[str, Any] = response.json() + return dr + + async def get_a_dag_run(self, external_dag_id: str, dag_run_id: str) -> dict[str, Any] | None: + """ + Retrieves information about a specific DAG run. + + :param external_dag_id: External ID of the DAG. + :param dag_run_id: ID of the DAG run. + """ + base_url, _ = self.get_conn() + dag_run_id = quote(dag_run_id) + path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}" + url = f"{base_url}{path}" + + async with ClientSession(headers=self._headers) as session: + async with session.get(url) as response: + response.raise_for_status() + dr: dict[str, Any] = await response.json() + return dr + + def get_task_instance( + self, external_dag_id: str, dag_run_id: str, external_task_id: str + ) -> dict[str, Any] | None: + """ + Retrieves information about a specific task instance within a DAG run. + + :param external_dag_id: External ID of the DAG. + :param dag_run_id: ID of the DAG run. + :param external_task_id: External ID of the task. + """ + base_url, _ = self.get_conn() + dag_run_id = quote(dag_run_id) + path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}" + url = f"{base_url}{path}" + response = requests.get(url, headers=self._headers) + response.raise_for_status() + ti: dict[str, Any] = response.json() + return ti + + async def get_a_task_instance( + self, external_dag_id: str, dag_run_id: str, external_task_id: str + ) -> dict[str, Any] | None: + """ + Retrieves information about a specific task instance within a DAG run. + + :param external_dag_id: External ID of the DAG. + :param dag_run_id: ID of the DAG run. + :param external_task_id: External ID of the task. + """ + base_url, _ = self.get_conn() + dag_run_id = quote(dag_run_id) + path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}" + url = f"{base_url}{path}" + + async with ClientSession(headers=self._headers) as session: + async with session.get(url) as response: + response.raise_for_status() + ti: dict[str, Any] = await response.json() + return ti diff --git a/astronomer/providers/core/sensors/astro.py b/astronomer/providers/core/sensors/astro.py new file mode 100644 index 000000000..681f348f8 --- /dev/null +++ b/astronomer/providers/core/sensors/astro.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import datetime +from typing import Any, cast + +from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.sensors.base import BaseSensorOperator, PokeReturnValue + +from astronomer.providers.core.hooks.astro import AstroHook +from astronomer.providers.core.triggers.astro import AstroDeploymentTrigger +from astronomer.providers.utils.typing_compat import Context + + +class ExternalDeploymentSensor(BaseSensorOperator): + """ + Custom Apache Airflow sensor for monitoring external deployments using Astro Cloud. + + :param external_dag_id: External ID of the DAG being monitored. + :param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials. + Defaults to "astro_cloud_default". + :param external_task_id: External ID of the task being monitored. If None, monitors the entire DAG. + :param kwargs: Additional keyword arguments passed to the BaseSensorOperator constructor. + """ + + def __init__( + self, + external_dag_id: str, + astro_cloud_conn_id: str = "astro_cloud_default", + external_task_id: str | None = None, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.astro_cloud_conn_id = astro_cloud_conn_id + self.external_task_id = external_task_id + self.external_dag_id = external_dag_id + self._dag_run_id: str = "" + + def poke(self, context: Context) -> bool | PokeReturnValue: + """ + Check the status of a DAG/task in another deployment. + + Queries Airflow's REST API for the status of the specified DAG or task instance. + Returns True if successful, False otherwise. + + :param context: The task execution context. + """ + hook = AstroHook(self.astro_cloud_conn_id) + dag_runs: list[dict[str, Any]] = hook.get_dag_runs(self.external_dag_id) + if not dag_runs: + self.log.info("No DAG runs found for DAG %s", self.external_dag_id) + return True + self._dag_run_id = cast(str, dag_runs[0]["dag_run_id"]) + if self.external_task_id is not None: + task_instance = hook.get_task_instance( + self.external_dag_id, self._dag_run_id, self.external_task_id + ) + task_state = task_instance.get("state") if task_instance else None + if task_state == "success": + return True + else: + state = dag_runs[0].get("state") + if state == "success": + return True + return False + + def execute(self, context: Context) -> Any: + """ + Executes the sensor. + + If the external deployment is not successful, it defers the execution using an AstroDeploymentTrigger. + + :param context: The task execution context. + """ + if not self.poke(context): + self.defer( + timeout=datetime.timedelta(seconds=self.timeout), + trigger=AstroDeploymentTrigger( + astro_cloud_conn_id=self.astro_cloud_conn_id, + external_task_id=self.external_task_id, + external_dag_id=self.external_dag_id, + poke_interval=self.poke_interval, + dag_run_id=self._dag_run_id, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, str]) -> None: + """ + Handles the completion event from the deferred execution. + + Raises AirflowSkipException if the upstream job failed and `soft_fail` is True. + Otherwise, raises AirflowException. + + :param context: The task execution context. + :param event: The event dictionary received from the deferred execution. + """ + if event.get("status") == "failed": + if self.soft_fail: + raise AirflowSkipException("Upstream job failed. Skipping the task.") + raise AirflowException("Upstream job failed.") diff --git a/astronomer/providers/core/triggers/astro.py b/astronomer/providers/core/triggers/astro.py new file mode 100644 index 000000000..2eb1f1396 --- /dev/null +++ b/astronomer/providers/core/triggers/astro.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import asyncio +from typing import Any, AsyncIterator + +from airflow.triggers.base import BaseTrigger, TriggerEvent + +from astronomer.providers.core.hooks.astro import AstroHook + + +class AstroDeploymentTrigger(BaseTrigger): + """ + Custom Apache Airflow trigger for monitoring the completion status of an external deployment using Astro Cloud. + + :param external_dag_id: External ID of the DAG being monitored. + :param dag_run_id: ID of the DAG run being monitored. + :param external_task_id: External ID of the task being monitored. If None, monitors the entire DAG. + :param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials. Defaults to "astro_cloud_default". + :param poke_interval: Time in seconds to wait between consecutive checks for completion status. + :param kwargs: Additional keyword arguments passed to the BaseTrigger constructor. + """ + + def __init__( + self, + external_dag_id: str, + dag_run_id: str, + external_task_id: str | None = None, + astro_cloud_conn_id: str = "astro_cloud_default", + poke_interval: float = 5.0, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.external_dag_id = external_dag_id + self.dag_run_id = dag_run_id + self.external_task_id = external_task_id + self.astro_cloud_conn_id = astro_cloud_conn_id + self.poke_interval = poke_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize the trigger for storage in the database.""" + return ( + "astronomer.providers.core.triggers.astro.AstroDeploymentTrigger", + { + "external_dag_id": self.external_dag_id, + "external_task_id": self.external_task_id, + "dag_run_id": self.dag_run_id, + "astro_cloud_conn_id": self.astro_cloud_conn_id, + "poke_interval": self.poke_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """ + Asynchronously runs the trigger and yields completion status events. + + Checks the status of the external deployment using Astro Cloud at regular intervals. + Yields TriggerEvent with the status "done" if successful, "failed" if failed. + """ + hook = AstroHook(self.astro_cloud_conn_id) + while True: + if self.external_task_id is not None: + task_instance = await hook.get_a_task_instance( + self.external_dag_id, self.dag_run_id, self.external_task_id + ) + state = task_instance.get("state") if task_instance else None + if state in ("success", "skipped"): + yield TriggerEvent({"status": "done"}) + elif state in ("failed", "upstream_failed"): + yield TriggerEvent({"status": "failed"}) + else: + dag_run = await hook.get_a_dag_run(self.external_dag_id, self.dag_run_id) + state = dag_run.get("state") if dag_run else None + if state == "success": + yield TriggerEvent({"status": "done"}) + elif state == "failed": + yield TriggerEvent({"status": "failed"}) + self.log.info("Job status is %s sleeping for %s seconds.", state, self.poke_interval) + await asyncio.sleep(self.poke_interval) diff --git a/astronomer/providers/package.py b/astronomer/providers/package.py index de4ead1d6..a5644c911 100644 --- a/astronomer/providers/package.py +++ b/astronomer/providers/package.py @@ -10,6 +10,11 @@ def get_provider_info() -> Dict[str, Any]: "description": "Apache Airflow Providers containing Deferrable Operators & Sensors from Astronomer", "versions": "1.18.4", # Optional. - "hook-class-names": [], + "connection-types": [ + { + "hook-class-name": "astronomer.providers.core.hooks.astro.AstroHook", + "connection-type": "Astro Cloud", + } + ], "extra-links": [], } diff --git a/tests/core/hooks/__init__.py b/tests/core/hooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/core/hooks/test_astro.py b/tests/core/hooks/test_astro.py new file mode 100644 index 000000000..e9ed9ef90 --- /dev/null +++ b/tests/core/hooks/test_astro.py @@ -0,0 +1,242 @@ +from unittest import mock +from unittest.mock import MagicMock, Mock, patch + +import pytest +from aioresponses import aioresponses +from airflow.exceptions import AirflowException + +from astronomer.providers.core.hooks.astro import AstroHook + + +class TestAstroHook: + + def test_get_ui_field_behaviour(self): + hook = AstroHook() + + result = hook.get_ui_field_behaviour() + + expected_result = { + "hidden_fields": ["login", "port", "schema", "extra"], + "relabeling": { + "password": "Astro Cloud API Token", + }, + "placeholders": { + "host": "https://clmkpsyfc010391acjie00t1l.astronomer.run/d5lc9c9x", + "password": "Astro API JWT Token", + }, + } + + assert result == expected_result + + @patch("airflow.hooks.base.BaseHook.get_connection") + @patch("os.environ.get") + def test_get_conn(self, mock_os_get, mock_get_connection): + # Create an instance of your class + hook = AstroHook() + + # Mock the return values for BaseHook.get_connection and os.environ.get + mock_conn = MagicMock() + mock_conn.host = "http://example.com" + mock_conn.password = "your_api_token" + mock_get_connection.return_value = mock_conn + mock_os_get.return_value = "http://example.com" + + result = hook.get_conn() + + expected_result = ("http://example.com", "your_api_token") + + # Assert that the actual result matches the expected result + assert result == expected_result + + mock_get_connection.assert_called_once_with(hook.astro_cloud_conn_id) + + # Reset the mocks + mock_get_connection.reset_mock() + mock_os_get.reset_mock() + + # Test case where conn.host is None + mock_conn.host = None + mock_os_get.return_value = None + + with pytest.raises(AirflowException): + hook.get_conn() + + mock_get_connection.assert_called_once_with(hook.astro_cloud_conn_id) + + # Reset the mocks + mock_get_connection.reset_mock() + + # Test case where conn.password is None + mock_conn.host = "http://example.com" + mock_conn.password = None + + with pytest.raises(AirflowException): + hook.get_conn() + + mock_get_connection.assert_called_once_with(hook.astro_cloud_conn_id) + + @patch("astronomer.providers.core.hooks.astro.AstroHook.get_conn") + def test_headers(self, mock_get_conn): + # Create an instance of your class + your_instance = AstroHook() + + # Mock the return value for the get_conn method + mock_get_conn.return_value = ("http://example.com", "your_api_token") + + # Call the property and get the result + result = your_instance._headers + + # Define the expected result based on the method implementation + expected_result = {"accept": "application/json", "Authorization": "Bearer your_api_token"} + + # Assert that the actual result matches the expected result + assert result == expected_result + + # Assert that get_conn was called once + mock_get_conn.assert_called_once() + + @patch("airflow.hooks.base.BaseHook.get_connection") + @patch("astronomer.providers.core.hooks.astro.requests.get") + def test_get_dag_runs(self, mock_requests_get, mock_get_connection): + hook = AstroHook() + + # Mocking the response from requests.get + mock_response = MagicMock() + mock_response.json.return_value = {"dag_runs": [{"dag_run_id": "123", "state": "running"}]} + mock_requests_get.return_value = mock_response + + # Calling the method to be tested + result = hook.get_dag_runs("external_dag_id") + + # Assertions + mock_requests_get.assert_called_once() + assert result == [{"dag_run_id": "123", "state": "running"}] + + @patch("airflow.hooks.base.BaseHook.get_connection") + @patch("astronomer.providers.core.hooks.astro.requests.get") + def test_get_dag_run(self, mock_requests_get, mock_get_connection): + hook = AstroHook() + + # Mocking the response from requests.get + mock_response = MagicMock() + mock_response.json.return_value = {"dag_run_id": "123", "state": "running"} + mock_requests_get.return_value = mock_response + + # Calling the method to be tested + result = hook.get_dag_run("external_dag_id", "123") + + # Assertions + mock_requests_get.assert_called_once() + assert result == {"dag_run_id": "123", "state": "running"} + + @patch("airflow.hooks.base.BaseHook.get_connection") + @patch("astronomer.providers.core.hooks.astro.requests.get") + def test_get_task_instance(self, mock_requests_get, mock_get_connection): + hook = AstroHook() + + # Mocking the response from requests.get + mock_response = MagicMock() + mock_response.json.return_value = {"task_instance_id": "456", "state": "success"} + mock_requests_get.return_value = mock_response + + # Calling the method to be tested + result = hook.get_task_instance("external_dag_id", "123", "external_task_id") + + # Assertions + mock_requests_get.assert_called_once() + assert result == {"task_instance_id": "456", "state": "success"} + + @pytest.mark.asyncio + @mock.patch("astronomer.providers.core.hooks.astro.AstroHook._headers") + async def test_get_a_dag_run(self, mock_headers): + external_dag_id = "your_external_dag_id" + dag_run_id = "your_dag_run_id" + url = f"https://test.com/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}" + + # Mocking necessary objects + your_class_instance = AstroHook() + your_class_instance.get_conn = Mock(return_value=("https://test.com", "Test Token")) + mock_headers.return_value = {"accept": "application/json", "Authorization": "Bearer Token"} + response_data = { + "conf": {}, + "dag_id": "my_dag", + "dag_run_id": "manual__2024-02-14T19:06:32.053905+00:00", + "data_interval_end": "2024-02-14T19:06:32.053905+00:00", + "data_interval_start": "2024-02-14T19:06:32.053905+00:00", + "end_date": "2024-02-14T19:16:33.987139+00:00", + "execution_date": "2024-02-14T19:06:32.053905+00:00", + "external_trigger": True, + "last_scheduling_decision": "2024-02-14T19:16:33.985973+00:00", + "logical_date": "2024-02-14T19:06:32.053905+00:00", + "note": None, + "run_type": "manual", + "start_date": "2024-02-14T19:06:33.004299+00:00", + "state": "success", + } + + with aioresponses() as mock_session: + mock_session.get( + url, + headers=your_class_instance._headers, + status=200, + payload=response_data, + ) + + result = await your_class_instance.get_a_dag_run(external_dag_id, dag_run_id) + + assert result == response_data + + @pytest.mark.asyncio + @mock.patch("astronomer.providers.core.hooks.astro.AstroHook._headers") + async def test_get_a_task_instance(self, mock_headers): + external_dag_id = "your_external_dag_id" + dag_run_id = "your_dag_run_id" + external_task_id = "your_external_task_id" + url = f"https://test.com/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}" + + # Mocking necessary objects + your_class_instance = AstroHook() + your_class_instance.get_conn = Mock(return_value=("https://test.com", "Test Token")) + mock_headers.return_value = {"accept": "application/json", "Authorization": "Bearer Token"} + response_data = { + "dag_id": "my_dag", + "dag_run_id": "manual__2024-02-14T19:06:32.053905+00:00", + "duration": 600.233105, + "end_date": "2024-02-14T19:16:33.459676+00:00", + "execution_date": "2024-02-14T19:06:32.053905+00:00", + "executor_config": "{}", + "hostname": "d10fc8b0ad27", + "map_index": -1, + "max_tries": 0, + "note": None, + "operator": "_PythonDecoratedOperator", + "pid": 927, + "pool": "default_pool", + "pool_slots": 1, + "priority_weight": 1, + "queue": "default", + "queued_when": "2024-02-14T19:06:33.036108+00:00", + "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "sla_miss": None, + "start_date": "2024-02-14T19:06:33.226571+00:00", + "state": "success", + "task_id": "my_python_function", + "trigger": None, + "triggerer_job": None, + "try_number": 1, + "unixname": "astro", + } + + with aioresponses() as mock_session: + mock_session.get( + url, + headers=your_class_instance._headers, + status=200, + payload=response_data, + ) + + result = await your_class_instance.get_a_task_instance( + external_dag_id, dag_run_id, external_task_id + ) + + assert result == response_data diff --git a/tests/core/sensors/test_astro.py b/tests/core/sensors/test_astro.py new file mode 100644 index 000000000..f7324c310 --- /dev/null +++ b/tests/core/sensors/test_astro.py @@ -0,0 +1,90 @@ +from unittest import mock + +import pytest +from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred + +from astronomer.providers.core.sensors.astro import ExternalDeploymentSensor +from astronomer.providers.core.triggers.astro import AstroDeploymentTrigger + + +class TestExternalDeploymentSensor: + + @pytest.mark.parametrize( + "get_dag_runs_response", + [ + None, + [], + [{"dag_run_id": "run_id", "state": "success"}], + [{"dag_run_id": "run_id", "state": "running"}], + [{"dag_run_id": "run_id", "state": "queued"}], + ], + ) + @mock.patch("astronomer.providers.core.hooks.astro.AstroHook.get_dag_runs") + def test_poke_dag(self, mock_get_dag_runs_response, get_dag_runs_response, context): + mock_get_dag_runs_response.return_value = get_dag_runs_response + sensor = ExternalDeploymentSensor(task_id="test_me", external_dag_id="test_dag") + response = sensor.poke(context) + if get_dag_runs_response in [None, []]: + assert response is True + elif get_dag_runs_response[0].get("state") == "success": + assert response is True + else: + assert response is False + + @pytest.mark.parametrize("task_state", ["success", "running"]) + @mock.patch("astronomer.providers.core.hooks.astro.AstroHook.get_task_instance") + @mock.patch("astronomer.providers.core.hooks.astro.AstroHook.get_dag_runs") + def test_poke_task(self, mock_get_dag_runs_response, mock_get_task_instance, task_state, context): + mock_get_dag_runs_response.return_value = [{"dag_run_id": "run_id", "state": "running"}] + mock_get_task_instance.return_value = {"state": task_state} + sensor = ExternalDeploymentSensor( + task_id="test_me", external_dag_id="test_dag", external_task_id="task_id" + ) + response = sensor.poke(context) + if task_state == "success": + assert response is True + else: + assert response is False + + @pytest.mark.parametrize("poke_response", [True, False]) + @mock.patch("astronomer.providers.core.sensors.astro.ExternalDeploymentSensor.poke") + def test_execute(self, mock_poke, poke_response, context): + mock_poke.return_value = poke_response + + sensor = ExternalDeploymentSensor(task_id="test_me", external_dag_id="test_dag") + if poke_response: + response = sensor.execute(context) + assert response is None + else: + with pytest.raises(TaskDeferred) as exc: + sensor.execute(context) + assert isinstance( + exc.value.trigger, AstroDeploymentTrigger + ), "Trigger is not a AstroDeploymentTrigger" + + @pytest.mark.parametrize( + "event,soft_fail", + [ + ({"status": "done"}, False), + ({"status": "done"}, True), + ({"status": "failed"}, False), + ({"status": "failed"}, True), + ], + ) + def test_execute_complete(self, event, soft_fail, context): + sensor = ExternalDeploymentSensor(task_id="test_me", external_dag_id="test_dag", soft_fail=soft_fail) + + if soft_fail: + if event.get("status") == "failed": + with pytest.raises(AirflowSkipException) as exc: + sensor.execute_complete(context, event) + assert str(exc.value) == "Upstream job failed. Skipping the task." + if event.get("status") == "done": + assert sensor.execute_complete(context, event) is None + else: + if event.get("status") == "failed": + with pytest.raises(AirflowException) as exc: + sensor.execute_complete(context, event) + assert str(exc.value) == "Upstream job failed." + if event.get("status") == "done": + assert sensor.execute_complete(context, event) is None diff --git a/tests/core/triggers/test_astro.py b/tests/core/triggers/test_astro.py new file mode 100644 index 000000000..8fc8f962d --- /dev/null +++ b/tests/core/triggers/test_astro.py @@ -0,0 +1,120 @@ +import asyncio +from unittest.mock import patch + +import pytest +from airflow.triggers.base import TriggerEvent + +from astronomer.providers.core.triggers.astro import AstroDeploymentTrigger + + +class TestAstroDeploymentTrigger: + + def test_serialize(self): + trigger = AstroDeploymentTrigger( + external_dag_id="external_dag_id", + dag_run_id="dag_run_id", + external_task_id="external_task_id", + astro_cloud_conn_id="astro_cloud_conn_id", + poke_interval=1.0, + ) + + serialized_data = trigger.serialize() + + expected_result = ( + "astronomer.providers.core.triggers.astro.AstroDeploymentTrigger", + { + "external_dag_id": "external_dag_id", + "external_task_id": "external_task_id", + "dag_run_id": "dag_run_id", + "astro_cloud_conn_id": "astro_cloud_conn_id", + "poke_interval": 1.0, + }, + ) + + assert serialized_data == expected_result + + @pytest.mark.asyncio + @patch("astronomer.providers.core.hooks.astro.AstroHook.get_a_task_instance") + async def test_run_task_successful(self, mock_get_task_instance): + trigger = AstroDeploymentTrigger( + external_dag_id="external_dag_id", + dag_run_id="dag_run_id", + external_task_id="external_task_id", + astro_cloud_conn_id="astro_cloud_conn_id", + poke_interval=1.0, + ) + + mock_get_task_instance.return_value = {"state": "success"} + + generator = trigger.run() + actual = await generator.asend(None) + assert actual == TriggerEvent({"status": "done"}) + + @pytest.mark.asyncio + @patch("astronomer.providers.core.hooks.astro.AstroHook.get_a_task_instance") + async def test_run_task_failed(self, mock_get_task_instance): + trigger = AstroDeploymentTrigger( + external_dag_id="external_dag_id", + dag_run_id="dag_run_id", + external_task_id="external_task_id", + astro_cloud_conn_id="astro_cloud_conn_id", + poke_interval=1.0, + ) + + mock_get_task_instance.return_value = {"state": "failed"} + + generator = trigger.run() + actual = await generator.asend(None) + assert actual == TriggerEvent({"status": "failed"}) + + @pytest.mark.asyncio + @patch("astronomer.providers.core.hooks.astro.AstroHook.get_a_dag_run") + async def test_run_dag_successful(self, mock_get_dag_run): + trigger = AstroDeploymentTrigger( + external_dag_id="external_dag_id", + dag_run_id="dag_run_id", + astro_cloud_conn_id="astro_cloud_conn_id", + poke_interval=1.0, + ) + + # Mocking AstroHook responses for a successful DAG run + mock_get_dag_run.return_value = {"state": "success"} + + generator = trigger.run() + actual = await generator.asend(None) + assert actual == TriggerEvent({"status": "done"}) + + @pytest.mark.asyncio + @patch("astronomer.providers.core.hooks.astro.AstroHook.get_a_dag_run") + async def test_run_dag_failed(self, mock_get_dag_run): + trigger = AstroDeploymentTrigger( + external_dag_id="external_dag_id", + dag_run_id="dag_run_id", + astro_cloud_conn_id="astro_cloud_conn_id", + poke_interval=1.0, + ) + + mock_get_dag_run.return_value = {"state": "failed"} + + generator = trigger.run() + actual = await generator.asend(None) + assert actual == TriggerEvent({"status": "failed"}) + + @pytest.mark.asyncio + @patch("astronomer.providers.core.hooks.astro.AstroHook.get_a_dag_run") + async def test_run_dag_wait(self, mock_get_dag_run): + trigger = AstroDeploymentTrigger( + external_dag_id="external_dag_id", + dag_run_id="dag_run_id", + astro_cloud_conn_id="astro_cloud_conn_id", + poke_interval=1.0, + ) + + mock_get_dag_run.return_value = {"state": "running"} + + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was not returned + assert task.done() is False + asyncio.get_event_loop().stop()