Skip to content

Commit

Permalink
[api] Use encoder/decoder for Segment anython2 translator
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Sep 30, 2024
1 parent 2071a1c commit f703b93
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 11 deletions.
109 changes: 105 additions & 4 deletions api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
*/
package ai.djl.modality.cv.translator;

import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.BoundingBox;
Expand All @@ -27,15 +30,21 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.UUID;

/** A {@link Translator} that handles mask generation task. */
public class Sam2Translator implements NoBatchifyTranslator<Sam2Input, DetectedObjects> {
Expand All @@ -44,13 +53,38 @@ public class Sam2Translator implements NoBatchifyTranslator<Sam2Input, DetectedO
private static final float[] STD = {0.229f, 0.224f, 0.225f};

private Pipeline pipeline;
private Predictor<NDList, NDList> predictor;
private String encoderPath;

/** Constructs a {@code Sam2Translator} instance. */
public Sam2Translator() {
public Sam2Translator(Builder builder) {
pipeline = new Pipeline();
pipeline.add(new Resize(1024, 1024));
pipeline.add(new ToTensor());
pipeline.add(new Normalize(MEAN, STD));
this.encoderPath = builder.encoderPath;
}

/** {@inheritDoc} */
@Override
public void prepare(TranslatorContext ctx) throws IOException, ModelException {
if (encoderPath == null) {
return;
}
Model model = ctx.getModel();
Path path = Paths.get(encoderPath);
if (!path.isAbsolute() && Files.notExists(path)) {
path = model.getModelPath().resolve(encoderPath);
}
if (!Files.exists(path)) {
throw new IOException("encoder model not found: " + encoderPath);
}
NDManager manager = ctx.getNDManager();
Model encoder = manager.getEngine().newModel("encoder", manager.getDevice());
encoder.load(path);
predictor = encoder.newPredictor(new NoopTranslator(null));
model.getNDManager().attachInternal(UUID.randomUUID().toString(), predictor);
model.getNDManager().attachInternal(UUID.randomUUID().toString(), encoder);
}

/** {@inheritDoc} */
Expand All @@ -72,7 +106,21 @@ public NDList processInput(TranslatorContext ctx, Sam2Input input) throws Except
NDArray locations = manager.create(buf, new Shape(1, numPoints, 2));
NDArray labels = manager.create(input.getLabels());

return new NDList(array, locations, labels);
if (predictor == null) {
return new NDList(array, locations, labels);
}

NDList embeddings = predictor.predict(new NDList(array));
NDArray mask = manager.zeros(new Shape(1, 1, 256, 256));
NDArray hasMask = manager.zeros(new Shape(1));
return new NDList(
embeddings.get(2),
embeddings.get(0),
embeddings.get(1),
locations,
labels,
mask,
hasMask);
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -101,6 +149,55 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) throws
return new DetectedObjects(classes, probabilities, boxes);
}

/**
* Creates a builder to build a {@code Sam2Translator}.
*
* @return a new builder
*/
public static Builder builder() {
return builder(Collections.emptyMap());
}

/**
* Creates a builder to build a {@code Sam2Translator} with specified arguments.
*
* @param arguments arguments to specify builder options
* @return a new builder
*/
public static Builder builder(Map<String, ?> arguments) {
return new Builder(arguments);
}

/** The builder for Sam2Translator. */
public static class Builder {

String encoderPath;

Builder(Map<String, ?> arguments) {
encoderPath = ArgumentsUtil.stringValue(arguments, "encoder");
}

/**
* Sets the encoder model path.
*
* @param encoderPath the encoder model path
* @return the builder
*/
public Builder optEncoderPath(String encoderPath) {
this.encoderPath = encoderPath;
return this;
}

/**
* Builds the translator.
*
* @return the new translator
*/
public Sam2Translator build() {
return new Sam2Translator(this);
}
}

/** A class represents the segment anything input. */
public static final class Sam2Input {

Expand Down Expand Up @@ -149,8 +246,12 @@ float[] toLocationArray(int width, int height) {
return ret;
}

int[][] getLabels() {
return new int[][] {labels.stream().mapToInt(Integer::intValue).toArray()};
float[][] getLabels() {
float[][] buf = new float[1][labels.size()];
for (int i = 0; i < labels.size(); ++i) {
buf[0][i] = labels.get(i);
}
return buf;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
package ai.djl.modality.cv.translator;

import ai.djl.Model;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.CategoryMask;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.translator.Sam2Translator.Sam2Input;
import ai.djl.translate.Translator;
Expand Down Expand Up @@ -43,8 +41,8 @@ public class Sam2TranslatorFactory implements TranslatorFactory, Serializable {
@SuppressWarnings("unchecked")
public <I, O> Translator<I, O> newInstance(
Class<I> input, Class<O> output, Model model, Map<String, ?> arguments) {
if (input == Image.class && output == CategoryMask.class) {
return (Translator<I, O>) new Sam2Translator();
if (input == Sam2Input.class && output == DetectedObjects.class) {
return (Translator<I, O>) Sam2Translator.builder(arguments).build();
}
throw new IllegalArgumentException("Unsupported input/output types.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.translator.Sam2Translator;
import ai.djl.modality.cv.translator.Sam2Translator.Sam2Input;
import ai.djl.modality.cv.translator.Sam2TranslatorFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
Expand Down Expand Up @@ -52,9 +52,8 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran
Criteria.builder()
.setTypes(Sam2Input.class, DetectedObjects.class)
.optModelUrls("djl://ai.djl.pytorch/sam2-hiera-tiny")
.optEngine("PyTorch")
.optDevice(Device.cpu()) // use sam2-hiera-tiny-gpu for GPU
.optTranslator(new Sam2Translator())
.optTranslatorFactory(new Sam2TranslatorFactory())
.optProgress(new ProgressBar())
.build();

Expand Down

0 comments on commit f703b93

Please sign in to comment.