Skip to content

Commit

Permalink
Refactor grpc tls (alibaba#10759)
Browse files Browse the repository at this point in the history
* Move Tls negotiator to GrpcSdkServer.

* use protocol negotiator builder replace directly create.

* use SPI load negotiator and set tls as default negotiator.

* Remove tlsconfig in BaseRpcServer.

* Add some ut.

* For checkstyle.
  • Loading branch information
KomachiSion authored Jul 11, 2023
1 parent a83e2cc commit 9069730
Show file tree
Hide file tree
Showing 28 changed files with 648 additions and 272 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public void testScheduleUpdateIfAbsent() throws InterruptedException, NacosExcep
notifyer);

serviceInfoUpdateService.scheduleUpdateIfAbsent("aa", "bb", "cc");
TimeUnit.SECONDS.sleep(2);
TimeUnit.MILLISECONDS.sleep(1500);
Mockito.verify(proxy).queryInstancesOfService(serviceName, group, clusters, 0, false);
}

Expand Down
30 changes: 10 additions & 20 deletions core/src/main/java/com/alibaba/nacos/core/remote/BaseRpcServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@

import com.alibaba.nacos.common.remote.ConnectionType;
import com.alibaba.nacos.common.remote.PayloadRegistry;
import com.alibaba.nacos.common.utils.JacksonUtils;
import com.alibaba.nacos.core.remote.tls.RpcServerSslContextRefresherHolder;
import com.alibaba.nacos.core.utils.Loggers;
import com.alibaba.nacos.sys.env.EnvUtil;
import org.springframework.beans.factory.annotation.Autowired;

import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
Expand All @@ -38,27 +37,21 @@ public abstract class BaseRpcServer {
PayloadRegistry.init();
}

@Autowired
protected RpcServerTlsConfig rpcServerTlsConfig;

