Skip to content

Commit

Permalink
xds/server: Fix xDS Server leak (#7664) (#7681)
Browse files Browse the repository at this point in the history
  • Loading branch information
arjan-bal authored Sep 30, 2024
1 parent 935f8cb commit 4f6c5f2
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 15 deletions.
1 change: 1 addition & 0 deletions xds/internal/server/conn_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ func (c *connWrapper) Close() error {
if c.rootProvider != nil {
c.rootProvider.Close()
}
c.parent.removeConn(c)
return c.Conn.Close()
}

Expand Down
34 changes: 19 additions & 15 deletions xds/internal/server/listener_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ func NewListenerWrapper(params ListenerWrapperParams) net.Listener {
xdsC: params.XDSClient,
modeCallback: params.ModeCallback,
isUnspecifiedAddr: params.Listener.Addr().(*net.TCPAddr).IP.IsUnspecified(),
conns: make(map[*connWrapper]bool),

mode: connectivity.ServingModeNotServing,
closed: grpcsync.NewEvent(),
Expand Down Expand Up @@ -135,13 +136,13 @@ type listenerWrapper struct {

// mu guards access to the current serving mode and the active filter chain
// manager.
mu sync.RWMutex
mu sync.Mutex
// Current serving mode.
mode connectivity.ServingMode
// Filter chain manager currently serving.
activeFilterChainManager *xdsresource.FilterChainManager
// conns accepted with configuration from activeFilterChainManager.
conns []*connWrapper
conns map[*connWrapper]bool

// These fields are read/written to in the context of xDS updates, which are
// guaranteed to be emitted synchronously from the xDS Client. Thus, they do
Expand Down Expand Up @@ -202,17 +203,14 @@ func (l *listenerWrapper) maybeUpdateFilterChains() {
// gracefully shut down with a grace period of 10 minutes for long-lived
// RPC's, such that clients will reconnect and have the updated
// configuration apply." - A36
var connsToClose []*connWrapper
if l.activeFilterChainManager != nil { // If there is a filter chain manager to clean up.
connsToClose = l.conns
l.conns = nil
}
connsToClose := l.conns
l.conns = make(map[*connWrapper]bool)
l.activeFilterChainManager = l.pendingFilterChainManager
l.pendingFilterChainManager = nil
l.instantiateFilterChainRoutingConfigurationsLocked()
l.mu.Unlock()
go func() {
for _, conn := range connsToClose {
for conn := range connsToClose {
conn.Drain()
}
}()
Expand Down Expand Up @@ -304,15 +302,15 @@ func (l *listenerWrapper) Accept() (net.Conn, error) {
return nil, fmt.Errorf("received connection with non-TCP address (local: %T, remote %T)", conn.LocalAddr(), conn.RemoteAddr())
}

l.mu.RLock()
l.mu.Lock()
if l.mode == connectivity.ServingModeNotServing {
// Close connections as soon as we accept them when we are in
// "not-serving" mode. Since we accept a net.Listener from the user
// in Serve(), we cannot close the listener when we move to
// "not-serving". Closing the connection immediately upon accepting
// is one of the other ways to implement the "not-serving" mode as
// outlined in gRFC A36.
l.mu.RUnlock()
l.mu.Unlock()
conn.Close()
continue
}
Expand All @@ -324,7 +322,7 @@ func (l *listenerWrapper) Accept() (net.Conn, error) {
SourcePort: srcAddr.Port,
})
if err != nil {
l.mu.RUnlock()
l.mu.Unlock()
// When a matching filter chain is not found, we close the
// connection right away, but do not return an error back to
// `grpc.Serve()` from where this Accept() was invoked. Returning an
Expand All @@ -341,12 +339,18 @@ func (l *listenerWrapper) Accept() (net.Conn, error) {
continue
}
cw := &connWrapper{Conn: conn, filterChain: fc, parent: l, urc: fc.UsableRouteConfiguration}
l.conns = append(l.conns, cw)
l.mu.RUnlock()
l.conns[cw] = true
l.mu.Unlock()
return cw, nil
}
}

func (l *listenerWrapper) removeConn(conn *connWrapper) {
l.mu.Lock()
defer l.mu.Unlock()
delete(l.conns, conn)
}

// Close closes the underlying listener. It also cancels the xDS watch
// registered in Serve() and closes any certificate provider instances created
// based on security configuration received in the LDS response.
Expand Down Expand Up @@ -376,9 +380,9 @@ func (l *listenerWrapper) switchModeLocked(newMode connectivity.ServingMode, err
l.mode = newMode
if l.mode == connectivity.ServingModeNotServing {
connsToClose := l.conns
l.conns = nil
l.conns = make(map[*connWrapper]bool)
go func() {
for _, conn := range connsToClose {
for conn := range connsToClose {
conn.Drain()
}
}()
Expand Down
105 changes: 105 additions & 0 deletions xds/internal/server/listener_wrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@ import (
"fmt"
"net"
"strconv"
"sync"
"testing"
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/testutils/xds/e2e"
testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
xdsinternal "google.golang.org/grpc/xds/internal"
"google.golang.org/grpc/xds/internal/xdsclient"
"google.golang.org/grpc/xds/internal/xdsclient/xdsresource"
Expand Down Expand Up @@ -151,5 +157,104 @@ func (s) TestListenerWrapper(t *testing.T) {
t.Fatalf("mode change received: %v, want: %v", mode, connectivity.ServingModeNotServing)
}
}
}

type testService struct {
testgrpc.TestServiceServer
}

func (*testService) EmptyCall(context.Context, *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
}

// TestConnsCleanup tests that the listener wrapper clears it's connection
// references when connections close. It sets up a listener wrapper and gRPC
// Server, and connects to the server 100 times and makes an RPC each time, and
// then closes the connection. After these 100 connections Close, the listener
// wrapper should have no more references to any connections.
func (s) TestConnsCleanup(t *testing.T) {
mgmtServer, nodeID, _, _, xdsC := xdsSetupForTests(t)
lis, err := testutils.LocalTCPListener()
if err != nil {
t.Fatalf("Failed to create a local TCP listener: %v", err)
}

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

modeCh := make(chan connectivity.ServingMode, 1)
vm := verifyMode{
modeCh: modeCh,
}

host, port := hostPortFromListener(t, lis)
lisResourceName := fmt.Sprintf(e2e.ServerListenerResourceNameTemplate, net.JoinHostPort(host, strconv.Itoa(int(port))))
params := ListenerWrapperParams{
Listener: lis,
ListenerResourceName: lisResourceName,
XDSClient: xdsC,
ModeCallback: vm.verifyModeCallback,
}
lw := NewListenerWrapper(params)
if lw == nil {
t.Fatalf("NewListenerWrapper(%+v) returned nil", params)
}
defer lw.Close()

resources := e2e.UpdateOptions{
NodeID: nodeID,
Listeners: []*v3listenerpb.Listener{e2e.DefaultServerListener(host, port, e2e.SecurityLevelNone, route1)},
SkipValidation: true,
}
if err := mgmtServer.Update(ctx, resources); err != nil {
t.Fatal(err)
}

// Wait for Listener Mode to go serving.
select {
case <-ctx.Done():
t.Fatalf("timeout waiting for mode change")
case mode := <-modeCh:
if mode != connectivity.ServingModeServing {
t.Fatalf("mode change received: %v, want: %v", mode, connectivity.ServingModeServing)
}
}

server := grpc.NewServer(grpc.Creds(insecure.NewCredentials()))
testgrpc.RegisterTestServiceServer(server, &testService{})
wg := sync.WaitGroup{}
go func() {
if err := server.Serve(lw); err != nil {
t.Errorf("failed to serve: %v", err)
}
}()

// Make 100 connections to the server, and make an RPC on each one.
for i := 0; i < 100; i++ {
cc, err := grpc.NewClient(lw.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient failed with err: %v", err)
}
client := testgrpc.NewTestServiceClient(cc)
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("client.EmptyCall() failed: %v", err)
}
cc.Close()
}

lisWrapper := lw.(*listenerWrapper)
// Eventually when the server processes the connection shutdowns, the
// listener wrapper should clear its references to the wrapped connections.
lenConns := 1
for ; ctx.Err() == nil && lenConns > 0; <-time.After(time.Millisecond) {
lisWrapper.mu.Lock()
lenConns = len(lisWrapper.conns)
lisWrapper.mu.Unlock()
}
if lenConns > 0 {
t.Fatalf("timeout waiting for lis wrapper conns to clear, size: %v", lenConns)
}

server.Stop()
wg.Wait()
}

0 comments on commit 4f6c5f2

Please sign in to comment.