Skip to content

Commit

Permalink
Improve ssl buffers handling (#8165)
Browse files Browse the repository at this point in the history
* Fixes #8161 improve SSLConnection buffers handling

Added memory heuristic to ArrayRetainableByteBufferPool

Signed-off-by: Ludovic Orban <[email protected]>
  • Loading branch information
lorban authored Jun 15, 2022
1 parent 0699bc5 commit 66de7ba
Show file tree
Hide file tree
Showing 5 changed files with 457 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLEngine;
Expand All @@ -36,31 +38,39 @@
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSocket;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.eclipse.jetty.client.api.ContentResponse;
import org.eclipse.jetty.client.http.HttpClientTransportOverHTTP;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpHeaderValue;
import org.eclipse.jetty.http.HttpScheme;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.io.ArrayByteBufferPool;
import org.eclipse.jetty.io.ArrayRetainableByteBufferPool;
import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.io.ClientConnectionFactory;
import org.eclipse.jetty.io.ClientConnector;
import org.eclipse.jetty.io.Connection;
import org.eclipse.jetty.io.ConnectionStatistics;
import org.eclipse.jetty.io.EndPoint;
import org.eclipse.jetty.io.RetainableByteBuffer;
import org.eclipse.jetty.io.RetainableByteBufferPool;
import org.eclipse.jetty.io.ssl.SslClientConnectionFactory;
import org.eclipse.jetty.io.ssl.SslConnection;
import org.eclipse.jetty.io.ssl.SslHandshakeListener;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.HttpConfiguration;
import org.eclipse.jetty.server.HttpConnectionFactory;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.SecureRequestCustomizer;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.server.SslConnectionFactory;
import org.eclipse.jetty.toolchain.test.Net;
import org.eclipse.jetty.util.Pool;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.util.thread.ExecutorThreadPool;
Expand All @@ -71,9 +81,14 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledForJreRange;
import org.junit.jupiter.api.condition.JRE;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import static org.awaitility.Awaitility.await;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
Expand Down Expand Up @@ -682,12 +697,7 @@ protected int networkFill(ByteBuffer input) throws IOException
// Trigger the creation of a new connection, but don't use it.
ConnectionPoolHelper.tryCreate(connectionPool);
// Verify that the connection has been created.
while (true)
{
Thread.sleep(50);
if (connectionPool.getConnectionCount() == 1)
break;
}
await().atMost(5, TimeUnit.SECONDS).until(connectionPool::getConnectionCount, is(1));

// Wait for the server to idle timeout the connection.
Thread.sleep(idleTimeout + idleTimeout / 2);
Expand All @@ -698,6 +708,299 @@ protected int networkFill(ByteBuffer input) throws IOException
assertEquals(0, clientBytes.get());
}

@Test
public void testEncryptedInputBufferRepooling() throws Exception
{
SslContextFactory.Server serverTLSFactory = createServerSslContextFactory();
QueuedThreadPool serverThreads = new QueuedThreadPool();
serverThreads.setName("server");
server = new Server(serverThreads);
var retainableByteBufferPool = new ArrayRetainableByteBufferPool()
{
@Override
public Pool<RetainableByteBuffer> poolFor(int capacity, boolean direct)
{
return super.poolFor(capacity, direct);
}
};
server.addBean(retainableByteBufferPool);
HttpConfiguration httpConfig = new HttpConfiguration();
httpConfig.addCustomizer(new SecureRequestCustomizer());
HttpConnectionFactory http = new HttpConnectionFactory(httpConfig);
SslConnectionFactory ssl = new SslConnectionFactory(serverTLSFactory, http.getProtocol())
{
@Override
protected SslConnection newSslConnection(Connector connector, EndPoint endPoint, SSLEngine engine)
{
ByteBufferPool byteBufferPool = connector.getByteBufferPool();
RetainableByteBufferPool retainableByteBufferPool = connector.getBean(RetainableByteBufferPool.class);
return new SslConnection(retainableByteBufferPool, byteBufferPool, connector.getExecutor(), endPoint, engine, isDirectBuffersForEncryption(), isDirectBuffersForDecryption())
{
@Override
protected int networkFill(ByteBuffer input) throws IOException
{
int n = super.networkFill(input);
if (n > 0)
throw new IOException("boom");
return n;
}
};
}
};
connector = new ServerConnector(server, 1, 1, ssl, http);
server.addConnector(connector);
server.setHandler(new EmptyServerHandler());
server.start();

SslContextFactory.Client clientTLSFactory = createClientSslContextFactory();
ClientConnector clientConnector = new ClientConnector();
clientConnector.setSelectors(1);
clientConnector.setSslContextFactory(clientTLSFactory);
QueuedThreadPool clientThreads = new QueuedThreadPool();
clientThreads.setName("client");
clientConnector.setExecutor(clientThreads);
client = new HttpClient(new HttpClientTransportOverHTTP(clientConnector));
client.setExecutor(clientThreads);
client.start();

assertThrows(Exception.class, () -> client.newRequest("localhost", connector.getLocalPort()).scheme(HttpScheme.HTTPS.asString()).send());

Pool<RetainableByteBuffer> bucket = retainableByteBufferPool.poolFor(16 * 1024 + 1, ssl.isDirectBuffersForEncryption());
assertEquals(1, bucket.size());
assertEquals(1, bucket.getIdleCount());
}

