Skip to content

Commit

Permalink
WebSockets Next: fix OnOpen callback that returns Buffer/byte[]
Browse files Browse the repository at this point in the history
- this callback should send a binary message
  • Loading branch information
mkouba committed Jun 20, 2024
1 parent ed77ee2 commit 6258848
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,8 @@ private static ResultHandle uniOnFailureDoOnError(ResultHandle endpointThis, Byt
private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCreator method, Callback callback,
GlobalErrorHandlersBuildItem globalErrorHandlers, WebSocketEndpointBuildItem endpoint,
ResultHandle value) {
if (callback.acceptsBinaryMessage()) {
if (callback.acceptsBinaryMessage()
|| isOnOpenWithBinaryReturnType(callback)) {
// ----------------------
// === Binary message ===
// ----------------------
Expand Down Expand Up @@ -1119,7 +1120,7 @@ private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCre
value,
fun.getInstance());
} else {
// return sendBinary(buffer,broadcast);
// return sendBinary(encodeBuffer(b),broadcast);
ResultHandle buffer = encodeBuffer(method, callback.returnType(), value, endpointThis, callback);
return method.invokeVirtualMethod(MethodDescriptor.ofMethod(WebSocketEndpointBase.class,
"sendBinary", Uni.class, Buffer.class, boolean.class), endpointThis, buffer,
Expand Down Expand Up @@ -1407,4 +1408,16 @@ static boolean isByteArray(Type type) {
static String methodToString(MethodInfo method) {
return method.declaringClass().name() + "#" + method.name() + "()";
}

private static boolean isOnOpenWithBinaryReturnType(Callback callback) {
if (callback.isOnOpen()) {
Type returnType = callback.returnType();
if (callback.isReturnTypeUni() || callback.isReturnTypeMulti()) {
returnType = callback.returnType().asParameterizedType().arguments().get(0);
}
return WebSocketDotNames.BUFFER.equals(returnType.name())
|| (returnType.kind() == Kind.ARRAY && PrimitiveType.BYTE.equals(returnType.asArrayType().constituent()));
}
return false;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package io.quarkus.websockets.next.test.onopenreturntypes;

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.net.URI;

import jakarta.inject.Inject;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.test.utils.WSClient;
import io.quarkus.websockets.next.test.utils.WSClient.ReceiverMode;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;

public class OnOpenReturnTypesTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(EndpointText.class, EndpointBinary.class, WSClient.class);
});

@Inject
Vertx vertx;

@TestHTTPResource("end-text")
URI endText;

@TestHTTPResource("end-binary")
URI endBinary;

@Test
void testReturnTypes() throws Exception {
try (WSClient textClient = WSClient.create(vertx, ReceiverMode.TEXT).connect(endText)) {
textClient.waitForMessages(1);
assertEquals("/end-text", textClient.getMessages().get(0).toString());
}
try (WSClient binaryClient = WSClient.create(vertx, ReceiverMode.BINARY).connect(endBinary)) {
binaryClient.waitForMessages(1);
assertEquals("/end-binary", binaryClient.getMessages().get(0).toString());
}
}

@WebSocket(path = "/end-text")
public static class EndpointText {

@OnOpen
String open(WebSocketConnection connection) {
return connection.handshakeRequest().path();
}

}

@WebSocket(path = "/end-binary")
public static class EndpointBinary {

@OnOpen
Buffer open(WebSocketConnection connection) {
return Buffer.buffer(connection.handshakeRequest().path());
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,26 @@

public class WSClient implements AutoCloseable {

public static WSClient create(Vertx vertx) {
return new WSClient(vertx);
}

public static WSClient create(Vertx vertx, ReceiverMode mode) {
return new WSClient(vertx, mode);
}

private final WebSocketClient client;
private AtomicReference<WebSocket> socket = new AtomicReference<>();
private List<Buffer> messages = new CopyOnWriteArrayList<>();
private final ReceiverMode mode;

public WSClient(Vertx vertx) {
public WSClient(Vertx vertx, ReceiverMode mode) {
this.client = vertx.createWebSocketClient();
this.mode = mode;
}

public static WSClient create(Vertx vertx) {
return new WSClient(vertx);
public WSClient(Vertx vertx) {
this(vertx, ReceiverMode.ALL);
}

public static URI toWS(URI uri, String path) {
Expand All @@ -52,7 +62,19 @@ public WSClient connect(WebSocketConnectOptions options, URI url) {
uri.append("?").append(url.getQuery());
}
ClientWebSocket webSocket = client.webSocket();
webSocket.handler(b -> messages.add(b));
switch (mode) {
case ALL:
webSocket.handler(b -> messages.add(b));
break;
case BINARY:
webSocket.binaryMessageHandler(b -> messages.add(b));
break;
case TEXT:
webSocket.textMessageHandler(b -> messages.add(Buffer.buffer(b)));
break;
default:
throw new IllegalStateException();
}
await(webSocket.connect(options.setPort(url.getPort()).setHost(url.getHost()).setURI(uri.toString())));
var prev = socket.getAndSet(webSocket);
if (prev != null) {
Expand Down Expand Up @@ -135,4 +157,10 @@ public void close() {
disconnect();
}

public enum ReceiverMode {
BINARY,
TEXT,
ALL
}

}

0 comments on commit 6258848

Please sign in to comment.