From e8d387a8e44418bc0b1cf60575597dd10c4d7dc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Thu, 13 Jul 2023 10:20:11 -0700 Subject: [PATCH] BUGFIX: AWS ECS trigger returned "success" after max_attempts instead of failing reported by user in issue #32580 the issue is about something else, but the user mentionned this as a "bonus bug" --- airflow/providers/amazon/aws/triggers/ecs.py | 7 ++++--- .../providers/amazon/aws/triggers/test_ecs.py | 21 +++++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/airflow/providers/amazon/aws/triggers/ecs.py b/airflow/providers/amazon/aws/triggers/ecs.py index 29ad22e13e7..af6f72e7714 100644 --- a/airflow/providers/amazon/aws/triggers/ecs.py +++ b/airflow/providers/amazon/aws/triggers/ecs.py @@ -22,6 +22,7 @@ from botocore.exceptions import ClientError, WaiterError +from airflow import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook from airflow.providers.amazon.aws.hooks.ecs import EcsHook from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook @@ -170,7 +171,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]: await waiter.wait( cluster=self.cluster, tasks=[self.task_arn], WaiterConfig={"MaxAttempts": 1} ) - break # we reach this point only if the waiter met a success criteria + # we reach this point only if the waiter met a success criteria + yield TriggerEvent({"status": "success", "task_arn": self.task_arn}) except WaiterError as error: if "terminal failure" in str(error): raise @@ -179,8 +181,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: finally: if self.log_group and self.log_stream: logs_token = await self._forward_logs(logs_client, logs_token) - - yield TriggerEvent({"status": "success", "task_arn": self.task_arn}) + raise AirflowException("Waiter error: max attempts reached") async def _forward_logs(self, logs_client, next_token: str | None = None) -> str | None: """ diff --git a/tests/providers/amazon/aws/triggers/test_ecs.py b/tests/providers/amazon/aws/triggers/test_ecs.py index 551ab39a448..e897bce7400 100644 --- a/tests/providers/amazon/aws/triggers/test_ecs.py +++ b/tests/providers/amazon/aws/triggers/test_ecs.py @@ -22,6 +22,7 @@ import pytest from botocore.exceptions import WaiterError +from airflow import AirflowException from airflow.providers.amazon.aws.hooks.ecs import EcsHook from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.providers.amazon.aws.triggers.ecs import ( @@ -56,6 +57,26 @@ async def test_run_until_error(self, _, client_mock): assert wait_mock.call_count == 3 + @pytest.mark.asyncio + @mock.patch.object(EcsHook, "async_conn") + # this mock is only necessary to avoid a "No module named 'aiobotocore'" error in the LatestBoto CI step + @mock.patch.object(AwsLogsHook, "async_conn") + async def test_run_until_timeout(self, _, client_mock): + a_mock = mock.MagicMock() + client_mock.__aenter__.return_value = a_mock + wait_mock = AsyncMock() + wait_mock.side_effect = WaiterError("name", "reason", {"tasks": [{"lastStatus": "my_status"}]}) + a_mock.get_waiter().wait = wait_mock + + trigger = TaskDoneTrigger("cluster", "task_arn", 0, 10, None, None) + + with pytest.raises(AirflowException) as err: + generator = trigger.run() + await generator.asend(None) + + assert wait_mock.call_count == 10 + assert "max attempts" in str(err.value) + @pytest.mark.asyncio @mock.patch.object(EcsHook, "async_conn") # this mock is only necessary to avoid a "No module named 'aiobotocore'" error in the LatestBoto CI step