Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stevejgordon committed Oct 22, 2024
1 parent 160eff8 commit 2300050
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,32 @@ public TResponse BuildResponse<TResponse>(RequestData requestData, byte[] respon
{
var body = responseBody ?? _responseBody;
var data = requestData.PostData;
if (data != null)

if (data is not null)
{
using (var stream = requestData.MemoryStreamFactory.Create())
using var stream = requestData.MemoryStreamFactory.Create();
if (requestData.HttpCompression)
{
using var zipStream = new GZipStream(stream, CompressionMode.Compress);
data.Write(zipStream, requestData.ConnectionSettings);
}
else
{
if (requestData.HttpCompression)
{
using var zipStream = new GZipStream(stream, CompressionMode.Compress);
data.Write(zipStream, requestData.ConnectionSettings);
}
else
data.Write(stream, requestData.ConnectionSettings);
data.Write(stream, requestData.ConnectionSettings);
}
}
requestData.MadeItToResponse = true;

var sc = statusCode ?? _statusCode;
Stream s = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody);
return requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse<TResponse>(requestData, _exception, sc, _headers, s, contentType ?? _contentType ?? RequestData.DefaultMimeType, body?.Length ?? 0, null, null);
Stream responseStream = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody);

var isStreamResponse = typeof(TResponse) == typeof(StreamResponse);

using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null)
{
return requestData.ConnectionSettings.ProductRegistration.ResponseBuilder
.ToResponse<TResponse>(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType ?? RequestData.DefaultMimeType, body?.Length ?? 0, null, null);
}
}

/// <inheritdoc cref="BuildResponse{TResponse}"/>>
Expand All @@ -93,17 +101,19 @@ public async Task<TResponse> BuildResponseAsync<TResponse>(RequestData requestDa
{
var body = responseBody ?? _responseBody;
var data = requestData.PostData;
if (data != null)

if (data is not null)
{
using (var stream = requestData.MemoryStreamFactory.Create())
using var stream = requestData.MemoryStreamFactory.Create();

if (requestData.HttpCompression)
{
using var zipStream = new GZipStream(stream, CompressionMode.Compress);
await data.WriteAsync(zipStream, requestData.ConnectionSettings, cancellationToken).ConfigureAwait(false);
}
else
{
if (requestData.HttpCompression)
{
using var zipStream = new GZipStream(stream, CompressionMode.Compress);
await data.WriteAsync(zipStream, requestData.ConnectionSettings, cancellationToken).ConfigureAwait(false);
}
else
await data.WriteAsync(stream, requestData.ConnectionSettings, cancellationToken).ConfigureAwait(false);
await data.WriteAsync(stream, requestData.ConnectionSettings, cancellationToken).ConfigureAwait(false);
}
}
requestData.MadeItToResponse = true;
Expand All @@ -117,8 +127,8 @@ public async Task<TResponse> BuildResponseAsync<TResponse>(RequestData requestDa
using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null)
{
return await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder
.ToResponseAsync<TResponse>(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType, body?.Length ?? 0, null, null, cancellationToken)
.ConfigureAwait(false);
.ToResponseAsync<TResponse>(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType, body?.Length ?? 0, null, null, cancellationToken)
.ConfigureAwait(false);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Licensed to Elasticsearch B.V under one or more agreements.
// Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information

using System.IO;
using System.Text.Json;
using System.Threading.Tasks;
using Elastic.Transport.IntegrationTests.Plumbing;
using Elastic.Transport.Products.Elasticsearch;
using Microsoft.AspNetCore.Mvc;
using Xunit;

namespace Elastic.Transport.IntegrationTests.Http;

public class StreamResponseTests(TransportTestServer instance) : AssemblyServerTestsBase(instance)
{
private const string Path = "/streamresponse";

[Fact]
public async Task StreamResponse_ShouldNotBeDisposed()
{
var nodePool = new SingleNodePool(Server.Uri);
var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient)));
var transport = new DistributedTransport(config);

var response = await transport.PostAsync<StreamResponse>(Path, PostData.String("{}"));

var sr = new StreamReader(response.Body);
var responseString = sr.ReadToEndAsync();
}
}