@Test
public void testEncryptedOutputBufferRepooling() throws Exception
{
SslContextFactory.Server serverTLSFactory = createServerSslContextFactory();
QueuedThreadPool serverThreads = new QueuedThreadPool();
serverThreads.setName("server");
server = new Server(serverThreads);
List<ByteBuffer> leakedBuffers = new ArrayList<>();
ArrayByteBufferPool byteBufferPool = new ArrayByteBufferPool()
{
@Override
public ByteBuffer acquire(int size, boolean direct)
{
ByteBuffer acquired = super.acquire(size, direct);
leakedBuffers.add(acquired);
return acquired;
}

@Override
public void release(ByteBuffer buffer)
{
leakedBuffers.remove(buffer);
super.release(buffer);
}
};
server.addBean(byteBufferPool);
HttpConfiguration httpConfig = new HttpConfiguration();
httpConfig.addCustomizer(new SecureRequestCustomizer());
HttpConnectionFactory http = new HttpConnectionFactory(httpConfig);
SslConnectionFactory ssl = new SslConnectionFactory(serverTLSFactory, http.getProtocol())
{
@Override
protected SslConnection newSslConnection(Connector connector, EndPoint endPoint, SSLEngine engine)
{
ByteBufferPool byteBufferPool = connector.getByteBufferPool();
RetainableByteBufferPool retainableByteBufferPool = connector.getBean(RetainableByteBufferPool.class);
return new SslConnection(retainableByteBufferPool, byteBufferPool, connector.getExecutor(), endPoint, engine, isDirectBuffersForEncryption(), isDirectBuffersForDecryption())
{
@Override
protected boolean networkFlush(ByteBuffer output) throws IOException
{
throw new IOException("bang");
}
};
}
};
connector = new ServerConnector(server, 1, 1, ssl, http);
server.addConnector(connector);
server.setHandler(new EmptyServerHandler());
server.start();

SslContextFactory.Client clientTLSFactory = createClientSslContextFactory();
ClientConnector clientConnector = new ClientConnector();
clientConnector.setSelectors(1);
clientConnector.setSslContextFactory(clientTLSFactory);
QueuedThreadPool clientThreads = new QueuedThreadPool();
clientThreads.setName("client");
clientConnector.setExecutor(clientThreads);
client = new HttpClient(new HttpClientTransportOverHTTP(clientConnector));
client.setExecutor(clientThreads);
client.start();

assertThrows(Exception.class, () -> client.newRequest("localhost", connector.getLocalPort()).scheme(HttpScheme.HTTPS.asString()).send());

await().atMost(5, TimeUnit.SECONDS).until(() -> leakedBuffers, is(empty()));
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
public void testEncryptedOutputBufferRepoolingAfterNetworkFlushReturnsFalse(boolean close) throws Exception
{
SslContextFactory.Server serverTLSFactory = createServerSslContextFactory();
QueuedThreadPool serverThreads = new QueuedThreadPool();
serverThreads.setName("server");
server = new Server(serverThreads);
List<ByteBuffer> leakedBuffers = new ArrayList<>();
ArrayByteBufferPool byteBufferPool = new ArrayByteBufferPool()
{
@Override
public ByteBuffer acquire(int size, boolean direct)
{
ByteBuffer acquired = super.acquire(size, direct);
leakedBuffers.add(acquired);
return acquired;
}

@Override
public void release(ByteBuffer buffer)
{
leakedBuffers.remove(buffer);
super.release(buffer);
}
};
server.addBean(byteBufferPool);
HttpConfiguration httpConfig = new HttpConfiguration();
httpConfig.addCustomizer(new SecureRequestCustomizer());
HttpConnectionFactory http = new HttpConnectionFactory(httpConfig);
AtomicBoolean failFlush = new AtomicBoolean(false);
SslConnectionFactory ssl = new SslConnectionFactory(serverTLSFactory, http.getProtocol())
{
@Override
protected SslConnection newSslConnection(Connector connector, EndPoint endPoint, SSLEngine engine)
{
ByteBufferPool byteBufferPool = connector.getByteBufferPool();
RetainableByteBufferPool retainableByteBufferPool = connector.getBean(RetainableByteBufferPool.class);
return new SslConnection(retainableByteBufferPool, byteBufferPool, connector.getExecutor(), endPoint, engine, isDirectBuffersForEncryption(), isDirectBuffersForDecryption())
{
@Override
protected boolean networkFlush(ByteBuffer output) throws IOException
{
if (failFlush.get())
return false;
return super.networkFlush(output);
}
};
}
};
connector = new ServerConnector(server, 1, 1, ssl, http);
server.addConnector(connector);
server.setHandler(new EmptyServerHandler()
{
@Override
protected void service(String target, Request jettyRequest, HttpServletRequest request, HttpServletResponse response)
{
failFlush.set(true);
if (close)
jettyRequest.getHttpChannel().getEndPoint().close();
else
jettyRequest.getHttpChannel().getEndPoint().shutdownOutput();
}
});
server.start();

SslContextFactory.Client clientTLSFactory = createClientSslContextFactory();
ClientConnector clientConnector = new ClientConnector();
clientConnector.setSelectors(1);
clientConnector.setSslContextFactory(clientTLSFactory);
QueuedThreadPool clientThreads = new QueuedThreadPool();
clientThreads.setName("client");
clientConnector.setExecutor(clientThreads);
client = new HttpClient(new HttpClientTransportOverHTTP(clientConnector));
client.setExecutor(clientThreads);
client.start();

assertThrows(Exception.class, () -> client.newRequest("localhost", connector.getLocalPort()).scheme(HttpScheme.HTTPS.asString()).send());

await().atMost(5, TimeUnit.SECONDS).until(() -> leakedBuffers, is(empty()));
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
public void testEncryptedOutputBufferRepoolingAfterNetworkFlushThrows(boolean close) throws Exception
{
SslContextFactory.Server serverTLSFactory = createServerSslContextFactory();
QueuedThreadPool serverThreads = new QueuedThreadPool();
serverThreads.setName("server");
server = new Server(serverThreads);
List<ByteBuffer> leakedBuffers = new ArrayList<>();
ArrayByteBufferPool byteBufferPool = new ArrayByteBufferPool()
{
@Override
public ByteBuffer acquire(int size, boolean direct)
{
ByteBuffer acquired = super.acquire(size, direct);
leakedBuffers.add(acquired);
return acquired;
}

@Override
public void release(ByteBuffer buffer)
{
leakedBuffers.remove(buffer);
super.release(buffer);
}
};
server.addBean(byteBufferPool);
HttpConfiguration httpConfig = new HttpConfiguration();
httpConfig.addCustomizer(new SecureRequestCustomizer());
HttpConnectionFactory http = new HttpConnectionFactory(httpConfig);
AtomicBoolean failFlush = new AtomicBoolean(false);
SslConnectionFactory ssl = new SslConnectionFactory(serverTLSFactory, http.getProtocol())
{
@Override
protected SslConnection newSslConnection(Connector connector, EndPoint endPoint, SSLEngine engine)
{
ByteBufferPool byteBufferPool = connector.getByteBufferPool();
RetainableByteBufferPool retainableByteBufferPool = connector.getBean(RetainableByteBufferPool.class);
return new SslConnection(retainableByteBufferPool, byteBufferPool, connector.getExecutor(), endPoint, engine, isDirectBuffersForEncryption(), isDirectBuffersForDecryption())
{
@Override
protected boolean networkFlush(ByteBuffer output) throws IOException
{
if (failFlush.get())
throw new IOException();
return super.networkFlush(output);
}
};
}
};
connector = new ServerConnector(server, 1, 1, ssl, http);
server.addConnector(connector);
server.setHandler(new EmptyServerHandler()
{
@Override
protected void service(String target, Request jettyRequest, HttpServletRequest request, HttpServletResponse response) throws IOException
{
failFlush.set(true);
if (close)
jettyRequest.getHttpChannel().getEndPoint().close();
else
jettyRequest.getHttpChannel().getEndPoint().shutdownOutput();
}
});
server.start();

SslContextFactory.Client clientTLSFactory = createClientSslContextFactory();
ClientConnector clientConnector = new ClientConnector();
clientConnector.setSelectors(1);
clientConnector.setSslContextFactory(clientTLSFactory);
QueuedThreadPool clientThreads = new QueuedThreadPool();
clientThreads.setName("client");
clientConnector.setExecutor(clientThreads);
client = new HttpClient(new HttpClientTransportOverHTTP(clientConnector));
client.setExecutor(clientThreads);
client.start();

assertThrows(Exception.class, () -> client.newRequest("localhost", connector.getLocalPort()).scheme(HttpScheme.HTTPS.asString()).send());

await().atMost(5, TimeUnit.SECONDS).until(() -> leakedBuffers, is(empty()));
}

@Test
public void testNeverUsedConnectionThenClientIdleTimeout() throws Exception
{
Expand Down Expand Up @@ -780,12 +1083,7 @@ protected int networkFill(ByteBuffer input) throws IOException
// Trigger the creation of a new connection, but don't use it.
ConnectionPoolHelper.tryCreate(connectionPool);
// Verify that the connection has been created.
while (true)
{
Thread.sleep(50);
if (connectionPool.getConnectionCount() == 1)
break;
}
await().atMost(5, TimeUnit.SECONDS).until(connectionPool::getConnectionCount, is(1));

// Wait for the client to idle timeout the connection.
Thread.sleep(idleTimeout + idleTimeout / 2);
Expand Down
Loading

0 comments on commit 66de7ba

Please sign in to comment.