Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deferrable mode for ECS operators #31881

Merged
merged 26 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
f4ea7d9
add deferrable mode for ECS Create Cluster
vandonr-amz May 25, 2023
4dddc9a
execute task - easy part, without the logs
vandonr-amz Jun 5, 2023
a6e5c0e
add logs support to run task operator
vandonr-amz Jun 6, 2023
bca71da
add tests
vandonr-amz Jun 9, 2023
59ff8f4
add deferrable for delete cluster, by adapting the create cluster tri…
vandonr-amz Jun 13, 2023
5dc163b
add deferrable parameter
vandonr-amz Jun 13, 2023
71b9648
rearranging code around a bit
vandonr-amz Jun 13, 2023
ab0f65d
tests
vandonr-amz Jun 13, 2023
a421457
add dots in comments
vandonr-amz Jun 13, 2023
b2d1dae
fix test
vandonr-amz Jun 13, 2023
913daeb
Merge remote-tracking branch 'origin/main' into vandonr/deferrable
vandonr-amz Jun 13, 2023
305ea5c
add trigger to yaml
vandonr-amz Jun 13, 2023
eab95e0
return last line of logs
vandonr-amz Jun 13, 2023
02ec64e
is this the right integration name ?
vandonr-amz Jun 13, 2023
960d1c5
add timeouts
vandonr-amz Jun 14, 2023
7141b0e
fix CI + some fix
vandonr-amz Jun 14, 2023
e215315
Merge remote-tracking branch 'origin/main' into vandonr/deferrable
vandonr-amz Jun 15, 2023
2eb2252
adjust expected value in test
vandonr-amz Jun 15, 2023
8861815
Merge remote-tracking branch 'origin/main' into vandonr/deferrable
vandonr-amz Jun 16, 2023
6ecd019
fix
vandonr-amz Jun 16, 2023
3038246
Merge remote-tracking branch 'origin/main' into vandonr/deferrable
vandonr-amz Jun 16, 2023
76a9f39
Merge remote-tracking branch 'origin/main' into vandonr/deferrable
vandonr-amz Jun 20, 2023
439bbd9
rename method in test
vandonr-amz Jun 20, 2023
5beb6b2
Merge remote-tracking branch 'origin/main' into vandonr/deferrable
vandonr-amz Jun 22, 2023
c5d50ad
use newly available wait method
vandonr-amz Jun 22, 2023
4e65d26
fix test
vandonr-amz Jun 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 133 additions & 27 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
EcsHook,
should_retry_eni,
)
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.triggers.ecs import (
ClusterWaiterTrigger,
TaskDoneTrigger,
)
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.utils.helpers import prune_dict
from airflow.utils.session import provide_session
Expand Down Expand Up @@ -67,6 +72,15 @@ def execute(self, context: Context):
"""Must overwrite in child classes."""
raise NotImplementedError("Please implement execute() in subclass")

def _complete_exec_with_cluster_desc(self, context, event=None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this callback is shared between create and delete cluster operators, so I put it there. It felt like a better solution than copy-pasting it for both.

"""To be used as trigger callback for operators that return the cluster description."""
if event["status"] != "success":
raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}")
cluster_arn = event.get("arn")
# We cannot get the cluster definition from the waiter on success, so we have to query it here.
details = self.hook.conn.describe_clusters(clusters=[cluster_arn])["clusters"][0]
return details


class EcsCreateClusterOperator(EcsBaseOperator):
"""
Expand All @@ -84,18 +98,27 @@ class EcsCreateClusterOperator(EcsBaseOperator):
if not set then the default waiter value will be used.
:param waiter_max_attempts: The maximum number of attempts to be made,
if not set then the default waiter value will be used.
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
"""

template_fields: Sequence[str] = ("cluster_name", "create_cluster_kwargs", "wait_for_completion")
template_fields: Sequence[str] = (
"cluster_name",
"create_cluster_kwargs",
"wait_for_completion",
"deferrable",
)

