From 57c7ee07eaa7d3c5c2240448bd038d0a990485a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Gurhem?= Date: Tue, 7 Feb 2023 16:15:16 +0100 Subject: [PATCH] feat: Add endpoint on Polling Agent to kill the current task if task is in cancelling status --- Common/src/Pollster/Pollster.cs | 8 +- Common/src/Pollster/TaskHandler.cs | 76 +++++++---- Common/tests/Helpers/ExceptionAsyncPipe.cs | 60 +++++++++ .../Helpers/ExceptionWorkerStreamHandler.cs | 61 +++++++++ Common/tests/Pollster/PollsterTest.cs | 73 ++++++++++ Common/tests/Pollster/TaskHandlerTest.cs | 125 +++++++++++------- Compute/PollingAgent/src/Program.cs | 12 ++ 7 files changed, 341 insertions(+), 74 deletions(-) create mode 100644 Common/tests/Helpers/ExceptionAsyncPipe.cs create mode 100644 Common/tests/Helpers/ExceptionWorkerStreamHandler.cs diff --git a/Common/src/Pollster/Pollster.cs b/Common/src/Pollster/Pollster.cs index e84b77ccf..ce486326e 100644 --- a/Common/src/Pollster/Pollster.cs +++ b/Common/src/Pollster/Pollster.cs @@ -67,6 +67,7 @@ public class Pollster : IInitializable private readonly IWorkerStreamHandler workerStreamHandler_; private bool endLoopReached_; private HealthCheckResult? healthCheckFailedResult_; + public Func? StopCancelledTask; public string TaskProcessing; public Pollster(IPullQueueStorage pullQueueStorage, @@ -295,6 +296,8 @@ void RecordError(Exception e) pollsterOptions_, cts); + StopCancelledTask = taskHandler.StopCancelledTask; + var precondition = await taskHandler.AcquireTask() .ConfigureAwait(false); @@ -311,6 +314,8 @@ await taskHandler.ExecuteTask() await taskHandler.PostProcessing() .ConfigureAwait(false); + StopCancelledTask = null; + logger_.LogDebug("Task returned"); // If the task was successful, we can remove a failure @@ -331,7 +336,8 @@ await taskHandler.PostProcessing() } finally { - TaskProcessing = string.Empty; + StopCancelledTask = null; + TaskProcessing = string.Empty; } } } diff --git a/Common/src/Pollster/TaskHandler.cs b/Common/src/Pollster/TaskHandler.cs index a46985edb..c5d6e48c7 100644 --- a/Common/src/Pollster/TaskHandler.cs +++ b/Common/src/Pollster/TaskHandler.cs @@ -87,33 +87,35 @@ public TaskHandler(ISessionTable sessionTable, Injection.Options.Pollster pollsterOptions, CancellationTokenSource cancellationTokenSource) { - sessionTable_ = sessionTable; - taskTable_ = taskTable; - resultTable_ = resultTable; - messageHandler_ = messageHandler; - taskProcessingChecker_ = taskProcessingChecker; - submitter_ = submitter; - dataPrefetcher_ = dataPrefetcher; - workerStreamHandler_ = workerStreamHandler; - activitySource_ = activitySource; - agentHandler_ = agentHandler; - logger_ = logger; - cancellationTokenSource_ = cancellationTokenSource; - ownerPodId_ = ownerPodId; - ownerPodName_ = ownerPodName; - taskData_ = null; - sessionData_ = null; + sessionTable_ = sessionTable; + taskTable_ = taskTable; + resultTable_ = resultTable; + messageHandler_ = messageHandler; + taskProcessingChecker_ = taskProcessingChecker; + submitter_ = submitter; + dataPrefetcher_ = dataPrefetcher; + workerStreamHandler_ = workerStreamHandler; + activitySource_ = activitySource; + agentHandler_ = agentHandler; + logger_ = logger; + ownerPodId_ = ownerPodId; + ownerPodName_ = ownerPodName; + taskData_ = null; + sessionData_ = null; token_ = Guid.NewGuid() .ToString(); - workerConnectionCts_ = new CancellationTokenSource(); + workerConnectionCts_ = new CancellationTokenSource(); + cancellationTokenSource_ = new CancellationTokenSource(); - reg1_ = cancellationTokenSource_.Token.Register(() => - { - logger_.LogWarning("Cancellation triggered, waiting {waitingTime} before cancelling task", - pollsterOptions.GraceDelay); - workerConnectionCts_.CancelAfter(pollsterOptions.GraceDelay); - }); + reg1_ = cancellationTokenSource.Token.Register(() => cancellationTokenSource_.Cancel()); + + cancellationTokenSource_.Token.Register(() => + { + logger_.LogWarning("Cancellation triggered, waiting {waitingTime} before cancelling task", + pollsterOptions.GraceDelay); + workerConnectionCts_.CancelAfter(pollsterOptions.GraceDelay); + }); workerConnectionCts_.Token.Register(() => logger_.LogWarning("Cancellation triggered, start to properly cancel task")); } @@ -131,10 +133,32 @@ await messageHandler_.DisposeAsync() reg1_.Unregister(); await reg1_.DisposeAsync() .ConfigureAwait(false); + cancellationTokenSource_.Dispose(); workerConnectionCts_.Dispose(); agent_?.Dispose(); } + /// + /// Refresh task metadata and stop execution if current task should be cancelled + /// + /// + /// Task representing the asynchronous execution of the method + /// + public async Task StopCancelledTask() + { + if (taskData_?.Status is not null or TaskStatus.Cancelled or TaskStatus.Cancelling) + { + taskData_ = await taskTable_.ReadTaskAsync(messageHandler_.TaskId, + CancellationToken.None) + .ConfigureAwait(false); + if (taskData_.Status is TaskStatus.Cancelling) + { + logger_.LogWarning("Task has been cancelled, trigger cancellation from exterior."); + cancellationTokenSource_.Cancel(); + } + } + } + /// /// Acquisition of the task in the message given to the constructor /// @@ -618,6 +642,12 @@ private async Task HandleErrorInternalAsync(Exception e, bool requeueIfUnavailable, CancellationToken cancellationToken) { + if (taskData.Status is TaskStatus.Cancelled or TaskStatus.Cancelling) + { + messageHandler_.Status = QueueMessageStatus.Processed; + return; + } + if (cancellationToken.IsCancellationRequested || (requeueIfUnavailable && e is RpcException { StatusCode: StatusCode.Unavailable, diff --git a/Common/tests/Helpers/ExceptionAsyncPipe.cs b/Common/tests/Helpers/ExceptionAsyncPipe.cs new file mode 100644 index 000000000..8dec020ed --- /dev/null +++ b/Common/tests/Helpers/ExceptionAsyncPipe.cs @@ -0,0 +1,60 @@ +// This file is part of the ArmoniK project +// +// Copyright (C) ANEO, 2021-2023. All rights reserved. +// W. Kirschenmann +// J. Gurhem +// D. Dubuc +// L. Ziane Khodja +// F. Lemaitre +// S. Djebbar +// J. Fonseca +// D. Brasseur +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY, without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +using ArmoniK.Api.gRPC.V1.Worker; +using ArmoniK.Core.Common.Utils; + +namespace ArmoniK.Core.Common.Tests.Helpers; + +public class ExceptionAsyncPipe : IAsyncPipe + where T : Exception, new() +{ + private readonly int delay_; + + public ExceptionAsyncPipe(int delay) + => delay_ = delay; + + public async Task ReadAsync(CancellationToken cancellationToken) + { + await Task.Delay(TimeSpan.FromMilliseconds(delay_), + cancellationToken) + .ConfigureAwait(false); + cancellationToken.ThrowIfCancellationRequested(); + throw new T(); + } + + public Task WriteAsync(ProcessRequest message) + => Task.CompletedTask; + + public Task WriteAsync(IEnumerable message) + => Task.CompletedTask; + + public Task CompleteAsync() + => Task.CompletedTask; +} diff --git a/Common/tests/Helpers/ExceptionWorkerStreamHandler.cs b/Common/tests/Helpers/ExceptionWorkerStreamHandler.cs new file mode 100644 index 000000000..b30da51a0 --- /dev/null +++ b/Common/tests/Helpers/ExceptionWorkerStreamHandler.cs @@ -0,0 +1,61 @@ +// This file is part of the ArmoniK project +// +// Copyright (C) ANEO, 2021-2023. All rights reserved. +// W. Kirschenmann +// J. Gurhem +// D. Dubuc +// L. Ziane Khodja +// F. Lemaitre +// S. Djebbar +// J. Fonseca +// D. Brasseur +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY, without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +using System; +using System.Threading; +using System.Threading.Tasks; + +using ArmoniK.Api.gRPC.V1.Worker; +using ArmoniK.Core.Common.Storage; +using ArmoniK.Core.Common.Stream.Worker; +using ArmoniK.Core.Common.Utils; + +using Microsoft.Extensions.Diagnostics.HealthChecks; + +namespace ArmoniK.Core.Common.Tests.Helpers; + +public class ExceptionWorkerStreamHandler : IWorkerStreamHandler + where T : Exception, new() +{ + private readonly int delay_; + + public ExceptionWorkerStreamHandler(int delay) + => delay_ = delay; + + public Task Check(HealthCheckTag tag) + => Task.FromResult(HealthCheckResult.Healthy()); + + public Task Init(CancellationToken cancellationToken) + => Task.CompletedTask; + + public void Dispose() + { + } + + public IAsyncPipe? Pipe { get; private set; } + + public void StartTaskProcessing(TaskData taskData, + CancellationToken cancellationToken) + => Pipe = new ExceptionAsyncPipe(delay_); +} diff --git a/Common/tests/Pollster/PollsterTest.cs b/Common/tests/Pollster/PollsterTest.cs index 3582fe66a..0a0b1e641 100644 --- a/Common/tests/Pollster/PollsterTest.cs +++ b/Common/tests/Pollster/PollsterTest.cs @@ -501,6 +501,79 @@ await testServiceProvider.Pollster.Init(CancellationToken.None) testServiceProvider.Pollster.TaskProcessing); } + [Test] + public async Task CancelLongTaskShouldSucceed() + { + var mockPullQueueStorage = new Mock(); + var waitWorkerStreamHandler = new ExceptionWorkerStreamHandler(15000); + var simpleAgentHandler = new SimpleAgentHandler(); + + using var testServiceProvider = new TestPollsterProvider(waitWorkerStreamHandler, + simpleAgentHandler, + mockPullQueueStorage.Object); + + var tuple = await InitSubmitter(testServiceProvider.Submitter, + testServiceProvider.PartitionTable, + CancellationToken.None) + .ConfigureAwait(false); + + mockPullQueueStorage.Setup(storage => storage.PullMessagesAsync(It.IsAny(), + It.IsAny())) + .Returns(() => new List + { + new SimpleQueueMessageHandler + { + CancellationToken = CancellationToken.None, + Status = QueueMessageStatus.Waiting, + MessageId = Guid.NewGuid() + .ToString(), + TaskId = tuple.taskSubmitted, + }, + }.ToAsyncEnumerable()); + + await testServiceProvider.Pollster.Init(CancellationToken.None) + .ConfigureAwait(false); + + var source = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + + var mainLoopTask = testServiceProvider.Pollster.MainLoop(source.Token); + + await Task.Delay(TimeSpan.FromMilliseconds(200), + CancellationToken.None) + .ConfigureAwait(false); + + await testServiceProvider.TaskTable.CancelTaskAsync(new List + { + tuple.taskSubmitted, + }, + CancellationToken.None) + .ConfigureAwait(false); + + await Task.Delay(TimeSpan.FromMilliseconds(200), + CancellationToken.None) + .ConfigureAwait(false); + + await testServiceProvider.Pollster.StopCancelledTask!.Invoke() + .ConfigureAwait(false); + + Assert.DoesNotThrowAsync(() => mainLoopTask); + Assert.False(testServiceProvider.Pollster.Failed); + Assert.True(source.Token.IsCancellationRequested); + + Assert.AreEqual(TaskStatus.Cancelled, + (await testServiceProvider.TaskTable.GetTaskStatus(new[] + { + tuple.taskSubmitted, + }, + CancellationToken.None) + .ConfigureAwait(false)).Single() + .Status); + Assert.AreEqual(string.Empty, + testServiceProvider.Pollster.TaskProcessing); + Assert.AreSame(string.Empty, + testServiceProvider.Pollster.TaskProcessing); + } + public static IEnumerable ExecuteTooManyErrorShouldFailTestCase { get diff --git a/Common/tests/Pollster/TaskHandlerTest.cs b/Common/tests/Pollster/TaskHandlerTest.cs index 41b8bec78..3aade492e 100644 --- a/Common/tests/Pollster/TaskHandlerTest.cs +++ b/Common/tests/Pollster/TaskHandlerTest.cs @@ -839,56 +839,6 @@ public async Task AcquireNotReadyTaskShouldFail() Assert.IsFalse(acquired); } - private class ExceptionAsyncPipe : IAsyncPipe - where T : Exception, new() - { - private readonly int delay_; - - public ExceptionAsyncPipe(int delay) - => delay_ = delay; - - public async Task ReadAsync(CancellationToken cancellationToken) - { - await Task.Delay(TimeSpan.FromMilliseconds(delay_)) - .ConfigureAwait(false); - throw new T(); - } - - public Task WriteAsync(ProcessRequest message) - => Task.CompletedTask; - - public Task WriteAsync(IEnumerable message) - => Task.CompletedTask; - - public Task CompleteAsync() - => Task.CompletedTask; - } - - public class ExceptionWorkerStreamHandler : IWorkerStreamHandler - where T : Exception, new() - { - private readonly int delay_; - - public ExceptionWorkerStreamHandler(int delay) - => delay_ = delay; - - public Task Check(HealthCheckTag tag) - => Task.FromResult(HealthCheckResult.Healthy()); - - public Task Init(CancellationToken cancellationToken) - => Task.CompletedTask; - - public void Dispose() - { - } - - public IAsyncPipe? Pipe { get; private set; } - - public void StartTaskProcessing(TaskData taskData, - CancellationToken cancellationToken) - => Pipe = new ExceptionAsyncPipe(delay_); - } - public class ExceptionStartWorkerStreamHandler : IWorkerStreamHandler where T : Exception, new() { @@ -1260,4 +1210,79 @@ await testServiceProvider.TaskHandler.PostProcessing() .ConfigureAwait(false)).Single() .Status); } + + [Test] + public async Task CancelLongTaskShouldSucceed() + { + var sqmh = new SimpleQueueMessageHandler + { + CancellationToken = CancellationToken.None, + Status = QueueMessageStatus.Waiting, + MessageId = Guid.NewGuid() + .ToString(), + }; + + var sh = new ExceptionWorkerStreamHandler(15000); + + var agentHandler = new SimpleAgentHandler(); + using var testServiceProvider = new TestTaskHandlerProvider(sh, + agentHandler, + sqmh, + new CancellationTokenSource()); + + var (taskId, _, _, _) = await InitProviderRunnableTask(testServiceProvider) + .ConfigureAwait(false); + + sqmh.TaskId = taskId; + + var acquired = await testServiceProvider.TaskHandler.AcquireTask() + .ConfigureAwait(false); + + Assert.IsTrue(acquired); + + await testServiceProvider.TaskHandler.PreProcessing() + .ConfigureAwait(false); + + await testServiceProvider.TaskHandler.ExecuteTask() + .ConfigureAwait(false); + + // Cancel task for test + + await Task.Delay(TimeSpan.FromMilliseconds(200)) + .ConfigureAwait(false); + + await testServiceProvider.TaskTable.CancelTaskAsync(new List + { + taskId, + }, + CancellationToken.None) + .ConfigureAwait(false); + + await Task.Delay(TimeSpan.FromMilliseconds(200)) + .ConfigureAwait(false); + + // Make several calls to ensure that it still works + await testServiceProvider.TaskHandler.StopCancelledTask() + .ConfigureAwait(false); + await testServiceProvider.TaskHandler.StopCancelledTask() + .ConfigureAwait(false); + await testServiceProvider.TaskHandler.StopCancelledTask() + .ConfigureAwait(false); + await testServiceProvider.TaskHandler.StopCancelledTask() + .ConfigureAwait(false); + + await testServiceProvider.TaskHandler.PostProcessing() + .ConfigureAwait(false); + + Assert.AreEqual(TaskStatus.Cancelling, + (await testServiceProvider.TaskTable.GetTaskStatus(new[] + { + taskId, + }) + .ConfigureAwait(false)).Single() + .Status); + + Assert.AreEqual(QueueMessageStatus.Processed, + sqmh.Status); + } } diff --git a/Compute/PollingAgent/src/Program.cs b/Compute/PollingAgent/src/Program.cs index 1ed29d8a1..30ef43ce6 100644 --- a/Compute/PollingAgent/src/Program.cs +++ b/Compute/PollingAgent/src/Program.cs @@ -170,6 +170,18 @@ public static async Task Main(string[] args) endpoints.MapGet("/taskprocessing", () => Task.FromResult(app.Services.GetRequiredService() .TaskProcessing)); + + endpoints.MapGet("/stopcancelledtask", + async () => + { + var stopCancelledTask = app.Services.GetRequiredService() + .StopCancelledTask; + if (stopCancelledTask != null) + { + await stopCancelledTask.Invoke() + .ConfigureAwait(false); + } + }); }); var pushQueueStorage = app.Services.GetRequiredService();