Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept task_key as an argument in DatabricksNotebookOperator #43106

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 27 additions & 20 deletions providers/src/airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from __future__ import annotations

import hashlib
import time
from abc import ABC, abstractmethod
from functools import cached_property
Expand Down Expand Up @@ -992,6 +993,7 @@ def __init__(
self,
caller: str = "DatabricksTaskBaseOperator",
databricks_conn_id: str = "databricks_default",
databricks_task_key: str | None = None,
databricks_retry_args: dict[Any, Any] | None = None,
databricks_retry_delay: int = 1,
databricks_retry_limit: int = 3,
Expand All @@ -1006,6 +1008,7 @@ def __init__(
):
self.caller = caller
self.databricks_conn_id = databricks_conn_id
self._databricks_task_key = databricks_task_key
self.databricks_retry_args = databricks_retry_args
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_limit = databricks_retry_limit
Expand Down Expand Up @@ -1034,6 +1037,24 @@ def __init__(
def _hook(self) -> DatabricksHook:
return self._get_hook(caller=self.caller)

@cached_property
def databricks_task_key(self):
if self._databricks_task_key is None:
self.log.info(
"No databricks_task_key provided. Generating task key by concatenation of dag_id and task_id."
)
self._databricks_task_key = f"{self.dag_id}__{self.task_id.replace('.', '__')}"

if len(self._databricks_task_key) > 100:
self.log.warning(
"The databricks_task_key '%s' exceeds 100 characters and will be truncated by the Databricks API. "
"This will cause failure when trying to monitor the task. Hence, task key will be hashed.",
self._databricks_task_key,
)
return hashlib.md5(self._databricks_task_key.encode()).hexdigest()

return self._databricks_task_key

def _get_hook(self, caller: str) -> DatabricksHook:
return DatabricksHook(
self.databricks_conn_id,
Expand All @@ -1043,18 +1064,6 @@ def _get_hook(self, caller: str) -> DatabricksHook:
caller=caller,
)

def _get_databricks_task_id(self, task_id: str) -> str:
"""Get the databricks task ID using dag_id and task_id. Removes illegal characters."""
task_id = f"{self.dag_id}__{task_id.replace('.', '__')}"
if len(task_id) > 100:
self.log.warning(
"The generated task_key '%s' exceeds 100 characters and will be truncated by the Databricks API. "
"This will cause failure when trying to monitor the task. task_key is generated by ",
"concatenating dag_id and task_id.",
task_id,
)
return task_id

@property
def _databricks_workflow_task_group(self) -> DatabricksWorkflowTaskGroup | None:
"""
Expand Down Expand Up @@ -1083,7 +1092,7 @@ def _get_task_base_json(self) -> dict[str, Any]:
def _get_run_json(self) -> dict[str, Any]:
"""Get run json to be used for task submissions."""
run_json = {
"run_name": self._get_databricks_task_id(self.task_id),
"run_name": self.databricks_task_key,
**self._get_task_base_json(),
}
if self.new_cluster and self.existing_cluster_id:
Expand Down Expand Up @@ -1133,19 +1142,17 @@ def _get_current_databricks_task(self) -> dict[str, Any]:
# building the {task_key: task} map below.
sorted_task_runs = sorted(tasks, key=lambda x: x["start_time"])

return {task["task_key"]: task for task in sorted_task_runs}[
self._get_databricks_task_id(self.task_id)
]
return {task["task_key"]: task for task in sorted_task_runs}[self.databricks_task_key]

def _convert_to_databricks_workflow_task(
self, relevant_upstreams: list[BaseOperator], context: Context | None = None
) -> dict[str, object]:
"""Convert the operator to a Databricks workflow task that can be a task in a workflow."""
base_task_json = self._get_task_base_json()
result = {
"task_key": self._get_databricks_task_id(self.task_id),
"task_key": self.databricks_task_key,
"depends_on": [
{"task_key": self._get_databricks_task_id(task_id)}
{"task_key": self.databricks_task_key}
for task_id in self.upstream_task_ids
if task_id in relevant_upstreams
],
Expand Down Expand Up @@ -1178,7 +1185,7 @@ def monitor_databricks_job(self) -> None:
run_state = RunState(**run["state"])
self.log.info(
"Current state of the the databricks task %s is %s",
self._get_databricks_task_id(self.task_id),
self.databricks_task_key,
run_state.life_cycle_state,
)
if self.deferrable and not run_state.is_terminal:
Expand All @@ -1200,7 +1207,7 @@ def monitor_databricks_job(self) -> None:
run_state = RunState(**run["state"])
self.log.info(
"Current state of the databricks task %s is %s",
self._get_databricks_task_id(self.task_id),
self.databricks_task_key,
run_state.life_cycle_state,
)
self._handle_terminal_run_state(run_state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.providers.databricks.operators.databricks import DatabricksTaskBaseOperator


REPAIR_WAIT_ATTEMPTS = os.getenv("DATABRICKS_REPAIR_WAIT_ATTEMPTS", 20)
REPAIR_WAIT_DELAY = os.getenv("DATABRICKS_REPAIR_WAIT_DELAY", 0.5)
Expand All @@ -57,18 +59,8 @@ def get_auth_decorator():
return auth.has_access_dag("POST", DagAccessEntity.RUN)


def _get_databricks_task_id(task: BaseOperator) -> str:
"""
Get the databricks task ID using dag_id and task_id. removes illegal characters.
:param task: The task to get the databricks task ID for.
:return: The databricks task ID.
"""
return f"{task.dag_id}__{task.task_id.replace('.', '__')}"


def get_databricks_task_ids(
group_id: str, task_map: dict[str, BaseOperator], log: logging.Logger
group_id: str, task_map: dict[str, DatabricksTaskBaseOperator], log: logging.Logger
) -> list[str]:
"""
Return a list of all Databricks task IDs for a dictionary of Airflow tasks.
Expand All @@ -83,7 +75,7 @@ def get_databricks_task_ids(
for task_id, task in task_map.items():
if task_id == f"{group_id}.launch":
continue
databricks_task_id = _get_databricks_task_id(task)
databricks_task_id = task.databricks_task_key
log.debug("databricks task id for task %s is %s", task_id, databricks_task_id)
task_ids.append(databricks_task_id)
return task_ids
Expand Down Expand Up @@ -112,7 +104,7 @@ def _clear_task_instances(
dag = airflow_app.dag_bag.get_dag(dag_id)
log.debug("task_ids %s to clear", str(task_ids))
dr: DagRun = _get_dagrun(dag, run_id, session=session)
tis_to_clear = [ti for ti in dr.get_task_instances() if _get_databricks_task_id(ti) in task_ids]
tis_to_clear = [ti for ti in dr.get_task_instances() if ti.databricks_task_key in task_ids]
clear_task_instances(tis_to_clear, session)


Expand Down Expand Up @@ -324,7 +316,7 @@ def get_tasks_to_run(self, ti_key: TaskInstanceKey, operator: BaseOperator, log:

tasks_to_run = {ti: t for ti, t in task_group_sub_tasks if ti in failed_and_skipped_tasks}

return ",".join(get_databricks_task_ids(task_group.group_id, tasks_to_run, log))
return ",".join(get_databricks_task_ids(task_group.group_id, tasks_to_run, log)) # type: ignore[arg-type]

@staticmethod
def _get_failed_and_skipped_tasks(dr: DagRun) -> list[str]:
Expand Down Expand Up @@ -387,7 +379,7 @@ def get_link(
"databricks_conn_id": metadata.conn_id,
"databricks_run_id": metadata.run_id,
"run_id": ti_key.run_id,
"tasks_to_repair": _get_databricks_task_id(task),
"tasks_to_repair": task.databricks_task_key,
}
return url_for("RepairDatabricksTasks.repair", **query_params)

Expand Down
Loading