Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelsavara committed Aug 29, 2024
1 parent ab8016c commit 7330d07
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, IL
int startIndex = bufferIndex, startOffset = offset;

int maxBuffers = buffers.Count - startIndex;
#if TARGET_WASI // WASI doesn't have iovecs and recvmsg in preview2
maxBuffers = Math.Max(maxBuffers, 1);
#endif
bool allocOnStack = maxBuffers <= IovStackThreshold;
Span<GCHandle> handles = allocOnStack ? stackalloc GCHandle[IovStackThreshold] : new GCHandle[maxBuffers];
Span<Interop.Sys.IOVector> iovecs = allocOnStack ? stackalloc Interop.Sys.IOVector[IovStackThreshold] : new Interop.Sys.IOVector[maxBuffers];
Expand Down Expand Up @@ -376,6 +379,9 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags,
Debug.Assert(socket.IsSocket);

int maxBuffers = buffers.Count;
#if TARGET_WASI // WASI doesn't have iovecs and recvmsg in preview2
maxBuffers = Math.Max(maxBuffers, 1);
#endif
bool allocOnStack = maxBuffers <= IovStackThreshold;

// When there are many buffers, reduce the number of pinned buffers based on available bytes.
Expand Down Expand Up @@ -532,6 +538,10 @@ private static unsafe int SysReceiveMessageFrom(
Debug.Assert(socket.IsSocket);

int buffersCount = buffers.Count;
#if TARGET_WASI // WASI doesn't have iovecs and sendmsg in preview2
buffersCount = Math.Max(buffersCount, 1);
#endif

bool allocOnStack = buffersCount <= IovStackThreshold;
Span<GCHandle> handles = allocOnStack ? stackalloc GCHandle[IovStackThreshold] : new GCHandle[buffersCount];
Span<Interop.Sys.IOVector> iovecs = allocOnStack ? stackalloc Interop.Sys.IOVector[IovStackThreshold] : new Interop.Sys.IOVector[buffersCount];
Expand All @@ -554,17 +564,22 @@ private static unsafe int SysReceiveMessageFrom(
fixed (byte* sockAddr = socketAddress)
fixed (Interop.Sys.IOVector* iov = iovecs)
{
#if !TARGET_WASI // WASI doesn't have msg_control and sendmsg in preview2

int cmsgBufferLen = Interop.Sys.GetControlMessageBufferSize(Convert.ToInt32(isIPv4), Convert.ToInt32(isIPv6));
byte* cmsgBuffer = stackalloc byte[cmsgBufferLen];
#endif

var messageHeader = new Interop.Sys.MessageHeader
{
SocketAddress = sockAddr,
SocketAddressLen = socketAddress.Length,
IOVectors = iov,
IOVectorCount = iovCount,
#if !TARGET_WASI // WASI doesn't have msg_control and sendmsg in preview2
ControlBuffer = cmsgBuffer,
ControlBufferLen = cmsgBufferLen
#endif
};

long received = 0;
Expand Down
28 changes: 15 additions & 13 deletions src/native/libs/System.Native/pal_networking.c
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,7 @@ SystemNative_SetIPv6Address(uint8_t* socketAddress, int32_t socketAddressLen, ui
return Error_SUCCESS;
}

#if !defined(TARGET_WASI)
static int8_t IsStreamSocket(int socket)
{
int type;
Expand All @@ -939,6 +940,7 @@ static void ConvertMessageHeaderToMsghdr(struct msghdr* header, const MessageHea
header->msg_controllen = (uint32_t)messageHeader->ControlBufferLen;
header->msg_flags = 0;
}
#endif // !TARGET_WASI

int32_t SystemNative_GetControlMessageBufferSize(int32_t isIPv4, int32_t isIPv6)
{
Expand Down Expand Up @@ -1392,11 +1394,9 @@ static int8_t ConvertSocketFlagsPalToPlatform(int32_t palFlags, int* platformFla
return true;
}

#if !defined(TARGET_WASI)
static int32_t ConvertSocketFlagsPlatformToPal(int platformFlags)
{
#if defined(TARGET_WASI)
return 0;
#else // TARGET_WASI
const int SupportedFlagsMask = MSG_OOB | MSG_DONTROUTE | MSG_TRUNC | MSG_CTRUNC;

platformFlags &= SupportedFlagsMask;
Expand All @@ -1405,8 +1405,8 @@ static int32_t ConvertSocketFlagsPlatformToPal(int platformFlags)
((platformFlags & MSG_DONTROUTE) == 0 ? 0 : SocketFlags_MSG_DONTROUTE) |
((platformFlags & MSG_TRUNC) == 0 ? 0 : SocketFlags_MSG_TRUNC) |
((platformFlags & MSG_CTRUNC) == 0 ? 0 : SocketFlags_MSG_CTRUNC);
#endif // !TARGET_WASI
}
#endif // !TARGET_WASI

int32_t SystemNative_Receive(intptr_t socket, void* buffer, int32_t bufferLen, int32_t flags, int32_t* received)
{
Expand Down Expand Up @@ -1506,18 +1506,19 @@ int32_t SystemNative_ReceiveMessage(intptr_t socket, MessageHeader* messageHeade
return Error_ENOTSUP;
}

ssize_t res;
#if !defined(TARGET_WASI)
struct msghdr header;
ConvertMessageHeaderToMsghdr(&header, messageHeader, fd);

ssize_t res;
#if defined(TARGET_WASI)
// ssize_t recvfrom (int sockfd, void *buf, size_t len, int flags, struct sockaddr *src_addr, socklen_t *addrlen);
// TODO
while ((res = recvfrom(fd, NULL, 0, socketFlags, NULL, 0)) < 0 && errno == EINTR);
#else
while ((res = recvmsg(fd, &header, socketFlags)) < 0 && errno == EINTR);
#else
// we will only use 0th buffer
struct iovec* msg_iov = (struct iovec*)messageHeader->IOVectors;
while ((res = recvfrom(fd, msg_iov[0].iov_base, msg_iov[0].iov_len, socketFlags, (sockaddr *)messageHeader->SocketAddress, (socklen_t*) &(messageHeader->SocketAddressLen))) < 0 && errno == EINTR);
#endif // !TARGET_WASI

#if !defined(TARGET_WASI)
assert(header.msg_name == messageHeader->SocketAddress); // should still be the same location as set in ConvertMessageHeaderToMsghdr
assert(header.msg_control == messageHeader->ControlBuffer);

Expand All @@ -1528,6 +1529,7 @@ int32_t SystemNative_ReceiveMessage(intptr_t socket, MessageHeader* messageHeade
messageHeader->ControlBufferLen = Min((int32_t)header.msg_controllen, messageHeader->ControlBufferLen);

messageHeader->Flags = ConvertSocketFlagsPlatformToPal(header.msg_flags);
#endif // !TARGET_WASI

if (res != -1)
{
Expand Down Expand Up @@ -1605,9 +1607,9 @@ int32_t SystemNative_SendMessage(intptr_t socket, MessageHeader* messageHeader,
while ((res = sendmsg(fd, &header, socketFlags)) < 0 && errno == EINTR);
#endif
#else // TARGET_WASI
// ssize_t sendto(int sockfd, const void *buf, size_t len, int flags, const struct sockaddr *dest_addr, socklen_t addrlen);
// TODO
while ((res = sendto(fd, NULL, 0, socketFlags, NULL, 0)) < 0 && errno == EINTR);
// we will only use 0th buffer
struct iovec* msg_iov = (struct iovec*)messageHeader->IOVectors;
while ((res = sendto(fd, msg_iov[0].iov_base, msg_iov[0].iov_len, socketFlags, (sockaddr *)messageHeader->SocketAddress, (socklen_t)messageHeader->SocketAddressLen)) < 0 && errno == EINTR);
#endif // !TARGET_WASI

if (res != -1)
Expand Down

0 comments on commit 7330d07

Please sign in to comment.