From 60ccf6af26e6574f4a66ba1e0784ef40c91159b5 Mon Sep 17 00:00:00 2001 From: Damian Peckett Date: Wed, 21 Feb 2024 22:20:09 +0100 Subject: [PATCH] refactor: use synchronous io in order to improve perf --- benchmark/main.go | 177 ++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 7 ++ go.sum | 17 +++++ homa.go | 16 ++--- socket.go | 157 +++++++++++----------------------------- socket_test.go | 29 +++----- util.go | 68 ++++++++++++++++-- 7 files changed, 322 insertions(+), 149 deletions(-) create mode 100644 benchmark/main.go diff --git a/benchmark/main.go b/benchmark/main.go new file mode 100644 index 0000000..c1e6b43 --- /dev/null +++ b/benchmark/main.go @@ -0,0 +1,177 @@ +/* SPDX-License-Identifier: ISC + * + * Copyright (c) 2019-2024 Stanford University + * Copyright (c) 2024 Damian Peckett + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +package main + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "net" + "runtime" + + "github.com/cheggaaa/pb/v3" + "github.com/dpeckett/go-homa" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" +) + +func main() { + serverAddr, err := net.ResolveUDPAddr("udp", "localhost:0") + if err != nil { + log.Fatalf("could not resolve server address: %v", err) + } + + serverSock, err := homa.NewSocket(serverAddr) + if err != nil { + log.Fatalf("could not create server socket: %v", err) + } + defer serverSock.Close() + + var serverGroup errgroup.Group + + nCPUs := runtime.GOMAXPROCS(0) + + nReceivers := nCPUs / 2 + for i := 0; i < nReceivers; i++ { + serverGroup.Go(func() error { + for { + msg, err := serverSock.Recv() + if err != nil { + return fmt.Errorf("could not receive message: %w", err) + } + + data, err := io.ReadAll(msg) + if err != nil { + return fmt.Errorf("could not read message: %w", err) + } + + if string(data) != "PING" { + return fmt.Errorf("unexpected message: %s", data) + } + + if err := msg.Close(); err != nil { + return fmt.Errorf("could not close message: %w", err) + } + + err = serverSock.Reply(msg.PeerAddr(), msg.ID(), []byte("PONG")) + if err != nil { + return fmt.Errorf("could not send reply: %w", err) + } + } + }) + } + + go func() { + if err := serverGroup.Wait(); err != nil && !errors.Is(err, net.ErrClosed) { + log.Fatalf("error: %v", err) + } + }() + + var senderGroup errgroup.Group + + const ( + totalMessages = 1000000 + maxOutstandingMessages = 100 + ) + sem := semaphore.NewWeighted(int64(maxOutstandingMessages)) + + bar := pb.StartNew(totalMessages) + + nSenders := nCPUs / 2 + for i := 0; i < nSenders; i++ { + senderGroup.Go(func() error { + senderAddr, err := net.ResolveUDPAddr("udp", "localhost:0") + if err != nil { + return fmt.Errorf("could not resolve sender address: %w", err) + } + + senderSock, err := homa.NewSocket(senderAddr) + if err != nil { + return fmt.Errorf("could not create sender socket: %w", err) + } + + var g errgroup.Group + + nMessages := totalMessages / nSenders + + g.Go(func() error { + defer senderSock.Close() + + for i := 0; i < nMessages; i++ { + msg, err := senderSock.Recv() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + + return fmt.Errorf("could not receive reply: %w", err) + } + + data, err := io.ReadAll(msg) + if err != nil { + return fmt.Errorf("could not read reply: %w", err) + } + + if string(data) != "PONG" { + return fmt.Errorf("unexpected reply: %s", data) + } + + if err := msg.Close(); err != nil { + return fmt.Errorf("could not close reply: %w", err) + } + + sem.Release(1) + + bar.Increment() + } + + return nil + }) + + g.Go(func() error { + for i := 0; i < nMessages; i++ { + if err := sem.Acquire(context.Background(), 1); err != nil { + return fmt.Errorf("failed to acquire semaphore: %w", err) + } + + _, err := senderSock.Send(serverSock.LocalAddr(), []byte("PING"), 0) + if err != nil { + sem.Release(1) + + return fmt.Errorf("could not send message: %w", err) + } + } + + return nil + }) + + return g.Wait() + }) + } + + if err := senderGroup.Wait(); err != nil { + log.Fatalf("error: %v", err) + } + + if err := serverSock.Close(); err != nil { + log.Fatalf("could not close server socket: %v", err) + } +} diff --git a/go.mod b/go.mod index d3c5237..d1b565c 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/dpeckett/go-homa go 1.21.0 require ( + github.com/cheggaaa/pb/v3 v3.1.5 github.com/daedaluz/goioctl v0.0.0-20220112121310-eef48b7845b0 github.com/stretchr/testify v1.8.4 golang.org/x/sync v0.6.0 @@ -10,7 +11,13 @@ require ( ) require ( + github.com/VividCortex/ewma v1.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fatih/color v1.15.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mattn/go-runewidth v0.0.15 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rivo/uniseg v0.2.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f5d0584..bf7dd00 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,30 @@ +github.com/VividCortex/ewma v1.2.0 h1:f58SaIzcDXrSy3kWaHNvuJgJ3Nmz59Zji6XoJR/q1ow= +github.com/VividCortex/ewma v1.2.0/go.mod h1:nz4BbCtbLyFDeC9SUHbtcT5644juEuWfUAUnGx7j5l4= +github.com/cheggaaa/pb/v3 v3.1.5 h1:QuuUzeM2WsAqG2gMqtzaWithDJv0i+i6UlnwSCI4QLk= +github.com/cheggaaa/pb/v3 v3.1.5/go.mod h1:CrxkeghYTXi1lQBEI7jSn+3svI3cuc19haAj6jM60XI= github.com/daedaluz/goioctl v0.0.0-20220112121310-eef48b7845b0 h1:tLanypj7anfk592ujzQ4RrZrvy4KrQQ+ozdZ2usuloM= github.com/daedaluz/goioctl v0.0.0-20220112121310-eef48b7845b0/go.mod h1:NbO2vzbi679q0yLroSQd/T4/NnTvpQgHWs7tgZCjiv8= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= +github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/homa.go b/homa.go index bf6eac4..ae2fb0e 100644 --- a/homa.go +++ b/homa.go @@ -58,10 +58,10 @@ type SendmsgArgs struct { CompletionCookie uint64 } -// SendmsgArgsFromBytes deserializes a sendmsgArgs from a byte slice. +// sendmsgArgsFromBytes deserializes a sendmsgArgs from a byte slice. // We implement our own deserialization method here because the Go doesn't support packed structs // and binary.Read uses reflection, which is very slow. -func SendmsgArgsFromBytes(buf []byte) SendmsgArgs { +func sendmsgArgsFromBytes(buf []byte) SendmsgArgs { var args SendmsgArgs args.ID = binary.NativeEndian.Uint64(buf[0:8]) args.CompletionCookie = binary.NativeEndian.Uint64(buf[8:16]) @@ -69,10 +69,10 @@ func SendmsgArgsFromBytes(buf []byte) SendmsgArgs { return args } -// Bytes returns the byte representation of the sendmsgArgs, suitable for passing to the kernel. +// bytes returns the byte representation of the sendmsgArgs, suitable for passing to the kernel. // We implement our own serialization method here because the Go doesn't support packed structs // and binary.Write uses reflection, which is very slow. -func (s *SendmsgArgs) Bytes() []byte { +func (s *SendmsgArgs) bytes() []byte { var buf [16]byte binary.NativeEndian.PutUint64(buf[0:8], s.ID) binary.NativeEndian.PutUint64(buf[8:16], s.CompletionCookie) @@ -121,10 +121,10 @@ type RecvmsgArgs struct { BPageOffsets [HOMA_MAX_BPAGES]uint32 } -// RecvmsgArgsFromBytes deserializes a recvmsgArgs from a byte slice. +// recvmsgArgsFromBytes deserializes a recvmsgArgs from a byte slice. // We implement our own deserialization method here because the Go doesn't support packed structs // and binary.Read uses reflection, which is very slow. -func RecvmsgArgsFromBytes(buf []byte) RecvmsgArgs { +func recvmsgArgsFromBytes(buf []byte) RecvmsgArgs { var args RecvmsgArgs args.ID = binary.NativeEndian.Uint64(buf[0:8]) args.CompletionCookie = binary.NativeEndian.Uint64(buf[8:16]) @@ -138,10 +138,10 @@ func RecvmsgArgsFromBytes(buf []byte) RecvmsgArgs { return args } -// Bytes returns the byte representation of the recvmsgArgs, suitable for passing to the kernel. +// bytes returns the byte representation of the recvmsgArgs, suitable for passing to the kernel. // We implement our own serialization method here because the Go doesn't support packed structs // and binary.Write uses reflection, which is very slow. -func (r *RecvmsgArgs) Bytes() []byte { +func (r *RecvmsgArgs) bytes() []byte { var buf [120]byte binary.NativeEndian.PutUint64(buf[0:8], r.ID) binary.NativeEndian.PutUint64(buf[8:16], r.CompletionCookie) diff --git a/socket.go b/socket.go index 4530d56..a83daf6 100644 --- a/socket.go +++ b/socket.go @@ -19,7 +19,6 @@ package homa import ( - "context" "errors" "fmt" "net" @@ -30,9 +29,8 @@ import ( ) type Socket struct { - fd int - bp *BufferPool - dataAvailable chan struct{} + fd int + bp *BufferPool } func NewSocket(listenAddr net.Addr) (*Socket, error) { @@ -41,23 +39,21 @@ func NewSocket(listenAddr net.Addr) (*Socket, error) { return nil, fmt.Errorf("could not open homa socket: %w", err) } - var rawListenAddr unix.Sockaddr - { - udpAddr, ok := listenAddr.(*net.UDPAddr) - if !ok { - return nil, fmt.Errorf("unsupported address type") - } - - if ipv4 := udpAddr.IP.To4(); ipv4 != nil { - rawListenAddr = &unix.SockaddrInet4{Port: udpAddr.Port, Addr: [4]byte(ipv4)} - } else if ipv6 := udpAddr.IP.To16(); ipv6 != nil { - rawListenAddr = &unix.SockaddrInet6{Port: udpAddr.Port, Addr: [16]byte(ipv6)} + var rawSockAddr unix.Sockaddr + switch listenAddr := listenAddr.(type) { + case *net.UDPAddr: + if ipv4 := listenAddr.IP.To4(); ipv4 != nil { + rawSockAddr = &unix.SockaddrInet4{Port: listenAddr.Port, Addr: [4]byte(ipv4)} + } else if ipv6 := listenAddr.IP.To16(); ipv6 != nil { + rawSockAddr = &unix.SockaddrInet6{Port: listenAddr.Port, Addr: [16]byte(ipv6)} } else { return nil, fmt.Errorf("unsupported address family") } + default: + return nil, fmt.Errorf("unsupported address type") } - err = unix.Bind(fd, rawListenAddr) + err = unix.Bind(fd, rawSockAddr) if err != nil { _ = unix.Close(fd) @@ -81,47 +77,9 @@ func NewSocket(listenAddr net.Addr) (*Socket, error) { return nil, fmt.Errorf("could not set homa buffer: %w", err) } - epfd, err := unix.EpollCreate1(0) - if err != nil { - _ = unix.Close(fd) - - return nil, fmt.Errorf("could not create epoll: %w", err) - } - - event := &unix.EpollEvent{ - Events: unix.EPOLLIN, - Fd: int32(fd), - } - if err := unix.EpollCtl(epfd, unix.EPOLL_CTL_ADD, fd, event); err != nil { - _ = unix.Close(fd) - _ = unix.Close(epfd) - - return nil, fmt.Errorf("could not add epoll event: %w", err) - } - - dataAvailable := make(chan struct{}, 1) - - go func() { - defer close(dataAvailable) - defer unix.Close(epfd) - - events := make([]unix.EpollEvent, 1) - for { - n, err := unix.EpollWait(epfd, events, 10) - if err != nil && !errors.Is(err, unix.EINTR) { - return - } - - if n > 0 { - dataAvailable <- struct{}{} - } - } - }() - return &Socket{ - fd: fd, - bp: bp, - dataAvailable: dataAvailable, + fd: fd, + bp: bp, }, nil } @@ -161,38 +119,33 @@ func (s *Socket) LocalAddr() net.Addr { } // Recv waits for an incoming RPC and returns a message containing the RPC's data. -// The flags argument specifies the type of RPC to receive. It returns a message -// containing the RPC's data, or an error if the operation failed. -func (s *Socket) Recv(ctx context.Context) (*Message, error) { +// It returns a message containing the RPC's data, or an error if the operation failed. +func (s *Socket) Recv() (*Message, error) { args := RecvmsgArgs{ - Flags: HOMA_RECVMSG_REQUEST | HOMA_RECVMSG_RESPONSE | HOMA_RECVMSG_NONBLOCKING, + Flags: HOMA_RECVMSG_REQUEST | HOMA_RECVMSG_RESPONSE, } unusedBuffers := s.bp.getUnusedBuffers() args.NumBPages = uint32(len(unusedBuffers)) copy(args.BPageOffsets[:], unusedBuffers) - argsBytes := args.Bytes() - - length := -1 - for length == -1 { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-s.dataAvailable: - var err error - length, _, _, _, err = unix.Recvmsg(s.fd, nil, argsBytes, 0) - if err != nil { - if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) { - continue - } - - return nil, fmt.Errorf("could not receive message: %w", err) - } + argsBytes := args.bytes() + + recvHdr := unix.Msghdr{ + Control: &argsBytes[0], + Controllen: uint64(len(argsBytes)), + } + + length, err := recvmsg(s.fd, &recvHdr, 0) + if err != nil { + if errors.Is(err, unix.EBADF) { + return nil, net.ErrClosed } + + return nil, fmt.Errorf("could not receive message: %w", err) } - return NewMessage(s.bp, RecvmsgArgsFromBytes(argsBytes), int64(length)), nil + return NewMessage(s.bp, recvmsgArgsFromBytes(argsBytes), int64(length)), nil } // Send initiates an RPC by sending a request message to a server. @@ -203,16 +156,16 @@ func (s *Socket) Send(dstAddr net.Addr, message []byte, completionCookie uint64) CompletionCookie: completionCookie, } - argsBytes := args.Bytes() + argsBytes := args.bytes() - name, nameLen, err := toRawSockAddr(dstAddr) + rawSockAddr, rawSockAddrLen, err := toRawSockAddr(dstAddr) if err != nil { return 0, fmt.Errorf("could not convert address: %w", err) } hdr := &unix.Msghdr{ - Name: name, - Namelen: uint32(nameLen), + Name: (*byte)(rawSockAddr), + Namelen: rawSockAddrLen, Iov: &unix.Iovec{Base: &message[0], Len: uint64(len(message))}, Iovlen: 1, Control: &argsBytes[0], @@ -227,7 +180,7 @@ func (s *Socket) Send(dstAddr net.Addr, message []byte, completionCookie uint64) return 0, fmt.Errorf("could not send message: %w", err) } - args = SendmsgArgsFromBytes(argsBytes) + args = sendmsgArgsFromBytes(argsBytes) return args.ID, nil } @@ -240,16 +193,16 @@ func (s *Socket) Reply(dstAddr net.Addr, id uint64, message []byte) error { ID: id, } - argsBytes := args.Bytes() + argsBytes := args.bytes() - name, nameLen, err := toRawSockAddr(dstAddr) + rawSockAddr, rawSockAddrLen, err := toRawSockAddr(dstAddr) if err != nil { return fmt.Errorf("could not convert address: %w", err) } hdr := &unix.Msghdr{ - Name: name, - Namelen: uint32(nameLen), + Name: (*byte)(rawSockAddr), + Namelen: rawSockAddrLen, Iov: &unix.Iovec{Base: &message[0], Len: uint64(len(message))}, Iovlen: 1, Control: &argsBytes[0], @@ -279,33 +232,3 @@ func (s *Socket) Abort(id uint64, errorCode int32) error { return ioctl.Ioctl(uintptr(s.fd), HOMAIOCABORT, uintptr(unsafe.Pointer(&args))) } - -func setsockoptHomaBuf(fd int, args SetBufArgs) error { - _, _, errno := unix.Syscall6(unix.SYS_SETSOCKOPT, uintptr(fd), IPPROTO_HOMA, SO_HOMA_SET_BUF, uintptr(unsafe.Pointer(&args)), unsafe.Sizeof(args), 0) - if errno != 0 { - return errno - } - - return nil -} - -func toRawSockAddr(addr net.Addr) (*byte, int64, error) { - switch addr := addr.(type) { - case *net.UDPAddr: - if ipv4 := addr.IP.To4(); ipv4 != nil { - return (*byte)(unsafe.Pointer(&unix.RawSockaddrInet4{ - Family: unix.AF_INET, - Port: htons(uint16(addr.Port)), - Addr: [4]byte(ipv4), - })), 16, nil - } - - return (*byte)(unsafe.Pointer(&unix.RawSockaddrInet6{ - Family: unix.AF_INET6, - Port: htons(uint16(addr.Port)), - Addr: [16]byte(addr.IP), - })), 28, nil - default: - return nil, 0, fmt.Errorf("unsupported address type: %T", addr) - } -} diff --git a/socket_test.go b/socket_test.go index 1f5bca0..56a90ec 100644 --- a/socket_test.go +++ b/socket_test.go @@ -20,7 +20,6 @@ package homa_test import ( "bytes" - "context" "crypto/rand" "crypto/sha256" "errors" @@ -47,17 +46,19 @@ func TestHomaRPC(t *testing.T) { serverSock, err := homa.NewSocket(serverAddr) require.NoError(t, err) - ctx, cancel := context.WithCancel(context.Background()) + clientAddr, err := net.ResolveUDPAddr("udp", "localhost:0") + require.NoError(t, err) - g, ctx := errgroup.WithContext(ctx) + clientSock, err := homa.NewSocket(clientAddr) + require.NoError(t, err) - g.Go(func() error { - defer serverSock.Close() + var g errgroup.Group + g.Go(func() error { for { - msg, err := serverSock.Recv(ctx) + msg, err := serverSock.Recv() if err != nil { - if errors.Is(err, context.Canceled) { + if errors.Is(err, net.ErrClosed) { return nil } @@ -80,18 +81,8 @@ func TestHomaRPC(t *testing.T) { }) g.Go(func() error { - defer cancel() - - clientAddr, err := net.ResolveUDPAddr("udp", "localhost:0") - if err != nil { - return err - } - - clientSock, err := homa.NewSocket(clientAddr) - if err != nil { - return err - } defer clientSock.Close() + defer serverSock.Close() for i := 0; i < 100; i++ { size, err := rand.Int(rand.Reader, big.NewInt(homa.HOMA_MAX_MESSAGE_LENGTH-1)) @@ -115,7 +106,7 @@ func TestHomaRPC(t *testing.T) { return fmt.Errorf("expected message id > 0, got %d", id) } - msg, err := clientSock.Recv(ctx) + msg, err := clientSock.Recv() if err != nil { return err } diff --git a/util.go b/util.go index 50d4127..42fd497 100644 --- a/util.go +++ b/util.go @@ -20,6 +20,9 @@ package homa import ( "encoding/binary" + "errors" + "fmt" + "net" "unsafe" "golang.org/x/sys/unix" @@ -37,14 +40,69 @@ func ntohs(net uint16) uint16 { return binary.BigEndian.Uint16(unsafe.Slice((*byte)(unsafe.Pointer(&net)), 2)) } +// recvmsg is a wrapper around the recvmsg system call, this is not natively exposed to go but +// we need to make some tweaks to the msghdr struct so we'll define our own. +func recvmsg(s int, msg *unix.Msghdr, flags int) (n int, err error) { + err = unix.EINTR + for errors.Is(err, unix.EINTR) { + r0, _, e1 := unix.Syscall(unix.SYS_RECVMSG, uintptr(s), uintptr(unsafe.Pointer(msg)), uintptr(flags)) + n = int(r0) + if e1 != 0 { + err = unix.Errno(e1) + } else { + err = nil + } + } + return +} + // sendmsg is a wrapper around the sendmsg system call, this is not natively exposed to go but // we need to make some tweaks to the msghdr struct so we'll define our own. func sendmsg(s int, msg *unix.Msghdr, flags int) (n int, err error) { - r0, _, e1 := unix.Syscall(unix.SYS_SENDMSG, uintptr(s), uintptr(unsafe.Pointer(msg)), uintptr(flags)) - n = int(r0) - if e1 != 0 { - err = unix.Errno(e1) - return + err = unix.EINTR + for errors.Is(err, unix.EINTR) { + r0, _, e1 := unix.Syscall(unix.SYS_SENDMSG, uintptr(s), uintptr(unsafe.Pointer(msg)), uintptr(flags)) + n = int(r0) + if e1 != 0 { + err = unix.Errno(e1) + } else { + err = nil + } } return } + +func setsockoptHomaBuf(fd int, args SetBufArgs) (err error) { + err = unix.EINTR + for errors.Is(err, unix.EINTR) { + _, _, e1 := unix.Syscall6(unix.SYS_SETSOCKOPT, uintptr(fd), IPPROTO_HOMA, SO_HOMA_SET_BUF, uintptr(unsafe.Pointer(&args)), unsafe.Sizeof(args), 0) + if e1 != 0 { + err = unix.Errno(e1) + } else { + err = nil + } + } + return +} + +// toRawSockAddr converts a net.Addr to a raw socket address. +func toRawSockAddr(addr net.Addr) (unsafe.Pointer, uint32, error) { + switch addr := addr.(type) { + case *net.UDPAddr: + if ipv4 := addr.IP.To4(); ipv4 != nil { + return unsafe.Pointer(&unix.RawSockaddrInet4{ + Family: unix.AF_INET, + Port: htons(uint16(addr.Port)), + Addr: [4]byte(ipv4), + }), unix.SizeofSockaddrInet4, nil + } + + return unsafe.Pointer(&unix.RawSockaddrInet6{ + Family: unix.AF_INET6, + Port: htons(uint16(addr.Port)), + Addr: [16]byte(addr.IP), + }), unix.SizeofSockaddrInet6, nil + default: + return nil, 0, fmt.Errorf("unsupported address type: %T", addr) + } +}