Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various fixes on ECS run task operator #31838

Merged
merged 9 commits into from
Jun 16, 2023
Merged
9 changes: 8 additions & 1 deletion airflow/providers/amazon/aws/hooks/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,14 @@ def _get_log_events(self, skip: int = 0) -> Generator:
except ClientError as error:
if error.response["Error"]["Code"] != "ResourceNotFoundException":
self.logger.warning("Error on retrieving Cloudwatch log events", error)

else:
self.logger.info(
"Cannot find log stream yet, it can take a couple of seconds to show up. "
"If this error persists, check that the log group and stream are correct: "
"group: %s\tstream: %s",
self.log_group,
self.log_stream_name,
)
yield from ()
except ConnectionClosedError as error:
self.logger.warning("ConnectionClosedError on retrieving Cloudwatch log events", error)
Expand Down
28 changes: 18 additions & 10 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,17 @@ def __init__(
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts

if self._aws_logs_enabled() and not self.wait_for_completion:
self.log.warning(
"Trying to get logs without waiting for the task to complete is undefined behavior."
)

@staticmethod
def _get_ecs_task_id(task_arn: str | None) -> str | None:
if task_arn is None:
return None
return task_arn.split("/")[-1]

@provide_session
def execute(self, context, session=None):
self.log.info(
Expand All @@ -506,25 +517,24 @@ def execute(self, context, session=None):

@AwsBaseHook.retry(should_retry_eni)
def _start_wait_check_task(self, context):

if not self.arn:
self._start_task(context)

if not self.wait_for_completion:
return

Comment on lines +523 to +525
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whatever logs users were getting for the short period of time without a wait_for_completion they will no longer get. So we're calling it a bug fix with no deprecation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, idk, we may want to keep the existing behavior, but what I don't like about it is that it made the operator slower just for the sake of maybe getting a couple of logs...
Since we were starting the thread, which slept for 30 seconds (or configured value) before checking if it was stopped, this operator would take 30 seconds to return no matter what, when the job was done in a second and a half.
It's like "don't wait for completion but still wait a bit"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd agree with Raph, I dont think this is a desired behavior but more an forgotten edge case. I would call it a bug fix

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack, I'll call that quorum then, let's call it a bug fix 👍

if self._aws_logs_enabled():
self.log.info("Starting ECS Task Log Fetcher")
self.task_log_fetcher = self._get_task_log_fetcher()
self.task_log_fetcher.start()

try:
if self.wait_for_completion:
self._wait_for_task_ended()
self._wait_for_task_ended()
finally:
self.task_log_fetcher.stop()

self.task_log_fetcher.join()
else:
if self.wait_for_completion:
self._wait_for_task_ended()
self._wait_for_task_ended()

self._check_success_task()

Expand Down Expand Up @@ -566,8 +576,7 @@ def _start_task(self, context):
self.log.info("ECS Task started: %s", response)

self.arn = response["tasks"][0]["taskArn"]
self.ecs_task_id = self.arn.split("/")[-1]
self.log.info("ECS task ID is: %s", self.ecs_task_id)
self.log.info("ECS task ID is: %s", self._get_ecs_task_id(self.arn))

if self.reattach:
# Save the task ARN in XCom to be able to reattach it if needed
Expand All @@ -590,7 +599,6 @@ def _try_reattach_task(self, context):
)
if previous_task_arn in running_tasks:
self.arn = previous_task_arn
self.ecs_task_id = self.arn.split("/")[-1]
self.log.info("Reattaching previously launched task: %s", self.arn)
else:
self.log.info("No active previously launched task found to reattach")
Expand Down Expand Up @@ -620,7 +628,7 @@ def _aws_logs_enabled(self):
def _get_task_log_fetcher(self) -> EcsTaskLogFetcher:
if not self.awslogs_group:
raise ValueError("must specify awslogs_group to fetch task logs")
log_stream_name = f"{self.awslogs_stream_prefix}/{self.ecs_task_id}"
log_stream_name = f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}"

