From 3f1461899145de7c07611c7fb200726d937e1021 Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Thu, 22 Feb 2024 10:53:43 +0800 Subject: [PATCH] Refine TCPMux memory usage Reduce read buf of first stun message. Add timeout to read first message to clean interrupted connection earlier. Add alive duration for gather to access connection created from stun bind, avoid connection leak from malicious client. --- tcp_mux.go | 63 ++++++++++++++++---- tcp_mux_test.go | 141 +++++++++++++++++++++++++++++++++++++++++++++ tcp_packet_conn.go | 28 +++++++-- 3 files changed, 216 insertions(+), 16 deletions(-) diff --git a/tcp_mux.go b/tcp_mux.go index df169be2..ff3afed0 100644 --- a/tcp_mux.go +++ b/tcp_mux.go @@ -10,6 +10,7 @@ import ( "net" "strings" "sync" + "time" "github.com/pion/logging" "github.com/pion/stun" @@ -52,6 +53,16 @@ type TCPMuxParams struct { // if the write buffer is full, the subsequent write packet will be dropped until it has enough space. // a default 4MB is recommended. WriteBufferSize int + + // A new established connection will be removed if the first STUN binding request is not received within this timeout, + // avoiding the client with bad network or attacker to create a lot of empty connections. + // Default 30s timeout will be used if not set. + FirstStunBindTimeout time.Duration + + // TCPMux will create connection from STUN binding request with an unknown username, if + // the connection is not used in the timeout, it will be removed to avoid resource leak / attack. + // Default 30s timeout will be used if not set. + AliveDurationForConnFromStun time.Duration } // NewTCPMuxDefault creates a new instance of TCPMuxDefault. @@ -60,6 +71,14 @@ func NewTCPMuxDefault(params TCPMuxParams) *TCPMuxDefault { params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") } + if params.FirstStunBindTimeout == 0 { + params.FirstStunBindTimeout = 30 * time.Second + } + + if params.AliveDurationForConnFromStun == 0 { + params.AliveDurationForConnFromStun = 30 * time.Second + } + m := &TCPMuxDefault{ params: ¶ms, @@ -110,13 +129,14 @@ func (m *TCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP) } if conn, ok := m.getConn(ufrag, isIPv6, local); ok { + conn.ClearAliveTimer() return conn, nil } - return m.createConn(ufrag, isIPv6, local) + return m.createConn(ufrag, isIPv6, local, false) } -func (m *TCPMuxDefault) createConn(ufrag string, isIPv6 bool, local net.IP) (*tcpPacketConn, error) { +func (m *TCPMuxDefault) createConn(ufrag string, isIPv6 bool, local net.IP, fromStun bool) (*tcpPacketConn, error) { addr, ok := m.LocalAddr().(*net.TCPAddr) if !ok { return nil, ErrGetTransportAddress @@ -124,11 +144,17 @@ func (m *TCPMuxDefault) createConn(ufrag string, isIPv6 bool, local net.IP) (*tc localAddr := *addr localAddr.IP = local + var alive time.Duration + if fromStun { + alive = m.params.AliveDurationForConnFromStun + } + conn := newTCPPacketConn(tcpPacketParams{ - ReadBuffer: m.params.ReadBufferSize, - WriteBuffer: m.params.WriteBufferSize, - LocalAddr: &localAddr, - Logger: m.params.Logger, + ReadBuffer: m.params.ReadBufferSize, + WriteBuffer: m.params.WriteBufferSize, + LocalAddr: &localAddr, + Logger: m.params.Logger, + AliveDuration: alive, }) var conns map[ipAddr]*tcpPacketConn @@ -163,13 +189,26 @@ func (m *TCPMuxDefault) closeAndLogError(closer io.Closer) { } func (m *TCPMuxDefault) handleConn(conn net.Conn) { - buf := make([]byte, receiveMTU) + buf := make([]byte, 512) + if m.params.FirstStunBindTimeout > 0 { + if err := conn.SetReadDeadline(time.Now().Add(m.params.FirstStunBindTimeout)); err != nil { + m.params.Logger.Warnf("Failed to set read deadline for first STUN message: %s to %s , err: %s", conn.RemoteAddr(), conn.LocalAddr(), err) + } + } n, err := readStreamingPacket(conn, buf) if err != nil { - m.params.Logger.Warnf("Error reading first packet from %s: %s", conn.RemoteAddr().String(), err) + if errors.Is(err, io.ErrShortBuffer) { + m.params.Logger.Warnf("Buffer too small for first packet from %s: %s", conn.RemoteAddr(), err) + } else { + m.params.Logger.Warnf("Error reading first packet from %s: %s", conn.RemoteAddr(), err) + } + m.closeAndLogError(conn) return } + if err = conn.SetReadDeadline(time.Time{}); err != nil { + m.params.Logger.Warnf("Failed to reset read deadline from %s: %s", conn.RemoteAddr(), err) + } buf = buf[:n] @@ -204,9 +243,6 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) { ufrag := strings.Split(string(attr), ":")[0] m.params.Logger.Debugf("Ufrag: %s", ufrag) - m.mu.Lock() - defer m.mu.Unlock() - host, _, err := net.SplitHostPort(conn.RemoteAddr().String()) if err != nil { m.closeAndLogError(conn) @@ -222,15 +258,18 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) { m.params.Logger.Warnf("Failed to get local tcp address in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr()) return } + m.mu.Lock() packetConn, ok := m.getConn(ufrag, isIPv6, localAddr.IP) if !ok { - packetConn, err = m.createConn(ufrag, isIPv6, localAddr.IP) + packetConn, err = m.createConn(ufrag, isIPv6, localAddr.IP, true) if err != nil { + m.mu.Unlock() m.closeAndLogError(conn) m.params.Logger.Warnf("Failed to create packetConn for STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr()) return } } + m.mu.Unlock() if err := packetConn.AddConn(conn, buf); err != nil { m.closeAndLogError(conn) diff --git a/tcp_mux_test.go b/tcp_mux_test.go index 75116871..cd3b6333 100644 --- a/tcp_mux_test.go +++ b/tcp_mux_test.go @@ -6,7 +6,9 @@ package ice import ( "io" "net" + "os" "testing" + "time" "github.com/pion/logging" "github.com/pion/stun" @@ -108,6 +110,10 @@ func TestTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) { ReadBufferSize: 20, }) + defer func() { + _ = tcpMux.Close() + }() + _, err = tcpMux.GetConnByUfrag("test", false, listener.Addr().(*net.TCPAddr).IP) require.NoError(t, err, "error getting conn by ufrag") @@ -117,3 +123,138 @@ func TestTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) { assert.Nil(t, conn, "should receive nil because mux is closed") assert.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed") } + +func TestTCPMux_FirstPacketTimeout(t *testing.T) { + report := test.CheckRoutines(t) + defer report() + + loggerFactory := logging.NewDefaultLoggerFactory() + + listener, err := net.ListenTCP("tcp", &net.TCPAddr{ + IP: net.IP{127, 0, 0, 1}, + Port: 0, + }) + require.NoError(t, err, "error starting listener") + defer func() { + _ = listener.Close() + }() + + tcpMux := NewTCPMuxDefault(TCPMuxParams{ + Listener: listener, + Logger: loggerFactory.NewLogger("ice"), + ReadBufferSize: 20, + FirstStunBindTimeout: time.Second, + }) + + require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil") + + conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr)) + require.NoError(t, err, "error dialing test TCP connection") + defer func() { + _ = conn.Close() + }() + + // Don't send any data, the mux should close the connection after the timeout + time.Sleep(1500 * time.Millisecond) + require.NoError(t, conn.SetReadDeadline(time.Now().Add(2*time.Second))) + buf := make([]byte, 1) + _, err = conn.Read(buf) + require.ErrorIs(t, err, io.EOF) +} + +func TestTCPMux_NoLeakForConnectionFromStun(t *testing.T) { + report := test.CheckRoutines(t) + defer report() + + loggerFactory := logging.NewDefaultLoggerFactory() + + listener, err := net.ListenTCP("tcp", &net.TCPAddr{ + IP: net.IP{127, 0, 0, 1}, + Port: 0, + }) + require.NoError(t, err, "error starting listener") + defer func() { + _ = listener.Close() + }() + + tcpMux := NewTCPMuxDefault(TCPMuxParams{ + Listener: listener, + Logger: loggerFactory.NewLogger("ice"), + ReadBufferSize: 20, + AliveDurationForConnFromStun: time.Second, + }) + + defer func() { + _ = tcpMux.Close() + }() + + require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil") + + t.Run("close connection from stun msg after timeout", func(t *testing.T) { + conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr)) + require.NoError(t, err, "error dialing test TCP connection") + defer func() { + _ = conn.Close() + }() + + msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, + stun.NewUsername("myufrag:otherufrag"), + stun.NewShortTermIntegrity("myufrag"), + stun.Fingerprint, + ) + require.NoError(t, err, "error building STUN packet") + msg.Encode() + + _, err = writeStreamingPacket(conn, msg.Raw) + require.NoError(t, err, "error writing TCP STUN packet") + + time.Sleep(1500 * time.Millisecond) + require.NoError(t, conn.SetReadDeadline(time.Now().Add(2*time.Second))) + buf := make([]byte, 1) + _, err = conn.Read(buf) + require.ErrorIs(t, err, io.EOF) + }) + + t.Run("connection keep alive if access by user", func(t *testing.T) { + conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr)) + require.NoError(t, err, "error dialing test TCP connection") + defer func() { + _ = conn.Close() + }() + + msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, + stun.NewUsername("myufrag2:otherufrag2"), + stun.NewShortTermIntegrity("myufrag2"), + stun.Fingerprint, + ) + require.NoError(t, err, "error building STUN packet") + msg.Encode() + + n, err := writeStreamingPacket(conn, msg.Raw) + require.NoError(t, err, "error writing TCP STUN packet") + + // wait for the connection to be created + time.Sleep(100 * time.Millisecond) + + pktConn, err := tcpMux.GetConnByUfrag("myufrag2", false, listener.Addr().(*net.TCPAddr).IP) + require.NoError(t, err, "error retrieving muxed connection for ufrag") + defer func() { + _ = pktConn.Close() + }() + + time.Sleep(1500 * time.Millisecond) + + // timeout, not closed + buf := make([]byte, 1024) + require.NoError(t, conn.SetReadDeadline(time.Now().Add(100*time.Millisecond))) + _, err = conn.Read(buf) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + + recv := make([]byte, n) + n2, rAddr, err := pktConn.ReadFrom(recv) + require.NoError(t, err, "error receiving data") + assert.Equal(t, conn.LocalAddr(), rAddr, "remote tcp address mismatch") + assert.Equal(t, n, n2, "received byte size mismatch") + assert.Equal(t, msg.Raw, recv, "received bytes mismatch") + }) +} diff --git a/tcp_packet_conn.go b/tcp_packet_conn.go index 8329cba8..60ea2c80 100644 --- a/tcp_packet_conn.go +++ b/tcp_packet_conn.go @@ -85,6 +85,7 @@ type tcpPacketConn struct { wg sync.WaitGroup closedChan chan struct{} closeOnce sync.Once + aliveTimer *time.Timer } type streamingPacket struct { @@ -94,10 +95,11 @@ type streamingPacket struct { } type tcpPacketParams struct { - ReadBuffer int - LocalAddr net.Addr - Logger logging.LeveledLogger - WriteBuffer int + ReadBuffer int + LocalAddr net.Addr + Logger logging.LeveledLogger + WriteBuffer int + AliveDuration time.Duration } func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn { @@ -110,9 +112,24 @@ func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn { closedChan: make(chan struct{}), } + if params.AliveDuration > 0 { + p.aliveTimer = time.AfterFunc(params.AliveDuration, func() { + p.params.Logger.Warn("close tcp packet conn by alive timeout") + _ = p.Close() + }) + } + return p } +func (t *tcpPacketConn) ClearAliveTimer() { + t.mu.Lock() + if t.aliveTimer != nil { + t.aliveTimer.Stop() + } + t.mu.Unlock() +} + func (t *tcpPacketConn) AddConn(conn net.Conn, firstPacketData []byte) error { t.params.Logger.Infof("Added connection: %s remote %s to local %s", conn.RemoteAddr().Network(), conn.RemoteAddr(), conn.LocalAddr()) @@ -261,6 +278,9 @@ func (t *tcpPacketConn) Close() error { t.closeOnce.Do(func() { close(t.closedChan) shouldCloseRecvChan = true + if t.aliveTimer != nil { + t.aliveTimer.Stop() + } }) for _, conn := range t.conns {