[ApiController, Route("[controller]")]
public class StreamResponseController : ControllerBase
{
[HttpPost]
public Task<JsonElement> Post([FromBody] JsonElement body) => Task.FromResult(body);
}
152 changes: 38 additions & 114 deletions tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,146 +3,70 @@
// See the LICENSE file in the project root for more information

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Elastic.Transport.Tests.Plumbing;
using FluentAssertions;
using Xunit;

namespace Elastic.Transport.Tests
namespace Elastic.Transport.Tests;

public class ResponseBuilderDisposeTests
{
public class ResponseBuilderDisposeTests
{
private readonly ITransportConfiguration _settings = InMemoryConnectionFactory.Create().DisableDirectStreaming(false);
private readonly ITransportConfiguration _settingsDisableDirectStream = InMemoryConnectionFactory.Create().DisableDirectStreaming();
private readonly ITransportConfiguration _settings = InMemoryConnectionFactory.Create().DisableDirectStreaming(false);

[Fact] public async Task ResponseWithHttpStatusCode() => await AssertRegularResponse(false, 1);
[Fact]
public async Task ResponseWithPotentialBody_StreamIsNotDisposed() => await AssertResponse(expectedDisposed: false);

[Fact] public async Task ResponseBuilderWithNoHttpStatusCode() => await AssertRegularResponse(false);
[Fact]
public async Task ResponseWith204StatusCode_StreamIsDisposed() => await AssertResponse(204);

[Fact] public async Task ResponseWithHttpStatusCodeDisableDirectStreaming() =>
await AssertRegularResponse(true, 1);
[Fact]
public async Task ResponseForHeadRequest_StreamIsDisposed() => await AssertResponse(httpMethod: HttpMethod.HEAD);

[Fact] public async Task ResponseBuilderWithNoHttpStatusCodeDisableDirectStreaming() =>
await AssertRegularResponse(true);
[Fact]
public async Task ResponseWithZeroContentLength_StreamIsDisposed() => await AssertResponse(contentLength: 0);

private async Task AssertRegularResponse(bool disableDirectStreaming, int? statusCode = null)
private async Task AssertResponse(int statusCode = 200, HttpMethod httpMethod = HttpMethod.GET, int contentLength = 10, bool expectedDisposed = true)
{
var settings = _settings;
var requestData = new RequestData(httpMethod, "/", null, settings, null, null, default)
{
var settings = disableDirectStreaming ? _settingsDisableDirectStream : _settings;
var memoryStreamFactory = new TrackMemoryStreamFactory();
var requestData = new RequestData(HttpMethod.GET, "/", null, settings, null, memoryStreamFactory, default)
{
Node = new Node(new Uri("http://localhost:9200"))
};

var stream = new TrackDisposeStream();
var response = _settings.ProductRegistration.ResponseBuilder.ToResponse<TestResponse>(requestData, null, statusCode, null, stream, null, -1, null, null);
response.Should().NotBeNull();

memoryStreamFactory.Created.Count().Should().Be(disableDirectStreaming ? 1 : 0);
if (disableDirectStreaming)
{
var memoryStream = memoryStreamFactory.Created[0];
memoryStream.IsDisposed.Should().BeTrue();
}
stream.IsDisposed.Should().BeTrue();


stream = new TrackDisposeStream();
var ct = new CancellationToken();
response = await _settings.ProductRegistration.ResponseBuilder.ToResponseAsync<TestResponse>(requestData, null, statusCode, null, stream, null, -1, null, null,
cancellationToken: ct);
response.Should().NotBeNull();
memoryStreamFactory.Created.Count().Should().Be(disableDirectStreaming ? 2 : 0);
if (disableDirectStreaming)
{
var memoryStream = memoryStreamFactory.Created[1];
memoryStream.IsDisposed.Should().BeTrue();
}
stream.IsDisposed.Should().BeTrue();
}
Node = new Node(new Uri("http://localhost:9200"))
};

[Fact] public async Task StreamResponseWithHttpStatusCode() => await AssertStreamResponse(false, 200);
var stream = new TrackDisposeStream();

[Fact] public async Task StreamResponseBuilderWithNoHttpStatusCode() => await AssertStreamResponse(false);
var response = _settings.ProductRegistration.ResponseBuilder.ToResponse<TestResponse>(requestData, null, statusCode, null, stream, null, contentLength, null, null);

[Fact] public async Task StreamResponseWithHttpStatusCodeDisableDirectStreaming() =>
await AssertStreamResponse(true, 1);
response.Should().NotBeNull();
stream.IsDisposed.Should().Be(expectedDisposed);

[Fact] public async Task StreamResponseBuilderWithNoHttpStatusCodeDisableDirectStreaming() =>
await AssertStreamResponse(true);
stream = new TrackDisposeStream();
var ct = new CancellationToken();

private async Task AssertStreamResponse(bool disableDirectStreaming, int? statusCode = null)
{
var settings = disableDirectStreaming ? _settingsDisableDirectStream : _settings;
var memoryStreamFactory = new TrackMemoryStreamFactory();

var requestData = new RequestData(HttpMethod.GET, "/", null, settings, null, memoryStreamFactory, default)
{
Node = new Node(new Uri("http://localhost:9200"))
};

var stream = new TrackDisposeStream();
var response = _settings.ProductRegistration.ResponseBuilder.ToResponse<TestResponse>(requestData, null, statusCode, null, stream, null, -1, null, null);
response.Should().NotBeNull();

memoryStreamFactory.Created.Count().Should().Be(disableDirectStreaming ? 1 : 0);
stream.IsDisposed.Should().Be(true);

stream = new TrackDisposeStream();
var ct = new CancellationToken();
response = await _settings.ProductRegistration.ResponseBuilder.ToResponseAsync<TestResponse>(requestData, null, statusCode, null, stream, null, -1, null, null,
cancellationToken: ct);
response.Should().NotBeNull();
memoryStreamFactory.Created.Count().Should().Be(disableDirectStreaming ? 2 : 0);
stream.IsDisposed.Should().Be(true);
}
response = await _settings.ProductRegistration.ResponseBuilder.ToResponseAsync<TestResponse>(requestData, null, statusCode, null, stream, null, contentLength, null, null,
cancellationToken: ct);

response.Should().NotBeNull();
stream.IsDisposed.Should().Be(expectedDisposed);
}

private class TrackDisposeStream : MemoryStream
{
public TrackDisposeStream() { }

public TrackDisposeStream(byte[] bytes) : base(bytes) { }
private class TrackDisposeStream : MemoryStream
{
public TrackDisposeStream() { }

public TrackDisposeStream(byte[] bytes, int index, int count) : base(bytes, index, count) { }
public TrackDisposeStream(byte[] bytes) : base(bytes) { }

public bool IsDisposed { get; private set; }
public TrackDisposeStream(byte[] bytes, int index, int count) : base(bytes, index, count) { }

protected override void Dispose(bool disposing)
{
IsDisposed = true;
base.Dispose(disposing);
}
}
public bool IsDisposed { get; private set; }

private class TrackMemoryStreamFactory : MemoryStreamFactory
protected override void Dispose(bool disposing)
{
public IList<TrackDisposeStream> Created { get; } = new List<TrackDisposeStream>();

public override MemoryStream Create()
{
var stream = new TrackDisposeStream();
Created.Add(stream);
return stream;
}

public override MemoryStream Create(byte[] bytes)
{
var stream = new TrackDisposeStream(bytes);
Created.Add(stream);
return stream;
}

public override MemoryStream Create(byte[] bytes, int index, int count)
{
var stream = new TrackDisposeStream(bytes, index, count);
Created.Add(stream);
return stream;
}
IsDisposed = true;
base.Dispose(disposing);
}
}
}

0 comments on commit 2300050

Please sign in to comment.