diff --git a/client.go b/client.go index c20db6ee9..43a27e94f 100644 --- a/client.go +++ b/client.go @@ -85,7 +85,7 @@ func NewClient(eventHandler EventHandler, opts ...Option) (cli *Client, err erro options.ReadBufferCap = toolkit.CeilToPowerOfTwo(rbc) } el.buffer = make([]byte, options.ReadBufferCap) - el.udpSockets = make(map[int]*conn) + el.clientUDPSockets = make(map[int]*conn) el.connections = make(map[int]*conn) el.eventHandler = eventHandler cli.el = el diff --git a/connection_unix.go b/connection_unix.go index 58d928445..6db2aa26d 100644 --- a/connection_unix.go +++ b/connection_unix.go @@ -34,8 +34,8 @@ import ( type conn struct { fd int // file descriptor - sa unix.Sockaddr // remote socket address ctx interface{} // user-defined context + peer unix.Sockaddr // remote socket address loop *eventloop // connected event-loop codec ICodec // codec for TCP opened bool // connection opened event fired @@ -50,7 +50,7 @@ type conn struct { func newTCPConn(fd int, el *eventloop, sa unix.Sockaddr, codec ICodec, localAddr, remoteAddr net.Addr) (c *conn) { c = &conn{ fd: fd, - sa: sa, + peer: sa, loop: el, codec: codec, localAddr: localAddr, @@ -65,7 +65,7 @@ func newTCPConn(fd int, el *eventloop, sa unix.Sockaddr, codec ICodec, localAddr func (c *conn) releaseTCP() { c.opened = false - c.sa = nil + c.peer = nil c.ctx = nil c.localAddr = nil c.remoteAddr = nil @@ -79,13 +79,13 @@ func (c *conn) releaseTCP() { func newUDPConn(fd int, el *eventloop, localAddr net.Addr, sa unix.Sockaddr, connected bool) (c *conn) { c = &conn{ fd: fd, - sa: sa, + peer: sa, loop: el, localAddr: localAddr, remoteAddr: socket.SockaddrToUDPAddr(sa), } if connected { - c.sa = nil + c.peer = nil } return } @@ -166,10 +166,10 @@ func (c *conn) asyncWrite(itf interface{}) error { func (c *conn) sendTo(buf []byte) error { c.loop.eventHandler.PreWrite(c) defer c.loop.eventHandler.AfterWrite(c, buf) - if c.sa == nil { + if c.peer == nil { return unix.Send(c.fd, buf, 0) } - return unix.Sendto(c.fd, buf, 0, c.sa) + return unix.Sendto(c.fd, buf, 0, c.peer) } // ================================== Non-concurrency-safe API's ================================== diff --git a/eventloop_unix.go b/eventloop_unix.go index e79ffb224..3885c6fd2 100644 --- a/eventloop_unix.go +++ b/eventloop_unix.go @@ -35,15 +35,16 @@ import ( ) type eventloop struct { - ln *listener // listener - idx int // loop index in the server loops list - svr *server // server in loop - poller *netpoll.Poller // epoll or kqueue - buffer []byte // read packet buffer whose capacity is set by user, default value is 64KB - connCount int32 // number of active connections in event-loop - udpSockets map[int]*conn // UDP socket map: fd -> conn - connections map[int]*conn // TCP connection map: fd -> conn - eventHandler EventHandler // user eventHandler + ln *listener // listener + idx int // loop index in the server loops list + svr *server // server in loop + poller *netpoll.Poller // epoll or kqueue + buffer []byte // read packet buffer whose capacity is set by user, default value is 64KB + connCount int32 // number of active connections in event-loop + connections map[int]*conn // TCP connection map: fd -> conn + eventHandler EventHandler // user eventHandler + clientUDPSockets map[int]*conn // client-side UDP socket map: fd -> conn + serverUDPSockets map[unix.Sockaddr]*conn // server-side UDP socket map: Sockaddr -> conn } func (el *eventloop) getLogger() logging.Logger { @@ -58,11 +59,17 @@ func (el *eventloop) loadConn() int32 { return atomic.LoadInt32(&el.connCount) } -func (el *eventloop) closeAllConns() { +func (el *eventloop) closeAllSockets() { // Close loops and all outstanding connections for _, c := range el.connections { _ = el.loopCloseConn(c, nil) } + for _, c := range el.clientUDPSockets { + c.releaseUDP() + } + for _, c := range el.serverUDPSockets { + c.releaseUDP() + } } func (el *eventloop) loopRegister(itf interface{}) error { @@ -76,7 +83,7 @@ func (el *eventloop) loopRegister(itf interface{}) error { c.releaseUDP() return err } - el.udpSockets[c.fd] = c + el.clientUDPSockets[c.fd] = c return nil } if err := el.poller.AddRead(c.pollAttachment); err != nil { @@ -277,11 +284,15 @@ func (el *eventloop) loopReadUDP(fd int) error { return fmt.Errorf("failed to read UDP packet from fd=%d in event-loop(%d), %v", fd, el.idx, os.NewSyscallError("recvfrom", err)) } - c := el.udpSockets[fd] - var oneOff bool - if c == nil { - c = newUDPConn(fd, el, el.ln.lnaddr, sa, false) - oneOff = true + var c *conn + if fd == el.ln.fd { + c = el.serverUDPSockets[sa] + if c == nil { + c = newUDPConn(fd, el, el.ln.lnaddr, sa, false) + el.serverUDPSockets[sa] = c + } + } else { + c = el.clientUDPSockets[fd] } out, action := el.eventHandler.React(el.buffer[:n], c) if out != nil { @@ -290,9 +301,5 @@ func (el *eventloop) loopReadUDP(fd int) error { if action == Shutdown { return gerrors.ErrServerShutdown } - if oneOff { - c.releaseUDP() - } - return nil } diff --git a/reactor_default_bsd.go b/reactor_default_bsd.go index 9f751409c..7b691a1ac 100644 --- a/reactor_default_bsd.go +++ b/reactor_default_bsd.go @@ -48,7 +48,7 @@ func (el *eventloop) activateSubReactor(lockOSThread bool) { } defer func() { - el.closeAllConns() + el.closeAllSockets() el.svr.signalShutdown() }() @@ -81,7 +81,7 @@ func (el *eventloop) loopRun(lockOSThread bool) { } defer func() { - el.closeAllConns() + el.closeAllSockets() el.ln.close() el.svr.signalShutdown() }() diff --git a/reactor_default_linux.go b/reactor_default_linux.go index 160d4e5db..666e232ae 100644 --- a/reactor_default_linux.go +++ b/reactor_default_linux.go @@ -47,7 +47,7 @@ func (el *eventloop) activateSubReactor(lockOSThread bool) { } defer func() { - el.closeAllConns() + el.closeAllSockets() el.svr.signalShutdown() }() @@ -96,7 +96,7 @@ func (el *eventloop) loopRun(lockOSThread bool) { } defer func() { - el.closeAllConns() + el.closeAllSockets() el.ln.close() el.svr.signalShutdown() }() diff --git a/reactor_optimized_bsd.go b/reactor_optimized_bsd.go index f76f7ecd2..95a92a33e 100644 --- a/reactor_optimized_bsd.go +++ b/reactor_optimized_bsd.go @@ -47,7 +47,7 @@ func (el *eventloop) activateSubReactor(lockOSThread bool) { } defer func() { - el.closeAllConns() + el.closeAllSockets() el.svr.signalShutdown() }() @@ -66,7 +66,7 @@ func (el *eventloop) loopRun(lockOSThread bool) { } defer func() { - el.closeAllConns() + el.closeAllSockets() el.ln.close() el.svr.signalShutdown() }() diff --git a/reactor_optimized_linux.go b/reactor_optimized_linux.go index 9547c0ed0..a6cc6950b 100644 --- a/reactor_optimized_linux.go +++ b/reactor_optimized_linux.go @@ -46,7 +46,7 @@ func (el *eventloop) activateSubReactor(lockOSThread bool) { } defer func() { - el.closeAllConns() + el.closeAllSockets() el.svr.signalShutdown() }() @@ -65,7 +65,7 @@ func (el *eventloop) loopRun(lockOSThread bool) { } defer func() { - el.closeAllConns() + el.closeAllSockets() el.ln.close() el.svr.signalShutdown() }() diff --git a/server_unix.go b/server_unix.go index 38ac874c8..1e2107eba 100644 --- a/server_unix.go +++ b/server_unix.go @@ -24,6 +24,8 @@ import ( "sync" "sync/atomic" + "golang.org/x/sys/unix" + "github.com/panjf2000/gnet/internal/netpoll" "github.com/panjf2000/gnet/pkg/errors" ) @@ -109,6 +111,7 @@ func (svr *server) activateEventLoops(numEventLoop int) (err error) { el.svr = svr el.poller = p el.buffer = make([]byte, svr.opts.ReadBufferCap) + el.serverUDPSockets = make(map[unix.Sockaddr]*conn) el.connections = make(map[int]*conn) el.eventHandler = svr.eventHandler if err = el.poller.AddRead(el.ln.packPollAttachment(el.loopAccept)); err != nil { @@ -141,6 +144,7 @@ func (svr *server) activateReactors(numEventLoop int) error { el.svr = svr el.poller = p el.buffer = make([]byte, svr.opts.ReadBufferCap) + el.serverUDPSockets = make(map[unix.Sockaddr]*conn) el.connections = make(map[int]*conn) el.eventHandler = svr.eventHandler svr.lb.register(el)