Skip to content

Commit

Permalink
Add ExternalDeploymentSensor
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro committed Feb 14, 2024
1 parent 1686086 commit d578ccb
Show file tree
Hide file tree
Showing 10 changed files with 728 additions and 1 deletion.
63 changes: 63 additions & 0 deletions astronomer/providers/core/example_dags/example_astro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import time
from datetime import datetime

from airflow import DAG
from airflow.decorators import task
from airflow.operators.trigger_dagrun import TriggerDagRunOperator

from astronomer.providers.core.sensors.astro import ExternalDeploymentSensor

with DAG(
dag_id="example_astro_task",
start_date=datetime(2022, 1, 1),
schedule=None,
catchup=False,
tags=["example", "async", "core"],
):
ExternalDeploymentSensor(
task_id="test1",
external_dag_id="example_wait_to_test_example_astro_task",
)

ExternalDeploymentSensor(
task_id="test2",
external_dag_id="example_wait_to_test_example_astro_task",
external_task_id="wait_for_2_min",
)

with DAG(
dag_id="wait_to_test_example_astro_dag",
start_date=datetime(2022, 1, 1),
schedule=None,
catchup=False,
tags=["example", "async", "core"],
):

@task
def wait_for_2_min() -> None:
"""Wait for 2 min."""
time.sleep(120)

wait_for_2_min()


with DAG(
dag_id="trigger_astro_test_and_example",
start_date=datetime(2022, 1, 1),
schedule=None,
catchup=False,
tags=["example", "async", "core"],
):
run_wait_dag = TriggerDagRunOperator(
task_id="run_wait_dag",
trigger_dag_id="example_external_task_async_waits_for_me",
wait_for_completion=False,
)

run_astro_dag = TriggerDagRunOperator(
task_id="run_astro_dag",
trigger_dag_id="example_astro_task",
wait_for_completion=False,
)

run_wait_dag >> run_astro_dag
Empty file.
112 changes: 112 additions & 0 deletions astronomer/providers/core/hooks/astro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from __future__ import annotations

import os
from typing import Any
from urllib.parse import quote

import requests
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook


class AstroHook(BaseHook):
"""
Custom Apache Airflow Hook for interacting with Astro Cloud API.
:param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials.
"""

conn_name_attr = "astro_cloud_conn_id"
default_conn_name = "astro_cloud_default"
conn_type = "Astro Cloud"
hook_name = "Astro Cloud"

def __init__(self, astro_cloud_conn_id: str = "astro_cloud_conn_id"):
super().__init__()
self.astro_cloud_conn_id = astro_cloud_conn_id