/**
* Start sever.
*/
@PostConstruct
public void start() throws Exception {
String serverName = getClass().getSimpleName();
String tlsConfig = JacksonUtils.toJson(rpcServerTlsConfig);
Loggers.REMOTE.info("Nacos {} Rpc server starting at port {} and tls config:{}", serverName, getServicePort(),
tlsConfig);
Loggers.REMOTE.info("Nacos {} Rpc server starting at port {}", serverName, getServicePort());

startServer();

if (RpcServerSslContextRefresherHolder.getInstance() != null) {
RpcServerSslContextRefresherHolder.getInstance().refresh(this);
}

Loggers.REMOTE.info("Nacos {} Rpc server started at port {} and tls config:{}", serverName, getServicePort(),
tlsConfig);
Loggers.REMOTE.info("Nacos {} Rpc server started at port {}", serverName, getServicePort());
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
Loggers.REMOTE.info("Nacos {} Rpc server stopping", serverName);
try {
Expand All @@ -78,18 +71,15 @@ public void start() throws Exception {
*/
public abstract ConnectionType getConnectionType();

public RpcServerTlsConfig getRpcServerTlsConfig() {
return rpcServerTlsConfig;
}

public void setRpcServerTlsConfig(RpcServerTlsConfig rpcServerTlsConfig) {
this.rpcServerTlsConfig = rpcServerTlsConfig;
}

/**
* reload ssl context.
* Reload protocol context if necessary.
*
* <p>
* protocol like:
* <li>Tls</li>
* </p>
*/
public abstract void reloadSslContext();
public abstract void reloadProtocolContext();

/**
* Start sever.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,7 @@
package com.alibaba.nacos.core.remote.grpc;

import com.alibaba.nacos.api.grpc.auto.Payload;
import com.alibaba.nacos.common.packagescan.resource.DefaultResourceLoader;
import com.alibaba.nacos.common.packagescan.resource.Resource;
import com.alibaba.nacos.common.packagescan.resource.ResourceLoader;
import com.alibaba.nacos.common.remote.ConnectionType;

import com.alibaba.nacos.common.utils.JacksonUtils;
import com.alibaba.nacos.common.utils.StringUtils;
import com.alibaba.nacos.common.utils.TlsTypeResolve;
import com.alibaba.nacos.core.remote.BaseRpcServer;
import com.alibaba.nacos.core.remote.ConnectionManager;
import com.alibaba.nacos.core.utils.Loggers;
Expand All @@ -37,22 +30,14 @@
import io.grpc.ServerInterceptor;
import io.grpc.ServerInterceptors;
import io.grpc.ServerServiceDefinition;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.InternalProtocolNegotiator;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;

import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.grpc.protobuf.ProtoUtils;
import io.grpc.stub.ServerCalls;
import io.grpc.util.MutableHandlerRegistry;
import org.springframework.beans.factory.annotation.Autowired;

import javax.net.ssl.SSLException;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Optional;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

Expand All @@ -66,8 +51,6 @@ public abstract class BaseGrpcServer extends BaseRpcServer {

private Server server;

private final ResourceLoader resourceLoader = new DefaultResourceLoader();

@Autowired
private GrpcRequestAcceptor grpcCommonRequestAcceptor;

Expand All @@ -77,8 +60,6 @@ public abstract class BaseGrpcServer extends BaseRpcServer {
@Autowired
private ConnectionManager connectionManager;

private OptionalTlsProtocolNegotiator optionalTlsProtocolNegotiator;

@Override
public ConnectionType getConnectionType() {
return ConnectionType.GRPC;
Expand All @@ -90,10 +71,11 @@ public void startServer() throws Exception {
addServices(handlerRegistry, new GrpcConnectionInterceptor(), new GrpcServerParamCheckInterceptor());
NettyServerBuilder builder = NettyServerBuilder.forPort(getServicePort()).executor(getRpcExecutor());

if (rpcServerTlsConfig.getEnableTls()) {
builder.protocolNegotiator(
new OptionalTlsProtocolNegotiator(getSslContextBuilder(), rpcServerTlsConfig.getCompatibility()));

Optional<InternalProtocolNegotiator.ProtocolNegotiator> negotiator = newProtocolNegotiator();
if (negotiator.isPresent()) {
InternalProtocolNegotiator.ProtocolNegotiator actual = negotiator.get();
Loggers.REMOTE.info("Add protocol negotiator {}", actual.getClass().getCanonicalName());
builder.protocolNegotiator(actual);
}

server = builder.maxInboundMessageSize(getMaxInboundMessageSize()).fallbackHandlerRegistry(handlerRegistry)
Expand All @@ -107,20 +89,26 @@ public void startServer() throws Exception {
server.start();
}

@Override
public void reloadProtocolContext() {
reloadProtocolNegotiator();
}

/**
* reload ssl context.
* Build new one protocol negotiator.
*
* <p>Such as support tls, proxy protocol and so on</p>
*
* @return ProtocolNegotiator
*/
public void reloadSslContext() {
if (optionalTlsProtocolNegotiator != null) {
try {
optionalTlsProtocolNegotiator.setSslContext(getSslContextBuilder());
} catch (Throwable throwable) {
Loggers.REMOTE.info("Nacos {} Rpc server reload ssl context fail at port {} and tls config:{}",
this.getClass().getSimpleName(), getServicePort(),
JacksonUtils.toJson(super.rpcServerTlsConfig));
throw throwable;
}
}
protected Optional<InternalProtocolNegotiator.ProtocolNegotiator> newProtocolNegotiator() {
return Optional.empty();
}

/**
* reload protocol negotiator If necessary.
*/
public void reloadProtocolNegotiator() {
}

protected long getPermitKeepAliveTime() {
Expand All @@ -136,8 +124,8 @@ protected long getKeepAliveTimeout() {
}

protected int getMaxInboundMessageSize() {
Integer property = EnvUtil.getProperty(GrpcServerConstants.GrpcConfig.MAX_INBOUND_MSG_SIZE_PROPERTY,
Integer.class);
Integer property = EnvUtil
.getProperty(GrpcServerConstants.GrpcConfig.MAX_INBOUND_MSG_SIZE_PROPERTY, Integer.class);
if (property != null) {
return property;
}
Expand All @@ -148,32 +136,34 @@ private void addServices(MutableHandlerRegistry handlerRegistry, ServerIntercept

// unary common call register.
final MethodDescriptor<Payload, Payload> unaryPayloadMethod = MethodDescriptor.<Payload, Payload>newBuilder()
.setType(MethodDescriptor.MethodType.UNARY).setFullMethodName(
MethodDescriptor.generateFullMethodName(GrpcServerConstants.REQUEST_SERVICE_NAME,
.setType(MethodDescriptor.MethodType.UNARY).setFullMethodName(MethodDescriptor
.generateFullMethodName(GrpcServerConstants.REQUEST_SERVICE_NAME,
GrpcServerConstants.REQUEST_METHOD_NAME))
.setRequestMarshaller(ProtoUtils.marshaller(Payload.getDefaultInstance()))
.setResponseMarshaller(ProtoUtils.marshaller(Payload.getDefaultInstance())).build();

final ServerCallHandler<Payload, Payload> payloadHandler = ServerCalls.asyncUnaryCall(
(request, responseObserver) -> grpcCommonRequestAcceptor.request(request, responseObserver));

final ServerServiceDefinition serviceDefOfUnaryPayload = ServerServiceDefinition.builder(
GrpcServerConstants.REQUEST_SERVICE_NAME).addMethod(unaryPayloadMethod, payloadHandler).build();
final ServerServiceDefinition serviceDefOfUnaryPayload = ServerServiceDefinition
.builder(GrpcServerConstants.REQUEST_SERVICE_NAME).addMethod(unaryPayloadMethod, payloadHandler)
.build();
handlerRegistry.addService(ServerInterceptors.intercept(serviceDefOfUnaryPayload, serverInterceptor));

// bi stream register.
final ServerCallHandler<Payload, Payload> biStreamHandler = ServerCalls.asyncBidiStreamingCall(
(responseObserver) -> grpcBiStreamRequestAcceptor.requestBiStream(responseObserver));

final MethodDescriptor<Payload, Payload> biStreamMethod = MethodDescriptor.<Payload, Payload>newBuilder()
.setType(MethodDescriptor.MethodType.BIDI_STREAMING).setFullMethodName(
MethodDescriptor.generateFullMethodName(GrpcServerConstants.REQUEST_BI_STREAM_SERVICE_NAME,
.setType(MethodDescriptor.MethodType.BIDI_STREAMING).setFullMethodName(MethodDescriptor
.generateFullMethodName(GrpcServerConstants.REQUEST_BI_STREAM_SERVICE_NAME,
GrpcServerConstants.REQUEST_BI_STREAM_METHOD_NAME))
.setRequestMarshaller(ProtoUtils.marshaller(Payload.newBuilder().build()))
.setResponseMarshaller(ProtoUtils.marshaller(Payload.getDefaultInstance())).build();

final ServerServiceDefinition serviceDefOfBiStream = ServerServiceDefinition.builder(
GrpcServerConstants.REQUEST_BI_STREAM_SERVICE_NAME).addMethod(biStreamMethod, biStreamHandler).build();
final ServerServiceDefinition serviceDefOfBiStream = ServerServiceDefinition
.builder(GrpcServerConstants.REQUEST_BI_STREAM_SERVICE_NAME).addMethod(biStreamMethod, biStreamHandler)
.build();
handlerRegistry.addService(ServerInterceptors.intercept(serviceDefOfBiStream, serverInterceptor));

}
Expand All @@ -185,57 +175,6 @@ public void shutdownServer() {
}
}

private SslContext getSslContextBuilder() {
try {
if (StringUtils.isBlank(rpcServerTlsConfig.getCertChainFile()) || StringUtils.isBlank(
rpcServerTlsConfig.getCertPrivateKey())) {
throw new IllegalArgumentException("Server certChainFile or certPrivateKey must be not null");
}
InputStream certificateChainFile = getInputStream(rpcServerTlsConfig.getCertChainFile(), "certChainFile");
InputStream privateKeyFile = getInputStream(rpcServerTlsConfig.getCertPrivateKey(), "certPrivateKey");
SslContextBuilder sslClientContextBuilder = SslContextBuilder.forServer(certificateChainFile,
privateKeyFile, rpcServerTlsConfig.getCertPrivateKeyPassword());

if (StringUtils.isNotBlank(rpcServerTlsConfig.getProtocols())) {
sslClientContextBuilder.protocols(rpcServerTlsConfig.getProtocols().split(","));
}

if (StringUtils.isNotBlank(rpcServerTlsConfig.getCiphers())) {
sslClientContextBuilder.ciphers(Arrays.asList(rpcServerTlsConfig.getCiphers().split(",")));
}
if (rpcServerTlsConfig.getMutualAuthEnable()) {
// trust all certificate
if (rpcServerTlsConfig.getTrustAll()) {
sslClientContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE);
} else {
if (StringUtils.isBlank(rpcServerTlsConfig.getTrustCollectionCertFile())) {
throw new IllegalArgumentException(
"enable mutual auth,trustCollectionCertFile must be not null");
}

InputStream clientCert = getInputStream(rpcServerTlsConfig.getTrustCollectionCertFile(),
"trustCollectionCertFile");
sslClientContextBuilder.trustManager(clientCert);
}
sslClientContextBuilder.clientAuth(ClientAuth.REQUIRE);
}
SslContextBuilder configure = GrpcSslContexts.configure(sslClientContextBuilder,
TlsTypeResolve.getSslProvider(rpcServerTlsConfig.getSslProvider()));
return configure.build();
} catch (SSLException e) {
throw new RuntimeException(e);
}
}

private InputStream getInputStream(String path, String config) {
try {
Resource resource = resourceLoader.getResource(path);
return resource.getInputStream();
} catch (IOException e) {
throw new RuntimeException(config + " load fail", e);
}
}

/**
* get rpc executor.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ public ThreadPoolExecutor getRpcExecutor() {

@Override
protected long getKeepAliveTime() {
Long property = EnvUtil.getProperty(GrpcServerConstants.GrpcConfig.CLUSTER_KEEP_ALIVE_TIME_PROPERTY,
Long.class);
Long property = EnvUtil
.getProperty(GrpcServerConstants.GrpcConfig.CLUSTER_KEEP_ALIVE_TIME_PROPERTY, Long.class);
if (property != null) {
return property;
}
Expand All @@ -58,8 +58,8 @@ protected long getKeepAliveTime() {

@Override
protected long getKeepAliveTimeout() {
Long property = EnvUtil.getProperty(GrpcServerConstants.GrpcConfig.CLUSTER_KEEP_ALIVE_TIMEOUT_PROPERTY,
Long.class);
Long property = EnvUtil
.getProperty(GrpcServerConstants.GrpcConfig.CLUSTER_KEEP_ALIVE_TIMEOUT_PROPERTY, Long.class);
if (property != null) {
return property;
}
Expand All @@ -68,8 +68,7 @@ protected long getKeepAliveTimeout() {

@Override
protected long getPermitKeepAliveTime() {
Long property = EnvUtil.getProperty(GrpcServerConstants.GrpcConfig.CLUSTER_PERMIT_KEEP_ALIVE_TIME,
Long.class);
Long property = EnvUtil.getProperty(GrpcServerConstants.GrpcConfig.CLUSTER_PERMIT_KEEP_ALIVE_TIME, Long.class);
if (property != null) {
return property;
}
Expand All @@ -78,8 +77,8 @@ protected long getPermitKeepAliveTime() {

@Override
protected int getMaxInboundMessageSize() {
Integer property = EnvUtil.getProperty(GrpcServerConstants.GrpcConfig.CLUSTER_MAX_INBOUND_MSG_SIZE_PROPERTY,
Integer.class);
Integer property = EnvUtil
.getProperty(GrpcServerConstants.GrpcConfig.CLUSTER_MAX_INBOUND_MSG_SIZE_PROPERTY, Integer.class);
if (property != null) {
return property;
}
Expand All @@ -92,5 +91,4 @@ protected int getMaxInboundMessageSize() {
}
return size;
}

}
Loading

0 comments on commit 9069730

Please sign in to comment.