Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ExternalDeploymentSensor #1472

Merged
merged 5 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions astronomer/providers/core/example_dags/example_astro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import time
from datetime import datetime

from airflow import DAG
from airflow.decorators import task
from airflow.operators.trigger_dagrun import TriggerDagRunOperator

from astronomer.providers.core.sensors.astro import ExternalDeploymentSensor

with DAG(
dag_id="example_astro_task",
start_date=datetime(2022, 1, 1),
schedule=None,
catchup=False,
tags=["example", "async", "core"],
):
ExternalDeploymentSensor(
task_id="test1",
external_dag_id="example_wait_to_test_example_astro_task",
)

ExternalDeploymentSensor(
task_id="test2",
external_dag_id="example_wait_to_test_example_astro_task",
external_task_id="wait_for_2_min",
)

with DAG(
dag_id="example_wait_to_test_example_astro_task",
start_date=datetime(2022, 1, 1),
schedule=None,
catchup=False,
tags=["example", "async", "core"],
):

@task
def wait_for_2_min() -> None:
"""Wait for 2 min."""
time.sleep(120)

wait_for_2_min()


with DAG(
dag_id="trigger_astro_test_and_example",
start_date=datetime(2022, 1, 1),
schedule=None,
catchup=False,
tags=["example", "async", "core"],
):
run_wait_dag = TriggerDagRunOperator(
task_id="run_wait_dag",
pankajastro marked this conversation as resolved.
Show resolved Hide resolved
trigger_dag_id="example_wait_to_test_example_astro_task",
wait_for_completion=False,
)

run_astro_dag = TriggerDagRunOperator(
task_id="run_astro_dag",
trigger_dag_id="example_astro_task",
wait_for_completion=False,
)

run_wait_dag >> run_astro_dag
Empty file.
156 changes: 156 additions & 0 deletions astronomer/providers/core/hooks/astro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from __future__ import annotations

import os
from typing import Any
from urllib.parse import quote

import requests
from aiohttp import ClientSession
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook


class AstroHook(BaseHook):
"""
Custom Apache Airflow Hook for interacting with Astro Cloud API.

:param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials.
"""

conn_name_attr = "astro_cloud_conn_id"
default_conn_name = "astro_cloud_default"
conn_type = "Astro Cloud"
hook_name = "Astro Cloud"

def __init__(self, astro_cloud_conn_id: str = "astro_cloud_conn_id"):
super().__init__()
self.astro_cloud_conn_id = astro_cloud_conn_id

@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""
Returns UI field behavior customization for the Astro Cloud connection.

This method defines hidden fields, relabeling, and placeholders for UI display.
"""
return {
"hidden_fields": ["login", "port", "schema", "extra"],
"relabeling": {
"password": "Astro Cloud API Token",
},
"placeholders": {
"host": "https://clmkpsyfc010391acjie00t1l.astronomer.run/d5lc9c9x",
"password": "Astro API JWT Token",
},
}

def get_conn(self) -> tuple[str, str]:
"""Retrieves the Astro Cloud connection details."""
conn = BaseHook.get_connection(self.astro_cloud_conn_id)
base_url = conn.host or os.environ.get("AIRFLOW__WEBSERVER__BASE_URL")
if base_url is None:
raise AirflowException(f"Airflow host is missing in connection {self.astro_cloud_conn_id}")
token = conn.password
if token is None:
raise AirflowException(f"Astro API token is missing in connection {self.astro_cloud_conn_id}")
return base_url, token

@property
def _headers(self) -> dict[str, str]:
"""Generates and returns headers for Astro Cloud API requests."""
_, token = self.get_conn()
headers = {"accept": "application/json", "Authorization": f"Bearer {token}"}
return headers

def get_dag_runs(self, external_dag_id: str) -> list[dict[str, str]]:
"""
Retrieves information about running or queued DAG runs.

:param external_dag_id: External ID of the DAG.
"""
base_url, _ = self.get_conn()
path = f"/api/v1/dags/{external_dag_id}/dagRuns"
params: dict[str, int | str | list[str]] = {
"limit": 1,
"state": ["running", "queued"],
"order_by": "-execution_date",
}
url = f"{base_url}{path}"
response = requests.get(url, headers=self._headers, params=params)
response.raise_for_status()
data: dict[str, list[dict[str, str]]] = response.json()
return data["dag_runs"]

def get_dag_run(self, external_dag_id: str, dag_run_id: str) -> dict[str, Any] | None:
"""
Retrieves information about a specific DAG run.

:param external_dag_id: External ID of the DAG.
:param dag_run_id: ID of the DAG run.
"""
base_url, _ = self.get_conn()
dag_run_id = quote(dag_run_id)
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}"
url = f"{base_url}{path}"
response = requests.get(url, headers=self._headers)
response.raise_for_status()
dr: dict[str, Any] = response.json()
return dr

async def get_a_dag_run(self, external_dag_id: str, dag_run_id: str) -> dict[str, Any] | None:
"""
Retrieves information about a specific DAG run.