return EcsTaskLogFetcher(
aws_conn_id=self.aws_conn_id,
Expand Down
2 changes: 1 addition & 1 deletion docs/apache-airflow-providers-amazon/operators/ecs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ both can be overridden with provided values. Raises an AirflowException with
the failure reason if a failed state is provided and that state is reached
before the target state.

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ecs.py
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ecs_fargate.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_ecs_task_state]
Expand Down
7 changes: 4 additions & 3 deletions tests/providers/amazon/aws/operators/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,10 @@ def test_execute_without_failures(
wait_mock.assert_called_once_with()
check_mock.assert_called_once_with()
assert self.ecs.arn == f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
assert self.ecs.ecs_task_id == TASK_ID

def test_task_id_parsing(self):
id = EcsRunTaskOperator._get_ecs_task_id(f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}")
assert id == TASK_ID

@mock.patch.object(EcsBaseOperator, "client")
def test_execute_with_failures(self, client_mock):
Expand Down Expand Up @@ -571,7 +574,6 @@ def test_reattach_successful(
check_mock.assert_called_once_with()
xcom_del_mock.assert_called_once()
assert self.ecs.arn == f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
assert self.ecs.ecs_task_id == TASK_ID

@pytest.mark.parametrize(
"launch_type, tags",
Expand Down Expand Up @@ -620,7 +622,6 @@ def test_reattach_save_task_arn_xcom(
check_mock.assert_called_once_with()
xcom_del_mock.assert_called_once()
assert self.ecs.arn == f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
assert self.ecs.ecs_task_id == TASK_ID

@mock.patch.object(EcsBaseOperator, "client")
@mock.patch("airflow.providers.amazon.aws.hooks.ecs.EcsTaskLogFetcher")
Expand Down
51 changes: 25 additions & 26 deletions tests/system/providers/amazon/aws/example_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from airflow import DAG
from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsTaskStates
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates
from airflow.providers.amazon.aws.operators.ecs import (
EcsCreateClusterOperator,
EcsDeleteClusterOperator,
Expand All @@ -34,7 +34,6 @@
from airflow.providers.amazon.aws.sensors.ecs import (
EcsClusterStateSensor,
EcsTaskDefinitionStateSensor,
EcsTaskStateSensor,
)
from airflow.utils.trigger_rule import TriggerRule
from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder
Expand Down Expand Up @@ -67,6 +66,15 @@ def get_region():
return boto3.session.Session().region_name


@task(trigger_rule=TriggerRule.ALL_DONE)
def clean_logs(group_name: str):
client = boto3.client("logs")
# A bit brutal to delete the whole group, I know,
# but we don't have the access to the arn of the task which is used in the stream name
# and also those logs just contain "hello world", which is not very interesting.
client.delete_log_group(logGroupName=group_name)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is going to fail if the group does not exist, so in a way it makes sure the log configuration stays correct.



with DAG(
dag_id=DAG_ID,
schedule="@once",
Expand All @@ -85,6 +93,7 @@ def get_region():
asg_name = f"{env_id}-asg"

aws_region = get_region()
log_group_name = "/ecs/hello-world"
vandonr-amz marked this conversation as resolved.
Show resolved Hide resolved

# [START howto_operator_ecs_create_cluster]
create_cluster = EcsCreateClusterOperator(
Expand Down Expand Up @@ -114,7 +123,16 @@ def get_region():
"workingDirectory": "/usr/bin",
"entryPoint": ["sh", "-c"],
"command": ["ls"],
}
"logConfiguration": {
"logDriver": "awslogs",
"options": {
"awslogs-group": log_group_name,
"awslogs-region": aws_region,
"awslogs-create-group": "true",
"awslogs-stream-prefix": "ecs",
},
},
},
],
register_task_kwargs={
"cpu": "256",
Expand All @@ -140,38 +158,19 @@ def get_region():
"containerOverrides": [
{
"name": container_name,
"command": ["echo", "hello", "world"],
"command": ["echo hello world"],
},
],
},
network_configuration={"awsvpcConfiguration": {"subnets": existing_cluster_subnets}},
# [START howto_awslogs_ecs]
awslogs_group="/ecs/hello-world",
awslogs_group=log_group_name,
awslogs_region=aws_region,
awslogs_stream_prefix="ecs/hello-world-container",
awslogs_stream_prefix=f"ecs/{container_name}",
# [END howto_awslogs_ecs]
# You must set `reattach=True` in order to get ecs_task_arn if you plan to use a Sensor.
reattach=True,
)
# [END howto_operator_ecs_run_task]

# EcsRunTaskOperator waits by default, setting as False to test the Sensor below.
run_task.wait_for_completion = False

# [START howto_sensor_ecs_task_state]
# By default, EcsTaskStateSensor waits until the task has started, but the
# demo task runs so fast that the sensor misses it. This sensor instead
# demonstrates how to wait until the ECS Task has completed by providing
# the target_state and failure_states parameters.
await_task_finish = EcsTaskStateSensor(
task_id="await_task_finish",
cluster=existing_cluster_name,
task=run_task.output["ecs_task_arn"],
target_state=EcsTaskStates.STOPPED,
failure_states={EcsTaskStates.NONE},
)
# [END howto_sensor_ecs_task_state]

# [START howto_operator_ecs_deregister_task_definition]
deregister_task = EcsDeregisterTaskDefinitionOperator(
task_id="deregister_task",
Expand Down Expand Up @@ -209,10 +208,10 @@ def get_region():
register_task,
await_task_definition,
run_task,
await_task_finish,
deregister_task,
delete_cluster,
await_delete_cluster,
clean_logs(log_group_name),
)

from tests.system.utils.watcher import watcher
Expand Down
22 changes: 22 additions & 0 deletions tests/system/providers/amazon/aws/example_ecs_fargate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from airflow import DAG
from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.providers.amazon.aws.hooks.ecs import EcsTaskStates
from airflow.providers.amazon.aws.operators.ecs import EcsRunTaskOperator
from airflow.providers.amazon.aws.sensors.ecs import EcsTaskStateSensor
from airflow.utils.trigger_rule import TriggerRule
from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder

Expand Down Expand Up @@ -120,16 +122,36 @@ def delete_cluster(cluster_name: str) -> None:
"assignPublicIp": "ENABLED",
},
},
# You must set `reattach=True` in order to get ecs_task_arn if you plan to use a Sensor.
reattach=True,
)
# [END howto_operator_ecs]

# EcsRunTaskOperator waits by default, setting as False to test the Sensor below.
hello_world.wait_for_completion = False

# [START howto_sensor_ecs_task_state]
# By default, EcsTaskStateSensor waits until the task has started, but the
# demo task runs so fast that the sensor misses it. This sensor instead
# demonstrates how to wait until the ECS Task has completed by providing
# the target_state and failure_states parameters.
await_task_finish = EcsTaskStateSensor(
task_id="await_task_finish",
cluster=cluster_name,
task=hello_world.output["ecs_task_arn"],
target_state=EcsTaskStates.STOPPED,
failure_states={EcsTaskStates.NONE},
)
# [END howto_sensor_ecs_task_state]

chain(
# TEST SETUP
test_context,
create_cluster(cluster_name),
create_task_definition,
# TEST BODY
hello_world,
await_task_finish,
# TEST TEARDOWN
delete_task_definition(create_task_definition),
delete_cluster(cluster_name),
Expand Down