@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""
Returns UI field behavior customization for the Astro Cloud connection.
This method defines hidden fields, relabeling, and placeholders for UI display.
"""
return {
"hidden_fields": ["login", "port", "schema", "extra"],
"relabeling": {
"password": "Astro Cloud API Token",
},
"placeholders": {
"host": "https://clmkpsyfc010391acjie00t1l.astronomer.run/d5lc9c9x",
"password": "JWT API Token",
},
}

def get_conn(self) -> tuple[str, str]:
"""Retrieves the Astro Cloud connection details."""
conn = BaseHook.get_connection(self.astro_cloud_conn_id)
base_url = conn.host or os.environ.get("AIRFLOW__WEBSERVER__BASE_URL")
if base_url is None:
raise AirflowException(f"Airflow host is missing in connection {self.astro_cloud_conn_id}")
token = conn.password
if token is None:
raise AirflowException(f"Astro API token is missing in connection {self.astro_cloud_conn_id}")
return base_url, token

@property
def _headers(self) -> dict[str, str]:
"""Generates and returns headers for Astro Cloud API requests."""
_, token = self.get_conn()
headers = {"accept": "application/json", "Authorization": f"Bearer {token}"}
return headers

def get_dag_runs(self, external_dag_id: str) -> list[dict[str, str]]:
"""
Retrieves information about running or queued DAG runs.
:param external_dag_id: External ID of the DAG.
"""
base_url, _ = self.get_conn()
path = f"/api/v1/dags/{external_dag_id}/dagRuns"
params: dict[str, int | str | list[str]] = {"limit": 1, "state": ["running", "queued"]}
url = f"{base_url}{path}"
response = requests.get(url, headers=self._headers, params=params)
response.raise_for_status()
data: dict[str, list[dict[str, str]]] = response.json()
return data["dag_runs"]

def get_dag_run(self, external_dag_id: str, dag_run_id: str) -> dict[str, Any] | None:
"""
Retrieves information about a specific DAG run.
:param external_dag_id: External ID of the DAG.
:param dag_run_id: ID of the DAG run.
"""
base_url, _ = self.get_conn()
dag_run_id = quote(dag_run_id)
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}"
url = f"{base_url}{path}"
response = requests.get(url, headers=self._headers)
response.raise_for_status()
dr: dict[str, Any] = response.json()
return dr

def get_task_instance(
self, external_dag_id: str, dag_run_id: str, external_task_id: str
) -> dict[str, Any] | None:
"""
Retrieves information about a specific task instance within a DAG run.
:param external_dag_id: External ID of the DAG.
:param dag_run_id: ID of the DAG run.
:param external_task_id: External ID of the task.
"""
base_url, _ = self.get_conn()
dag_run_id = quote(dag_run_id)
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}"
url = f"{base_url}{path}"
response = requests.get(url, headers=self._headers)
response.raise_for_status()
ti: dict[str, Any] = response.json()
return ti
115 changes: 115 additions & 0 deletions astronomer/providers/core/sensors/astro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import annotations

import datetime

# import time
from typing import Any, cast

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue

from astronomer.providers.core.hooks.astro import AstroHook
from astronomer.providers.core.triggers.astro import AstroDeploymentTrigger
from astronomer.providers.utils.typing_compat import Context


class ExternalDeploymentSensor(BaseSensorOperator):
"""
Custom Apache Airflow sensor for monitoring external deployments using Astro Cloud.
:param external_dag_id: External ID of the DAG being monitored.
:param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials.
Defaults to "astro_cloud_default".
:param external_task_id: External ID of the task being monitored. If None, monitors the entire DAG.
:param kwargs: Additional keyword arguments passed to the BaseSensorOperator constructor.
"""

def __init__(
self,
external_dag_id: str,
astro_cloud_conn_id: str = "astro_cloud_default",
external_task_id: str | None = None,
**kwargs: Any,
):
super().__init__(**kwargs)
self.astro_cloud_conn_id = astro_cloud_conn_id
self.external_task_id = external_task_id
self.external_dag_id = external_dag_id
self._dag_run_id: str = ""

# def wait_for_dag_start(self, second_to_wait: int = 120, sleep: int = 5) -> None:
# """TODO"""
# hook = AstroHook(self.astro_cloud_conn_id)
# end_time = datetime.datetime.now() + datetime.timedelta(seconds=second_to_wait)
# while end_time >= datetime.datetime.now():
# try:
# dag_runs = hook.get_dag_runs(self.external_dag_id)
# if dag_runs is not None:
# return
# except Exception:
# time.sleep(sleep)

def poke(self, context: Context) -> bool | PokeReturnValue:
"""
Check the status of a DAG/task in another deployment.
Queries Airflow's REST API for the status of the specified DAG or task instance.
Returns True if successful, False otherwise.
:param context: The task execution context.
"""
hook = AstroHook(self.astro_cloud_conn_id)
dag_runs: list[dict[str, Any]] = hook.get_dag_runs(self.external_dag_id)
if dag_runs is None or len(dag_runs) == 0:
self.log.info("No DAG runs found for DAG %s", self.external_dag_id)
return True
self._dag_run_id = cast(str, dag_runs[0]["dag_run_id"])
if self.external_task_id is not None:
task_instance = hook.get_task_instance(
self.external_dag_id, self._dag_run_id, self.external_task_id
)
task_state = task_instance.get("state") if task_instance else None
if task_state == "success":
return True
else:
state = dag_runs[0].get("state")
if state == "success":
return True
return False

def execute(self, context: Context) -> Any:
"""
Executes the sensor.
If the external deployment is not successful, it defers the execution using an AstroDeploymentTrigger.
:param context: The task execution context.
"""
if not self.poke(context):
self.defer(
timeout=datetime.timedelta(seconds=self.timeout),
trigger=AstroDeploymentTrigger(
astro_cloud_conn_id=self.astro_cloud_conn_id,
external_task_id=self.external_task_id,
external_dag_id=self.external_dag_id,
poke_interval=self.poke_interval,
dag_run_id=self._dag_run_id,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, str]) -> None:
"""
Handles the completion event from the deferred execution.
Raises AirflowSkipException if the upstream job failed and `soft_fail` is True.
Otherwise, raises AirflowException.
:param context: The task execution context.
:param event: The event dictionary received from the deferred execution.
"""
if event.get("status") == "failed":
if self.soft_fail:
raise AirflowSkipException("Upstream job failed. Skipping the task.")
else:
raise AirflowException("Upstream job failed.")
78 changes: 78 additions & 0 deletions astronomer/providers/core/triggers/astro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

import asyncio
from typing import Any, AsyncIterator

from airflow.triggers.base import BaseTrigger, TriggerEvent

from astronomer.providers.core.hooks.astro import AstroHook


class AstroDeploymentTrigger(BaseTrigger):
"""
Custom Apache Airflow trigger for monitoring the completion status of an external deployment using Astro Cloud.
:param external_dag_id: External ID of the DAG being monitored.
:param dag_run_id: ID of the DAG run being monitored.
:param external_task_id: External ID of the task being monitored. If None, monitors the entire DAG.
:param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials. Defaults to "astro_cloud_default".
:param poke_interval: Time in seconds to wait between consecutive checks for completion status.
:param kwargs: Additional keyword arguments passed to the BaseTrigger constructor.
"""

def __init__(
self,
external_dag_id: str,
dag_run_id: str,
external_task_id: str | None = None,
astro_cloud_conn_id: str = "astro_cloud_default",
poke_interval: float = 5.0,
**kwargs: Any,
):
super().__init__(**kwargs)
self.external_dag_id = external_dag_id
self.dag_run_id = dag_run_id
self.external_task_id = external_task_id
self.astro_cloud_conn_id = astro_cloud_conn_id
self.poke_interval = poke_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize the trigger for storage in the database."""
return (
"astronomer.providers.core.triggers.astro.AstroDeploymentTrigger",
{
"external_dag_id": self.external_dag_id,
"external_task_id": self.external_task_id,
"dag_run_id": self.dag_run_id,
"astro_cloud_conn_id": self.astro_cloud_conn_id,
"poke_interval": self.poke_interval,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Asynchronously runs the trigger and yields completion status events.
Checks the status of the external deployment using Astro Cloud at regular intervals.
Yields TriggerEvent with the status "done" if successful, "failed" if failed.
"""
hook = AstroHook(self.astro_cloud_conn_id)
while True:
if self.external_task_id is not None:
task_instance = hook.get_task_instance(
self.external_dag_id, self.dag_run_id, self.external_task_id
)
state = task_instance.get("state") if task_instance else None
if state in ("success", "skipped"):
yield TriggerEvent({"status": "done"})
elif state in ("failed", "upstream_failed"):
yield TriggerEvent({"status": "failed"})
else:
dag_run = hook.get_dag_run(self.external_dag_id, self.dag_run_id)
state = dag_run.get("state") if dag_run else None
if state == "success":
yield TriggerEvent({"status": "done"})
elif state == "failed":
yield TriggerEvent({"status": "failed"})
self.log.info("Job status is %s sleeping for %s seconds.", state, self.poke_interval)
await asyncio.sleep(self.poke_interval)
7 changes: 6 additions & 1 deletion astronomer/providers/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ def get_provider_info() -> Dict[str, Any]:
"description": "Apache Airflow Providers containing Deferrable Operators & Sensors from Astronomer",
"versions": "1.18.4",
# Optional.
"hook-class-names": [],
"connection-types": [
{
"hook-class-name": "astronomer.providers.core.hooks.astro.AstroHook",
"connection-type": "Astro Cloud",
}
],
"extra-links": [],
}
Empty file added tests/core/hooks/__init__.py
Empty file.
Loading

0 comments on commit d578ccb

Please sign in to comment.