:param external_dag_id: External ID of the DAG.
:param dag_run_id: ID of the DAG run.
"""
base_url, _ = self.get_conn()
dag_run_id = quote(dag_run_id)
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}"
url = f"{base_url}{path}"

async with ClientSession(headers=self._headers) as session:
async with session.get(url) as response:
response.raise_for_status()
dr: dict[str, Any] = await response.json()
return dr

def get_task_instance(
self, external_dag_id: str, dag_run_id: str, external_task_id: str
) -> dict[str, Any] | None:
"""
Retrieves information about a specific task instance within a DAG run.

:param external_dag_id: External ID of the DAG.
:param dag_run_id: ID of the DAG run.
:param external_task_id: External ID of the task.
"""
base_url, _ = self.get_conn()
dag_run_id = quote(dag_run_id)
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}"
url = f"{base_url}{path}"
response = requests.get(url, headers=self._headers)
response.raise_for_status()
ti: dict[str, Any] = response.json()
return ti

async def get_a_task_instance(
self, external_dag_id: str, dag_run_id: str, external_task_id: str
) -> dict[str, Any] | None:
"""
Retrieves information about a specific task instance within a DAG run.

:param external_dag_id: External ID of the DAG.
:param dag_run_id: ID of the DAG run.
:param external_task_id: External ID of the task.
"""
base_url, _ = self.get_conn()
dag_run_id = quote(dag_run_id)
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}"
url = f"{base_url}{path}"

async with ClientSession(headers=self._headers) as session:
async with session.get(url) as response:
response.raise_for_status()
ti: dict[str, Any] = await response.json()
return ti
100 changes: 100 additions & 0 deletions astronomer/providers/core/sensors/astro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from __future__ import annotations

import datetime
from typing import Any, cast

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue

from astronomer.providers.core.hooks.astro import AstroHook
from astronomer.providers.core.triggers.astro import AstroDeploymentTrigger
from astronomer.providers.utils.typing_compat import Context


class ExternalDeploymentSensor(BaseSensorOperator):
"""
Custom Apache Airflow sensor for monitoring external deployments using Astro Cloud.

:param external_dag_id: External ID of the DAG being monitored.
:param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials.
Defaults to "astro_cloud_default".
:param external_task_id: External ID of the task being monitored. If None, monitors the entire DAG.
:param kwargs: Additional keyword arguments passed to the BaseSensorOperator constructor.
"""

def __init__(
self,
external_dag_id: str,
astro_cloud_conn_id: str = "astro_cloud_default",
external_task_id: str | None = None,
**kwargs: Any,
):
super().__init__(**kwargs)
self.astro_cloud_conn_id = astro_cloud_conn_id
self.external_task_id = external_task_id
self.external_dag_id = external_dag_id
self._dag_run_id: str = ""

def poke(self, context: Context) -> bool | PokeReturnValue:
"""
Check the status of a DAG/task in another deployment.

Queries Airflow's REST API for the status of the specified DAG or task instance.
Returns True if successful, False otherwise.

:param context: The task execution context.
"""
hook = AstroHook(self.astro_cloud_conn_id)
dag_runs: list[dict[str, Any]] = hook.get_dag_runs(self.external_dag_id)
if not dag_runs:
self.log.info("No DAG runs found for DAG %s", self.external_dag_id)
return True
self._dag_run_id = cast(str, dag_runs[0]["dag_run_id"])
if self.external_task_id is not None:
task_instance = hook.get_task_instance(
self.external_dag_id, self._dag_run_id, self.external_task_id
)
task_state = task_instance.get("state") if task_instance else None
if task_state == "success":
return True
else:
state = dag_runs[0].get("state")
if state == "success":
return True
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly this fetches the latest DAG run and checks the status?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it check the status of last dagrun based on execution date


def execute(self, context: Context) -> Any:
"""
Executes the sensor.

If the external deployment is not successful, it defers the execution using an AstroDeploymentTrigger.

:param context: The task execution context.
"""
if not self.poke(context):
self.defer(
timeout=datetime.timedelta(seconds=self.timeout),
trigger=AstroDeploymentTrigger(
astro_cloud_conn_id=self.astro_cloud_conn_id,
external_task_id=self.external_task_id,
external_dag_id=self.external_dag_id,
poke_interval=self.poke_interval,
dag_run_id=self._dag_run_id,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, str]) -> None:
"""
Handles the completion event from the deferred execution.

Raises AirflowSkipException if the upstream job failed and `soft_fail` is True.
Otherwise, raises AirflowException.

:param context: The task execution context.
:param event: The event dictionary received from the deferred execution.
"""
if event.get("status") == "failed":
if self.soft_fail:
raise AirflowSkipException("Upstream job failed. Skipping the task.")
raise AirflowException("Upstream job failed.")
Loading
Loading