diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java
new file mode 100644
index 00000000000..b83624fb8d4
--- /dev/null
+++ b/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java
@@ -0,0 +1,64 @@
+/*
+ * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.modality.nlp.translator;
+
+import ai.djl.modality.Input;
+import ai.djl.modality.Output;
+import ai.djl.ndarray.NDList;
+import ai.djl.translate.Batchifier;
+import ai.djl.translate.Translator;
+import ai.djl.translate.TranslatorContext;
+import ai.djl.util.JsonUtils;
+
+/** A {@link Translator} that can handle generic text embedding {@link Input} and {@link Output}. */
+public class TextEmbeddingServingTranslator implements Translator {
+
+ private Translator translator;
+
+ /**
+ * Constructs a {@code TextEmbeddingServingTranslator} instance.
+ *
+ * @param translator a {@code Translator} processes text embedding input
+ */
+ public TextEmbeddingServingTranslator(Translator translator) {
+ this.translator = translator;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Batchifier getBatchifier() {
+ return translator.getBatchifier();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void prepare(TranslatorContext ctx) throws Exception {
+ translator.prepare(ctx);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
+ String text = input.getData().getAsString();
+ return translator.processInput(ctx, text);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
+ float[] ret = translator.processOutput(ctx, list);
+ Output output = new Output();
+ output.add(JsonUtils.GSON_PRETTY.toJson(ret));
+ return output;
+ }
+}
diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java
new file mode 100644
index 00000000000..0187ae3fa90
--- /dev/null
+++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java
@@ -0,0 +1,147 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.huggingface.translator;
+
+import ai.djl.huggingface.tokenizers.Encoding;
+import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.NDManager;
+import ai.djl.ndarray.types.DataType;
+import ai.djl.translate.ArgumentsUtil;
+import ai.djl.translate.Batchifier;
+import ai.djl.translate.Translator;
+import ai.djl.translate.TranslatorContext;
+
+import java.io.IOException;
+import java.util.Map;
+
+/** The translator for Huggingface text embedding model. */
+public class TextEmbeddingTranslator implements Translator {
+
+ private static final int[] AXIS = {0};
+
+ private HuggingFaceTokenizer tokenizer;
+ private Batchifier batchifier;
+
+ TextEmbeddingTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) {
+ this.tokenizer = tokenizer;
+ this.batchifier = batchifier;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Batchifier getBatchifier() {
+ return batchifier;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDList processInput(TranslatorContext ctx, String input) {
+ NDManager manager = ctx.getNDManager();
+ Encoding encoding = tokenizer.encode(input);
+ ctx.setAttachment("encoding", encoding);
+ long[] indices = encoding.getIds();
+ long[] attentionMask = encoding.getAttentionMask();
+ NDList ndList = new NDList(2);
+ ndList.add(manager.create(indices));
+ ndList.add(manager.create(attentionMask));
+ return ndList;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public float[] processOutput(TranslatorContext ctx, NDList list) {
+ NDArray embeddings = list.get("last_hidden_state");
+ Encoding encoding = (Encoding) ctx.getAttachment("encoding");
+ long[] attentionMask = encoding.getAttentionMask();
+ NDManager manager = ctx.getNDManager();
+ NDArray inputAttentionMask = manager.create(attentionMask).toType(DataType.FLOAT32, true);
+ long[] shape = embeddings.getShape().getShape();
+ inputAttentionMask = inputAttentionMask.tile(shape[shape.length - 1]);
+ inputAttentionMask = inputAttentionMask.reshape(embeddings.getShape());
+ NDArray inputAttentionMaskSum = inputAttentionMask.sum(AXIS);
+ NDArray clamp = inputAttentionMaskSum.clip(1e-9, 1e12);
+ NDArray prod = embeddings.mul(inputAttentionMask);
+ NDArray sum = prod.sum(AXIS);
+ embeddings = sum.div(clamp).normalize(2, 0);
+
+ return embeddings.toFloatArray();
+ }
+
+ /**
+ * Creates a builder to build a {@code TextEmbeddingTranslator}.
+ *
+ * @param tokenizer the tokenizer
+ * @return a new builder
+ */
+ public static Builder builder(HuggingFaceTokenizer tokenizer) {
+ return new Builder(tokenizer);
+ }
+
+ /**
+ * Creates a builder to build a {@code TextEmbeddingTranslator}.
+ *
+ * @param tokenizer the tokenizer
+ * @param arguments the models' arguments
+ * @return a new builder
+ */
+ public static Builder builder(HuggingFaceTokenizer tokenizer, Map arguments) {
+ Builder builder = builder(tokenizer);
+ builder.configure(arguments);
+
+ return builder;
+ }
+
+ /** The builder for token classification translator. */
+ public static final class Builder {
+
+ private HuggingFaceTokenizer tokenizer;
+ private Batchifier batchifier = Batchifier.STACK;
+
+ Builder(HuggingFaceTokenizer tokenizer) {
+ this.tokenizer = tokenizer;
+ }
+
+ /**
+ * Sets the {@link Batchifier} for the {@link Translator}.
+ *
+ * @param batchifier true to include token types
+ * @return this builder
+ */
+ public TextEmbeddingTranslator.Builder optBatchifier(Batchifier batchifier) {
+ this.batchifier = batchifier;
+ return this;
+ }
+
+ /**
+ * Configures the builder with the model arguments.
+ *
+ * @param arguments the model arguments
+ */
+ public void configure(Map arguments) {
+ String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
+ optBatchifier(Batchifier.fromString(batchifierStr));
+ }
+
+ /**
+ * Builds the translator.
+ *
+ * @return the new translator
+ * @throws IOException if I/O error occurs
+ */
+ public TextEmbeddingTranslator build() throws IOException {
+ return new TextEmbeddingTranslator(tokenizer, batchifier);
+ }
+ }
+}
diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslatorFactory.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslatorFactory.java
new file mode 100644
index 00000000000..87d692abe57
--- /dev/null
+++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslatorFactory.java
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.huggingface.translator;
+
+import ai.djl.Model;
+import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
+import ai.djl.modality.Input;
+import ai.djl.modality.Output;
+import ai.djl.modality.nlp.translator.TextEmbeddingServingTranslator;
+import ai.djl.translate.TranslateException;
+import ai.djl.translate.Translator;
+import ai.djl.translate.TranslatorFactory;
+import ai.djl.util.Pair;
+
+import java.io.IOException;
+import java.lang.reflect.Type;
+import java.nio.file.Path;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+/** A {@link TranslatorFactory} that creates a {@link TextEmbeddingTranslator} instance. */
+public class TextEmbeddingTranslatorFactory implements TranslatorFactory {
+
+ private static final Set> SUPPORTED_TYPES = new HashSet<>();
+
+ static {
+ SUPPORTED_TYPES.add(new Pair<>(String.class, float[].class));
+ SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class));
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Set> getSupportedTypes() {
+ return SUPPORTED_TYPES;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Translator, ?> newInstance(
+ Class> input, Class> output, Model model, Map arguments)
+ throws TranslateException {
+ Path modelPath = model.getModelPath();
+ try {
+ HuggingFaceTokenizer tokenizer =
+ HuggingFaceTokenizer.builder(arguments)
+ .optTokenizerPath(modelPath)
+ .optManager(model.getNDManager())
+ .build();
+ TextEmbeddingTranslator translator =
+ TextEmbeddingTranslator.builder(tokenizer, arguments).build();
+ if (input == String.class && output == float[].class) {
+ return translator;
+ } else if (input == Input.class && output == Output.class) {
+ return new TextEmbeddingServingTranslator(translator);
+ }
+ throw new IllegalArgumentException("Unsupported input/output types.");
+ } catch (IOException e) {
+ throw new TranslateException("Failed to load tokenizer.", e);
+ }
+ }
+}
diff --git a/extensions/tokenizers/src/main/python/fill_mask_converter.py b/extensions/tokenizers/src/main/python/fill_mask_converter.py
index f9b1dadbdcb..7b06552ff51 100644
--- a/extensions/tokenizers/src/main/python/fill_mask_converter.py
+++ b/extensions/tokenizers/src/main/python/fill_mask_converter.py
@@ -56,3 +56,6 @@ def verify_jit_output(self, hf_pipeline, encoding, out):
def encode_inputs(self, tokenizer):
text = self.inputs.replace("[MASK]", tokenizer.mask_token)
return tokenizer.encode_plus(text, return_tensors='pt')
+
+ def get_extra_arguments(self, hf_pipeline) -> dict:
+ return {"maskToken": hf_pipeline.tokenizer.mask_token}
diff --git a/extensions/tokenizers/src/main/python/huggingface_converter.py b/extensions/tokenizers/src/main/python/huggingface_converter.py
index f965a76528a..6fdd53f22fc 100644
--- a/extensions/tokenizers/src/main/python/huggingface_converter.py
+++ b/extensions/tokenizers/src/main/python/huggingface_converter.py
@@ -57,7 +57,7 @@ def save_model(self, model_id: str, output_dir: str, temp_dir: str):
return False, reason, -1
size = self.save_to_model_zoo(model_id, output_dir, temp_dir,
- hf_pipeline.tokenizer.mask_token)
+ hf_pipeline)
return True, None, size
@@ -100,7 +100,7 @@ def jit_trace_model(self, hf_pipeline, model_id: str, temp_dir: str):
return model_file
def save_to_model_zoo(self, model_id: str, output_dir: str, temp_dir: str,
- mask_token: str):
+ hf_pipeline):
artifact_ids = model_id.split("/")
model_name = artifact_ids[-1]
@@ -111,13 +111,14 @@ def save_to_model_zoo(self, model_id: str, output_dir: str, temp_dir: str,
# Save serving.properties
serving_file = os.path.join(temp_dir, "serving.properties")
+ arguments = self.get_extra_arguments(hf_pipeline)
with open(serving_file, 'w') as f:
f.write(f"engine=PyTorch\n"
f"option.modelName={model_name}\n"
f"option.mapLocation=true\n"
f"translatorFactory={self.translator}\n")
- if mask_token:
- f.write(f"maskToken={mask_token}\n")
+ for k, v in arguments.items():
+ f.write(f"{k}={v}\n")
# Save model as .zip file
logging.info(f"Saving DJL model as zip: {model_name}.zip ...")
@@ -157,6 +158,9 @@ def verify_jit_model(self, hf_pipeline, model_file: str):
return self.verify_jit_output(hf_pipeline, encoding, out)
+ def get_extra_arguments(self, hf_pipeline) -> dict:
+ return {}
+
def verify_jit_output(self, hf_pipeline, encoding, out):
if not hasattr(out, "last_hidden_layer"):
return False, f"Unexpected inference result: {out}"
diff --git a/extensions/tokenizers/src/main/python/sentence_similarity_converter.py b/extensions/tokenizers/src/main/python/sentence_similarity_converter.py
index baa0382a1c8..dbcb587a2e0 100644
--- a/extensions/tokenizers/src/main/python/sentence_similarity_converter.py
+++ b/extensions/tokenizers/src/main/python/sentence_similarity_converter.py
@@ -51,3 +51,6 @@ def verify_jit_output(self, hf_pipeline, encoding, out):
return False, f"Unexpected inference result: {last_hidden_state}"
return True, None
+
+ def get_extra_arguments(self, hf_pipeline) -> dict:
+ return {"padding": "true"}
diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TranslatorTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TranslatorTest.java
index 2cad1353bec..88c72d41c3d 100644
--- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TranslatorTest.java
+++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TranslatorTest.java
@@ -16,6 +16,7 @@
import ai.djl.ModelException;
import ai.djl.huggingface.translator.FillMaskTranslatorFactory;
import ai.djl.huggingface.translator.QuestionAnsweringTranslatorFactory;
+import ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory;
import ai.djl.huggingface.translator.TokenClassificationTranslatorFactory;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
@@ -26,6 +27,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
+import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.repository.zoo.Criteria;
@@ -325,4 +327,83 @@ public void testTokenClassificationTranslator()
() -> factory.newInstance(String.class, Integer.class, model, arguments));
}
}
+
+ @Test
+ public void testTextEmbeddingTranslator()
+ throws ModelException, IOException, TranslateException {
+ TestRequirements.notArm();
+
+ String text = "This is an example sentence";
+
+ Block block =
+ new LambdaBlock(
+ a -> {
+ NDManager manager = a.getManager();
+ NDArray arr = manager.ones(new Shape(1, 7, 384));
+ arr.setName("last_hidden_state");
+ return new NDList(arr);
+ },
+ "model");
+ Path modelDir = Paths.get("build/model");
+ Files.createDirectories(modelDir);
+
+ Criteria criteria =
+ Criteria.builder()
+ .setTypes(String.class, float[].class)
+ .optModelPath(modelDir)
+ .optBlock(block)
+ .optEngine("PyTorch")
+ .optArgument("tokenizer", "bert-base-uncased")
+ .optOption("hasParameter", "false")
+ .optTranslatorFactory(new TextEmbeddingTranslatorFactory())
+ .build();
+
+ try (ZooModel model = criteria.loadModel();
+ Predictor predictor = model.newPredictor()) {
+ float[] res = predictor.predict(text);
+ Assert.assertEquals(res.length, 384);
+ Assertions.assertAlmostEquals(res[0], 0.05103);
+ }
+
+ Criteria criteria2 =
+ Criteria.builder()
+ .setTypes(Input.class, Output.class)
+ .optModelPath(modelDir)
+ .optBlock(block)
+ .optEngine("PyTorch")
+ .optArgument("tokenizer", "bert-base-uncased")
+ .optOption("hasParameter", "false")
+ .optTranslatorFactory(new TextEmbeddingTranslatorFactory())
+ .build();
+
+ try (ZooModel model = criteria2.loadModel();
+ Predictor predictor = model.newPredictor()) {
+ Input input = new Input();
+ input.add(text);
+ Output out = predictor.predict(input);
+ float[] res = JsonUtils.GSON.fromJson(out.getAsString(0), float[].class);
+ Assert.assertEquals(res.length, 384);
+ Assertions.assertAlmostEquals(res[0], 0.05103);
+ }
+
+ try (Model model = Model.newInstance("test")) {
+ model.setBlock(block);
+ Map options = new HashMap<>();
+ options.put("hasParameter", "false");
+ model.load(modelDir, "test", options);
+
+ TextEmbeddingTranslatorFactory factory = new TextEmbeddingTranslatorFactory();
+ Map arguments = new HashMap<>();
+
+ Assert.assertThrows(
+ TranslateException.class,
+ () -> factory.newInstance(String.class, Integer.class, model, arguments));
+
+ arguments.put("tokenizer", "bert-base-uncased");
+
+ Assert.assertThrows(
+ IllegalArgumentException.class,
+ () -> factory.newInstance(String.class, Integer.class, model, arguments));
+ }
+ }
}