From b37eafded11eff20b6d67e146d17b21e807bb1d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marie=20P=C3=ADchov=C3=A1?= <11718369+ManickaP@users.noreply.github.com> Date: Fri, 3 May 2024 02:06:45 +0200 Subject: [PATCH] [H/3] Fix code passed into QuicConnection.CloseAsync and QuicStream.Abort (#55282) * Fix code passed into QuicConnection.CloseAsync and QuicStream.Abort * Validate that configured error code is used --------- Co-authored-by: Andrew Casey --- .../Core/src/Internal/Http3/Http3Stream.cs | 4 +- ...QuicConnectionContext.FeatureCollection.cs | 6 +- .../src/Internal/QuicConnectionContext.cs | 5 +- .../QuicStreamContext.FeatureCollection.cs | 10 +++- .../src/Internal/QuicStreamContext.cs | 6 +- .../src/QuicTransportOptions.cs | 5 +- .../test/QuicConnectionContextTests.cs | 27 +++++++++ .../test/QuicStreamContextTests.cs | 55 +++++++++++++++++++ .../Http3/Http3RequestTests.cs | 33 ++++++++++- 9 files changed, 138 insertions(+), 13 deletions(-) diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs index 945a6e455798..3ce5cdcad632 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs @@ -166,6 +166,9 @@ private void AbortCore(Exception exception, Http3ErrorCode errorCode) abortReason = new ConnectionAbortedException(exception.Message, exception); } + // This has the side-effect of validating the error code, so do it before we consume the error code + _errorCodeFeature.Error = (long)errorCode; + _context.WebTransportSession?.Abort(abortReason, errorCode); Log.Http3StreamAbort(TraceIdentifier, errorCode, abortReason); @@ -181,7 +184,6 @@ private void AbortCore(Exception exception, Http3ErrorCode errorCode) RequestBodyPipe.Writer.Complete(exception); // Abort framewriter and underlying transport after stopping output. - _errorCodeFeature.Error = (long)errorCode; _frameWriter.Abort(abortReason); } } diff --git a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicConnectionContext.FeatureCollection.cs b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicConnectionContext.FeatureCollection.cs index c563a58368c6..e2603836421a 100644 --- a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicConnectionContext.FeatureCollection.cs +++ b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicConnectionContext.FeatureCollection.cs @@ -16,7 +16,11 @@ internal sealed partial class QuicConnectionContext : IProtocolErrorCodeFeature, public long Error { get => _error ?? -1; - set => _error = value; + set + { + QuicTransportOptions.ValidateErrorCode(value); + _error = value; + } } public X509Certificate2? ClientCertificate diff --git a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicConnectionContext.cs b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicConnectionContext.cs index 5e85e81821fb..9fd38e0a8bfa 100644 --- a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicConnectionContext.cs +++ b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicConnectionContext.cs @@ -56,6 +56,7 @@ public override async ValueTask DisposeAsync() { lock (_shutdownLock) { + // The DefaultCloseErrorCode setter validates that the error code is within the valid range _closeTask ??= _connection.CloseAsync(errorCode: _context.Options.DefaultCloseErrorCode).AsTask(); } @@ -81,7 +82,7 @@ public override void Abort(ConnectionAbortedException abortReason) return; } - var resolvedErrorCode = _error ?? 0; + var resolvedErrorCode = _error ?? 0; // Only valid error codes are assigned to _error _abortReason = ExceptionDispatchInfo.Capture(abortReason); QuicLog.ConnectionAbort(_log, this, resolvedErrorCode, abortReason.Message); _closeTask = _connection.CloseAsync(errorCode: resolvedErrorCode).AsTask(); @@ -130,7 +131,7 @@ public override void Abort(ConnectionAbortedException abortReason) catch (QuicException ex) when (ex.QuicError == QuicError.ConnectionAborted) { // Shutdown initiated by peer, abortive. - _error = ex.ApplicationErrorCode; + _error = ex.ApplicationErrorCode; // Trust Quic to provide us a valid error code QuicLog.ConnectionAborted(_log, this, ex.ApplicationErrorCode.GetValueOrDefault(), ex); ThreadPool.UnsafeQueueUserWorkItem(state => diff --git a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicStreamContext.FeatureCollection.cs b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicStreamContext.FeatureCollection.cs index 48e0095110d7..d3d39e202b49 100644 --- a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicStreamContext.FeatureCollection.cs +++ b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicStreamContext.FeatureCollection.cs @@ -38,7 +38,11 @@ public OnCloseRegistration(Action callback, object? state) public long Error { get => _error ?? -1; - set => _error = value; + set + { + QuicTransportOptions.ValidateErrorCode(value); + _error = value; + } } public long StreamId { get; private set; } @@ -54,6 +58,8 @@ public long Error public void AbortRead(long errorCode, ConnectionAbortedException abortReason) { + QuicTransportOptions.ValidateErrorCode(errorCode); + lock (_shutdownLock) { if (_stream != null) @@ -74,6 +80,8 @@ public void AbortRead(long errorCode, ConnectionAbortedException abortReason) public void AbortWrite(long errorCode, ConnectionAbortedException abortReason) { + QuicTransportOptions.ValidateErrorCode(errorCode); + lock (_shutdownLock) { if (_stream != null) diff --git a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicStreamContext.cs b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicStreamContext.cs index d81f0ec4757a..6d7fd3b3c777 100644 --- a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicStreamContext.cs +++ b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicStreamContext.cs @@ -273,7 +273,7 @@ private async ValueTask DoReceiveAsync() catch (QuicException ex) when (ex.QuicError is QuicError.StreamAborted or QuicError.ConnectionAborted) { // Abort from peer. - _error = ex.ApplicationErrorCode; + _error = ex.ApplicationErrorCode; // Trust Quic to provide us a valid error code QuicLog.StreamAbortedRead(_log, this, ex.ApplicationErrorCode.GetValueOrDefault()); // This could be ignored if _shutdownReason is already set. @@ -434,7 +434,7 @@ private async ValueTask DoSendAsync() catch (QuicException ex) when (ex.QuicError is QuicError.StreamAborted or QuicError.ConnectionAborted) { // Abort from peer. - _error = ex.ApplicationErrorCode; + _error = ex.ApplicationErrorCode; // Trust Quic to provide us a valid error code QuicLog.StreamAbortedWrite(_log, this, ex.ApplicationErrorCode.GetValueOrDefault()); // This could be ignored if _shutdownReason is already set. @@ -501,7 +501,7 @@ public override void Abort(ConnectionAbortedException abortReason) _shutdownReason = abortReason; } - var resolvedErrorCode = _error ?? 0; + var resolvedErrorCode = _error ?? 0; // _error is validated on assignment QuicLog.StreamAbort(_log, this, resolvedErrorCode, abortReason.Message); if (stream.CanRead) diff --git a/src/Servers/Kestrel/Transport.Quic/src/QuicTransportOptions.cs b/src/Servers/Kestrel/Transport.Quic/src/QuicTransportOptions.cs index e2f58f124b0a..e6cdce55817c 100644 --- a/src/Servers/Kestrel/Transport.Quic/src/QuicTransportOptions.cs +++ b/src/Servers/Kestrel/Transport.Quic/src/QuicTransportOptions.cs @@ -68,14 +68,15 @@ public long DefaultCloseErrorCode } } - private static void ValidateErrorCode(long errorCode) + internal static void ValidateErrorCode(long errorCode) { const long MinErrorCode = 0; const long MaxErrorCode = (1L << 62) - 1; if (errorCode < MinErrorCode || errorCode > MaxErrorCode) { - throw new ArgumentOutOfRangeException(nameof(errorCode), errorCode, $"A value between {MinErrorCode} and {MaxErrorCode} is required."); + // Print the values in hex since the max is unintelligible in decimal + throw new ArgumentOutOfRangeException(nameof(errorCode), errorCode, $"A value between 0x{MinErrorCode:x} and 0x{MaxErrorCode:x} is required."); } } diff --git a/src/Servers/Kestrel/Transport.Quic/test/QuicConnectionContextTests.cs b/src/Servers/Kestrel/Transport.Quic/test/QuicConnectionContextTests.cs index c8ee8ab0d226..8a37f8668fb5 100644 --- a/src/Servers/Kestrel/Transport.Quic/test/QuicConnectionContextTests.cs +++ b/src/Servers/Kestrel/Transport.Quic/test/QuicConnectionContextTests.cs @@ -706,6 +706,33 @@ public async Task PersistentState_StreamsReused_StatePersisted() Assert.Equal(true, state); } + [ConditionalTheory] + [MsQuicSupported] + [InlineData(-1L)] // Too small + [InlineData(1L << 62)] // Too big + public async Task IProtocolErrorFeature_InvalidErrorCode(long errorCode) + { + // Arrange + await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory); + + var options = QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint); + await using var clientConnection = await QuicConnection.ConnectAsync(options); + + await using var serverConnection = await connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout(); + + // Act + var clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + await clientStream.WriteAsync(TestData).DefaultTimeout(); + + var serverStream = await serverConnection.AcceptAsync().DefaultTimeout(); + + var protocolErrorCodeFeature = serverConnection.Features.Get(); + + // Assert + Assert.IsType(protocolErrorCodeFeature); + Assert.Throws(() => protocolErrorCodeFeature.Error = errorCode); + } + private record RequestState( QuicConnection QuicConnection, MultiplexedConnectionContext ServerConnection, diff --git a/src/Servers/Kestrel/Transport.Quic/test/QuicStreamContextTests.cs b/src/Servers/Kestrel/Transport.Quic/test/QuicStreamContextTests.cs index 7ebb99a5eb02..0dd698dd2e2a 100644 --- a/src/Servers/Kestrel/Transport.Quic/test/QuicStreamContextTests.cs +++ b/src/Servers/Kestrel/Transport.Quic/test/QuicStreamContextTests.cs @@ -526,4 +526,59 @@ public async Task StreamAbortFeature_AbortWrite_ClientReceivesAbort() var serverEx = await Assert.ThrowsAsync(() => serverReadTask).DefaultTimeout(); Assert.Equal("Test reason", serverEx.Message); } + + [ConditionalTheory] + [MsQuicSupported] + [InlineData(-1L)] // Too small + [InlineData(1L << 62)] // Too big + public async Task IProtocolErrorFeature_InvalidErrorCode(long errorCode) + { + // Arrange + await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory); + + var options = QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint); + await using var clientConnection = await QuicConnection.ConnectAsync(options); + + await using var serverConnection = await connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout(); + + // Act + var clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + await clientStream.WriteAsync(TestData).DefaultTimeout(); + + var serverStream = await serverConnection.AcceptAsync().DefaultTimeout(); + + var protocolErrorCodeFeature = serverStream.Features.Get(); + + // Assert + Assert.IsType(protocolErrorCodeFeature); + Assert.Throws(() => protocolErrorCodeFeature.Error = errorCode); + } + + [ConditionalTheory] + [MsQuicSupported] + [InlineData(-1L)] // Too small + [InlineData(1L << 62)] // Too big + public async Task IStreamAbortFeature_InvalidErrorCode(long errorCode) + { + // Arrange + await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory); + + var options = QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint); + await using var clientConnection = await QuicConnection.ConnectAsync(options); + + await using var serverConnection = await connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout(); + + // Act + var clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + await clientStream.WriteAsync(TestData).DefaultTimeout(); + + var serverStream = await serverConnection.AcceptAsync().DefaultTimeout(); + + var protocolErrorCodeFeature = serverStream.Features.Get(); + + // Assert + Assert.IsType(protocolErrorCodeFeature); + Assert.Throws(() => protocolErrorCodeFeature.AbortRead(errorCode, new ConnectionAbortedException())); + Assert.Throws(() => protocolErrorCodeFeature.AbortWrite(errorCode, new ConnectionAbortedException())); + } } diff --git a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs index c00abc6a907f..e16404299f90 100644 --- a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs +++ b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs @@ -13,17 +13,16 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Internal; +using Microsoft.AspNetCore.InternalTesting; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Https; -using Microsoft.AspNetCore.InternalTesting; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Quic; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Diagnostics.Metrics; using Microsoft.Extensions.Diagnostics.Metrics.Testing; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Testing; using Microsoft.Extensions.Primitives; -using Xunit; namespace Interop.FunctionalTests.Http3; @@ -2031,6 +2030,34 @@ public async Task GET_GracefulServerShutdown_RequestCompleteSuccessfullyInsideHo } } + [ConditionalFact] + [MsQuicSupported] + public async Task ServerReset_InvalidErrorCode() + { + var ranHandler = false; + var hostBuilder = CreateHostBuilder(context => + { + ranHandler = true; + // Can't test a too-large value since it's bigger than int + //Assert.Throws(() => context.Features.Get().Reset(-1)); // Invalid negative value + context.Features.Get().Reset(-1); + return Task.CompletedTask; + }); + + using var host = await hostBuilder.StartAsync().DefaultTimeout(); + using var client = HttpHelpers.CreateClient(); + + var request = new HttpRequestMessage(HttpMethod.Get, $"https://127.0.0.1:{host.GetPort()}/"); + request.Version = GetProtocol(HttpProtocols.Http3); + request.VersionPolicy = HttpVersionPolicy.RequestVersionExact; + + var response = await client.SendAsync(request, CancellationToken.None).DefaultTimeout(); + await host.StopAsync().DefaultTimeout(); + + Assert.True(ranHandler); + Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); + } + private IHostBuilder CreateHostBuilder(RequestDelegate requestDelegate, HttpProtocols? protocol = null, Action configureKestrel = null) { return HttpHelpers.CreateHostBuilder(AddTestLogging, requestDelegate, protocol, configureKestrel);