def __init__(
self,
*,
cluster_name: str,
create_cluster_kwargs: dict | None = None,
wait_for_completion: bool = True,
waiter_delay: int | None = None,
waiter_max_attempts: int | None = None,
waiter_delay: int = 15,
waiter_max_attempts: int = 60,
deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -104,6 +127,7 @@ def __init__(
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable

def execute(self, context: Context):
self.log.info(
Expand All @@ -119,6 +143,21 @@ def execute(self, context: Context):
# In some circumstances the ECS Cluster is created immediately,
# and there is no reason to wait for completion.
self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state)
elif self.deferrable:
self.defer(
trigger=ClusterWaiterTrigger(
waiter_name="cluster_active",
cluster_arn=cluster_details["clusterArn"],
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region=self.region,
),
method_name="_complete_exec_with_cluster_desc",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
vandonr-amz marked this conversation as resolved.
Show resolved Hide resolved
elif self.wait_for_completion:
waiter = self.hook.get_waiter("cluster_active")
waiter.wait(
Expand Down Expand Up @@ -148,24 +187,29 @@ class EcsDeleteClusterOperator(EcsBaseOperator):
if not set then the default waiter value will be used.
:param waiter_max_attempts: The maximum number of attempts to be made,
if not set then the default waiter value will be used.
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
"""

template_fields: Sequence[str] = ("cluster_name", "wait_for_completion")
template_fields: Sequence[str] = ("cluster_name", "wait_for_completion", "deferrable")

def __init__(
self,
*,
cluster_name: str,
wait_for_completion: bool = True,
waiter_delay: int | None = None,
waiter_max_attempts: int | None = None,
waiter_delay: int = 15,
waiter_max_attempts: int = 60,
deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.cluster_name = cluster_name
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable

def execute(self, context: Context):
self.log.info("Deleting cluster %r.", self.cluster_name)
Expand All @@ -174,9 +218,24 @@ def execute(self, context: Context):
cluster_state = cluster_details.get("status")

if cluster_state == EcsClusterStates.INACTIVE:
# In some circumstances the ECS Cluster is deleted immediately,
# so there is no reason to wait for completion.
# if the cluster doesn't have capacity providers that are associated with it,
# the deletion is instantaneous, and we don't need to wait for it.
Copy link
Contributor

@syedahsn syedahsn Jun 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cluster_details has the capacityProviders associated with the nodegroup. Would that be a better way to decide whether we want to wait for completion or not?
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs/client/delete_cluster.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, we could do that, but the check on the status above is already taking care of that. We can write a different check, but the result would be the same.

self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state)
elif self.deferrable:
self.defer(
trigger=ClusterWaiterTrigger(
waiter_name="cluster_inactive",
cluster_arn=cluster_details["clusterArn"],
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region=self.region,
),
method_name="_complete_exec_with_cluster_desc",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
elif self.wait_for_completion:
waiter = self.hook.get_waiter("cluster_inactive")
waiter.wait(
Expand Down Expand Up @@ -347,6 +406,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
finished.
:param awslogs_fetch_interval: the interval that the ECS task log fetcher should wait
in between each Cloudwatch logs fetches.
If deferrable is set to True, that parameter is ignored and waiter_delay is used instead.
:param quota_retry: Config if and how to retry the launch of a new ECS task, to handle
transient errors.
:param reattach: If set to True, will check if the task previously launched by the task_instance
Expand All @@ -361,6 +421,9 @@ class EcsRunTaskOperator(EcsBaseOperator):
if not set then the default waiter value will be used.
:param waiter_max_attempts: The maximum number of attempts to be made,
if not set then the default waiter value will be used.
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
"""

ui_color = "#f0ede4"
Expand All @@ -384,6 +447,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
"reattach",
"number_logs_exception",
"wait_for_completion",
"deferrable",
)
template_fields_renderers = {
"overrides": "json",
Expand Down Expand Up @@ -416,8 +480,9 @@ def __init__(
reattach: bool = False,
number_logs_exception: int = 10,
wait_for_completion: bool = True,
waiter_delay: int | None = None,
waiter_max_attempts: int | None = None,
waiter_delay: int = 6,
waiter_max_attempts: int = 100,
deferrable: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -451,6 +516,7 @@ def __init__(
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable

if self._aws_logs_enabled() and not self.wait_for_completion:
self.log.warning(
Expand All @@ -473,7 +539,35 @@ def execute(self, context, session=None):
if self.reattach:
self._try_reattach_task(context)

self._start_wait_check_task(context)
self._start_wait_task(context)

self._after_execution(session)

if self.do_xcom_push and self.task_log_fetcher:
return self.task_log_fetcher.get_last_log_message()
else:
return None

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error in task execution: {event}")
self.arn = event["task_arn"] # restore arn to its updated value, needed for next steps
self._after_execution()
if self._aws_logs_enabled():
# same behavior as non-deferrable mode, return last line of logs of the task.
logs_client = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).conn
one_log = logs_client.get_log_events(
logGroupName=self.awslogs_group,
logStreamName=self._get_logs_stream_name(),
startFromHead=False,
limit=1,
)
if len(one_log["events"]) > 0:
return one_log["events"][0]["message"]

@provide_session
def _after_execution(self, session=None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to extract this to reuse it in execute and execute_complete, but I wouldn't find a great name for it.

self._check_success_task()

self.log.info("ECS Task has been successfully executed")

Expand All @@ -482,16 +576,29 @@ def execute(self, context, session=None):
# as we can't reattach it anymore
self._xcom_del(session, self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id))

if self.do_xcom_push and self.task_log_fetcher:
return self.task_log_fetcher.get_last_log_message()

return None

@AwsBaseHook.retry(should_retry_eni)
def _start_wait_check_task(self, context):
def _start_wait_task(self, context):
Comment on lines -491 to +580
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the check went to _after_execution

if not self.arn:
self._start_task(context)

if self.deferrable:
self.defer(
trigger=TaskDoneTrigger(
cluster=self.cluster,
task_arn=self.arn,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region=self.region,
log_group=self.awslogs_group,
log_stream=self._get_logs_stream_name(),
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)

if not self.wait_for_completion:
return

Expand All @@ -508,8 +615,6 @@ def _start_wait_check_task(self, context):
else:
self._wait_for_task_ended()

self._check_success_task()

def _xcom_del(self, session, task_id):
session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id == task_id).delete()

Expand Down Expand Up @@ -584,33 +689,34 @@ def _wait_for_task_ended(self) -> None:
waiter.wait(
cluster=self.cluster,
tasks=[self.arn],
WaiterConfig=prune_dict(
{
"Delay": self.waiter_delay,
"MaxAttempts": self.waiter_max_attempts,
}
),
WaiterConfig={
"Delay": self.waiter_delay,
"MaxAttempts": self.waiter_max_attempts,
},
)

return

def _aws_logs_enabled(self):
return self.awslogs_group and self.awslogs_stream_prefix

def _get_logs_stream_name(self) -> str:
return f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}"

def _get_task_log_fetcher(self) -> AwsTaskLogFetcher:
if not self.awslogs_group:
raise ValueError("must specify awslogs_group to fetch task logs")
log_stream_name = f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}"

return AwsTaskLogFetcher(
aws_conn_id=self.aws_conn_id,
region_name=self.awslogs_region,
log_group=self.awslogs_group,
log_stream_name=log_stream_name,
log_stream_name=self._get_logs_stream_name(),
fetch_interval=self.awslogs_fetch_interval,
logger=self.log,
)

@AwsBaseHook.retry(should_retry_eni)
def _check_success_task(self) -> None:
if not self.client or not self.arn:
return
Expand Down
Loading