Skip to content

Commit

Permalink
[api] Visualize sam2 output for Sam2ServingTranslator (#3494)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Oct 4, 2024
1 parent f88f2d6 commit d3c69e2
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.Sam2Translator.Sam2Input;
import ai.djl.ndarray.BytesSupplier;
import ai.djl.ndarray.NDList;
Expand All @@ -23,7 +25,13 @@
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;

import org.apache.commons.codec.binary.Base64OutputStream;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.LinkedHashMap;
import java.util.Map;

/** A {@link Translator} that can serve SAM2 model. */
public class Sam2ServingTranslator implements Translator<Input, Output> {
Expand All @@ -47,11 +55,29 @@ public Batchifier getBatchifier() {

/** {@inheritDoc} */
@Override
public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
public Output processOutput(TranslatorContext ctx, NDList list) throws IOException {
Output output = new Output();
Sam2Input sam2 = (Sam2Input) ctx.getAttachment("input");
output.addProperty("Content-Type", "application/json");
DetectedObjects obj = translator.processOutput(ctx, list);
output.add(BytesSupplier.wrapAsJson(obj));
DetectedObjects detection = translator.processOutput(ctx, list);
Map<String, Object> ret = new LinkedHashMap<>(); // NOPMD
ret.put("result", detection);
if (sam2.isVisualize()) {
Image img = sam2.getImage();
img.drawBoundingBoxes(detection, 0.8f);
img.drawMarks(sam2.getPoints());
for (Rectangle rect : sam2.getBoxes()) {
img.drawRectangle(rect, 0xff0000, 6);
}
ByteArrayOutputStream os = new ByteArrayOutputStream();
os.write("data:image/png;base64,".getBytes(StandardCharsets.UTF_8));
Base64OutputStream bos = new Base64OutputStream(os, true, 0, null);
img.save(bos, "png");
bos.close();
os.close();
ret.put("image", os.toString(StandardCharsets.UTF_8.name()));
}
output.add(BytesSupplier.wrapAsJson(ret));
return output;
}

Expand All @@ -64,6 +90,7 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception
throw new TranslateException("Input data is empty.");
}
Sam2Input sam2 = Sam2Input.fromJson(data.getAsString());
ctx.setAttachment("input", sam2);
return translator.processInput(ctx, sam2);
} catch (IOException e) {
throw new TranslateException("Input is not an Image data type", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ public static final class Sam2Input {
private Image image;
private Point[] points;
private int[] labels;
private boolean visualize;

/**
* Constructs a {@code Sam2Input} instance.
Expand All @@ -214,9 +215,22 @@ public static final class Sam2Input {
* @param labels the labels for the locations (0: background, 1: foreground)
*/
public Sam2Input(Image image, Point[] points, int[] labels) {
this(image, points, labels, false);
}

/**
* Constructs a {@code Sam2Input} instance.
*
* @param image the image
* @param points the locations on the image
* @param labels the labels for the locations (0: background, 1: foreground)
* @param visualize true if output visualized image
*/
public Sam2Input(Image image, Point[] points, int[] labels, boolean visualize) {
this.image = image;
this.points = points;
this.labels = labels;
this.visualize = visualize;
}

/**
Expand All @@ -228,6 +242,15 @@ public Image getImage() {
return image;
}

/**
* Returns {@code true} if output visualized image.
*
* @return {@code true} if output visualized image
*/
public boolean isVisualize() {
return visualize;
}

/**
* Returns the locations.
*
Expand Down Expand Up @@ -288,13 +311,16 @@ float[][] getLabels() {
public static Sam2Input fromJson(String input) throws IOException {
Prompt prompt = JsonUtils.GSON.fromJson(input, Prompt.class);
if (prompt.image == null) {
throw new IllegalArgumentException("Missing url value");
throw new IllegalArgumentException("Missing image value");
}
if (prompt.prompt == null || prompt.prompt.length == 0) {
throw new IllegalArgumentException("Missing prompt value");
}
Image image = ImageFactory.getInstance().fromUrl(prompt.image);
Builder builder = builder(image);
if (prompt.visualize) {
builder.visualize();
}
for (Location location : prompt.prompt) {
int[] data = location.data;
if ("point".equals(location.type)) {
Expand Down Expand Up @@ -322,6 +348,7 @@ public static final class Builder {
private Image image;
private List<Point> points;
private List<Integer> labels;
private boolean visualize;

Builder(Image image) {
this.image = image;
Expand Down Expand Up @@ -380,6 +407,16 @@ public Builder addBox(int x, int y, int right, int bottom) {
return this;
}

/**
* Sets the visualize for the {@code Sam2Input}.
*
* @return the builder
*/
public Builder visualize() {
visualize = true;
return this;
}

/**
* Builds the {@code Sam2Input}.
*
Expand All @@ -388,7 +425,7 @@ public Builder addBox(int x, int y, int right, int bottom) {
public Sam2Input build() {
Point[] location = points.toArray(new Point[0]);
int[] array = labels.stream().mapToInt(Integer::intValue).toArray();
return new Sam2Input(image, location, array);
return new Sam2Input(image, location, array, visualize);
}
}

Expand All @@ -413,6 +450,7 @@ public void setLabel(int label) {
private static final class Prompt {
String image;
Location[] prompt;
boolean visualize;

public void setImage(String image) {
this.image = image;
Expand All @@ -421,6 +459,10 @@ public void setImage(String image) {
public void setPrompt(Location[] prompt) {
this.prompt = prompt;
}

public void setVisualize(boolean visualize) {
this.visualize = visualize;
}
}
}
}
6 changes: 1 addition & 5 deletions api/src/main/java/ai/djl/util/JsonSerializable.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,7 @@ default ByteBuffer toByteBuffer() {
return ByteBuffer.wrap(toJson().getBytes(StandardCharsets.UTF_8));
}

/**
* Serializes the object to the {@code JsonElement}.
*
* @return the {@code JsonElement}
*/
/** {@inheritDoc} */
JsonElement serialize();

/** A customized Gson serializer to serialize the {@code Segmentation} object. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,17 @@ public void test() throws IOException {
"{\"image\": \""
+ file.toUri().toURL()
+ "\",\n"
+ "\"visualize\": true,\n"
+ "\"prompt\": [\n"
+ " {\"type\": \"point\", \"data\": [575, 750], \"label\": 0},\n"
+ " {\"type\": \"rectangle\", \"data\": [425, 600, 700, 875]}\n"
+ "]}";
Sam2Input input = Sam2Input.fromJson(json);
Assert.assertTrue(input.isVisualize());
Assert.assertEquals(input.getPoints().size(), 1);
Assert.assertEquals(input.getBoxes().size(), 1);

input = Sam2Input.builder(img).addPoint(0, 1).addBox(0, 0, 1, 1).build();
input = Sam2Input.builder(img).visualize().addPoint(0, 1).addBox(0, 0, 1, 1).build();
Assert.assertEquals(input.getPoints().size(), 1);
Assert.assertEquals(input.getBoxes().size(), 1);
}
Expand Down

0 comments on commit d3c69e2

Please sign in to comment.