diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java new file mode 100644 index 00000000000..b83624fb8d4 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java @@ -0,0 +1,64 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.nlp.translator; + +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.NDList; +import ai.djl.translate.Batchifier; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import ai.djl.util.JsonUtils; + +/** A {@link Translator} that can handle generic text embedding {@link Input} and {@link Output}. */ +public class TextEmbeddingServingTranslator implements Translator { + + private Translator translator; + + /** + * Constructs a {@code TextEmbeddingServingTranslator} instance. + * + * @param translator a {@code Translator} processes text embedding input + */ + public TextEmbeddingServingTranslator(Translator translator) { + this.translator = translator; + } + + /** {@inheritDoc} */ + @Override + public Batchifier getBatchifier() { + return translator.getBatchifier(); + } + + /** {@inheritDoc} */ + @Override + public void prepare(TranslatorContext ctx) throws Exception { + translator.prepare(ctx); + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, Input input) throws Exception { + String text = input.getData().getAsString(); + return translator.processInput(ctx, text); + } + + /** {@inheritDoc} */ + @Override + public Output processOutput(TranslatorContext ctx, NDList list) throws Exception { + float[] ret = translator.processOutput(ctx, list); + Output output = new Output(); + output.add(JsonUtils.GSON_PRETTY.toJson(ret)); + return output; + } +} diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java new file mode 100644 index 00000000000..0187ae3fa90 --- /dev/null +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java @@ -0,0 +1,147 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.huggingface.translator; + +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.translate.ArgumentsUtil; +import ai.djl.translate.Batchifier; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; + +import java.io.IOException; +import java.util.Map; + +/** The translator for Huggingface text embedding model. */ +public class TextEmbeddingTranslator implements Translator { + + private static final int[] AXIS = {0}; + + private HuggingFaceTokenizer tokenizer; + private Batchifier batchifier; + + TextEmbeddingTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) { + this.tokenizer = tokenizer; + this.batchifier = batchifier; + } + + /** {@inheritDoc} */ + @Override + public Batchifier getBatchifier() { + return batchifier; + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, String input) { + NDManager manager = ctx.getNDManager(); + Encoding encoding = tokenizer.encode(input); + ctx.setAttachment("encoding", encoding); + long[] indices = encoding.getIds(); + long[] attentionMask = encoding.getAttentionMask(); + NDList ndList = new NDList(2); + ndList.add(manager.create(indices)); + ndList.add(manager.create(attentionMask)); + return ndList; + } + + /** {@inheritDoc} */ + @Override + public float[] processOutput(TranslatorContext ctx, NDList list) { + NDArray embeddings = list.get("last_hidden_state"); + Encoding encoding = (Encoding) ctx.getAttachment("encoding"); + long[] attentionMask = encoding.getAttentionMask(); + NDManager manager = ctx.getNDManager(); + NDArray inputAttentionMask = manager.create(attentionMask).toType(DataType.FLOAT32, true); + long[] shape = embeddings.getShape().getShape(); + inputAttentionMask = inputAttentionMask.tile(shape[shape.length - 1]); + inputAttentionMask = inputAttentionMask.reshape(embeddings.getShape()); + NDArray inputAttentionMaskSum = inputAttentionMask.sum(AXIS); + NDArray clamp = inputAttentionMaskSum.clip(1e-9, 1e12); + NDArray prod = embeddings.mul(inputAttentionMask); + NDArray sum = prod.sum(AXIS); + embeddings = sum.div(clamp).normalize(2, 0); + + return embeddings.toFloatArray(); + } + + /** + * Creates a builder to build a {@code TextEmbeddingTranslator}. + * + * @param tokenizer the tokenizer + * @return a new builder + */ + public static Builder builder(HuggingFaceTokenizer tokenizer) { + return new Builder(tokenizer); + } + + /** + * Creates a builder to build a {@code TextEmbeddingTranslator}. + * + * @param tokenizer the tokenizer + * @param arguments the models' arguments + * @return a new builder + */ + public static Builder builder(HuggingFaceTokenizer tokenizer, Map arguments) { + Builder builder = builder(tokenizer); + builder.configure(arguments); + + return builder; + } + + /** The builder for token classification translator. */ + public static final class Builder { + + private HuggingFaceTokenizer tokenizer; + private Batchifier batchifier = Batchifier.STACK; + + Builder(HuggingFaceTokenizer tokenizer) { + this.tokenizer = tokenizer; + } + + /** + * Sets the {@link Batchifier} for the {@link Translator}. + * + * @param batchifier true to include token types + * @return this builder + */ + public TextEmbeddingTranslator.Builder optBatchifier(Batchifier batchifier) { + this.batchifier = batchifier; + return this; + } + + /** + * Configures the builder with the model arguments. + * + * @param arguments the model arguments + */ + public void configure(Map arguments) { + String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack"); + optBatchifier(Batchifier.fromString(batchifierStr)); + } + + /** + * Builds the translator. + * + * @return the new translator + * @throws IOException if I/O error occurs + */ + public TextEmbeddingTranslator build() throws IOException { + return new TextEmbeddingTranslator(tokenizer, batchifier); + } + } +} diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslatorFactory.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslatorFactory.java new file mode 100644 index 00000000000..87d692abe57 --- /dev/null +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslatorFactory.java @@ -0,0 +1,72 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.huggingface.translator; + +import ai.djl.Model; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.modality.nlp.translator.TextEmbeddingServingTranslator; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorFactory; +import ai.djl.util.Pair; + +import java.io.IOException; +import java.lang.reflect.Type; +import java.nio.file.Path; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** A {@link TranslatorFactory} that creates a {@link TextEmbeddingTranslator} instance. */ +public class TextEmbeddingTranslatorFactory implements TranslatorFactory { + + private static final Set> SUPPORTED_TYPES = new HashSet<>(); + + static { + SUPPORTED_TYPES.add(new Pair<>(String.class, float[].class)); + SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); + } + + /** {@inheritDoc} */ + @Override + public Set> getSupportedTypes() { + return SUPPORTED_TYPES; + } + + /** {@inheritDoc} */ + @Override + public Translator newInstance( + Class input, Class output, Model model, Map arguments) + throws TranslateException { + Path modelPath = model.getModelPath(); + try { + HuggingFaceTokenizer tokenizer = + HuggingFaceTokenizer.builder(arguments) + .optTokenizerPath(modelPath) + .optManager(model.getNDManager()) + .build(); + TextEmbeddingTranslator translator = + TextEmbeddingTranslator.builder(tokenizer, arguments).build(); + if (input == String.class && output == float[].class) { + return translator; + } else if (input == Input.class && output == Output.class) { + return new TextEmbeddingServingTranslator(translator); + } + throw new IllegalArgumentException("Unsupported input/output types."); + } catch (IOException e) { + throw new TranslateException("Failed to load tokenizer.", e); + } + } +} diff --git a/extensions/tokenizers/src/main/python/fill_mask_converter.py b/extensions/tokenizers/src/main/python/fill_mask_converter.py index f9b1dadbdcb..7b06552ff51 100644 --- a/extensions/tokenizers/src/main/python/fill_mask_converter.py +++ b/extensions/tokenizers/src/main/python/fill_mask_converter.py @@ -56,3 +56,6 @@ def verify_jit_output(self, hf_pipeline, encoding, out): def encode_inputs(self, tokenizer): text = self.inputs.replace("[MASK]", tokenizer.mask_token) return tokenizer.encode_plus(text, return_tensors='pt') + + def get_extra_arguments(self, hf_pipeline) -> dict: + return {"maskToken": hf_pipeline.tokenizer.mask_token} diff --git a/extensions/tokenizers/src/main/python/huggingface_converter.py b/extensions/tokenizers/src/main/python/huggingface_converter.py index f965a76528a..6fdd53f22fc 100644 --- a/extensions/tokenizers/src/main/python/huggingface_converter.py +++ b/extensions/tokenizers/src/main/python/huggingface_converter.py @@ -57,7 +57,7 @@ def save_model(self, model_id: str, output_dir: str, temp_dir: str): return False, reason, -1 size = self.save_to_model_zoo(model_id, output_dir, temp_dir, - hf_pipeline.tokenizer.mask_token) + hf_pipeline) return True, None, size @@ -100,7 +100,7 @@ def jit_trace_model(self, hf_pipeline, model_id: str, temp_dir: str): return model_file def save_to_model_zoo(self, model_id: str, output_dir: str, temp_dir: str, - mask_token: str): + hf_pipeline): artifact_ids = model_id.split("/") model_name = artifact_ids[-1] @@ -111,13 +111,14 @@ def save_to_model_zoo(self, model_id: str, output_dir: str, temp_dir: str, # Save serving.properties serving_file = os.path.join(temp_dir, "serving.properties") + arguments = self.get_extra_arguments(hf_pipeline) with open(serving_file, 'w') as f: f.write(f"engine=PyTorch\n" f"option.modelName={model_name}\n" f"option.mapLocation=true\n" f"translatorFactory={self.translator}\n") - if mask_token: - f.write(f"maskToken={mask_token}\n") + for k, v in arguments.items(): + f.write(f"{k}={v}\n") # Save model as .zip file logging.info(f"Saving DJL model as zip: {model_name}.zip ...") @@ -157,6 +158,9 @@ def verify_jit_model(self, hf_pipeline, model_file: str): return self.verify_jit_output(hf_pipeline, encoding, out) + def get_extra_arguments(self, hf_pipeline) -> dict: + return {} + def verify_jit_output(self, hf_pipeline, encoding, out): if not hasattr(out, "last_hidden_layer"): return False, f"Unexpected inference result: {out}" diff --git a/extensions/tokenizers/src/main/python/sentence_similarity_converter.py b/extensions/tokenizers/src/main/python/sentence_similarity_converter.py index baa0382a1c8..dbcb587a2e0 100644 --- a/extensions/tokenizers/src/main/python/sentence_similarity_converter.py +++ b/extensions/tokenizers/src/main/python/sentence_similarity_converter.py @@ -51,3 +51,6 @@ def verify_jit_output(self, hf_pipeline, encoding, out): return False, f"Unexpected inference result: {last_hidden_state}" return True, None + + def get_extra_arguments(self, hf_pipeline) -> dict: + return {"padding": "true"} diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TranslatorTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TranslatorTest.java index 2cad1353bec..88c72d41c3d 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TranslatorTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TranslatorTest.java @@ -16,6 +16,7 @@ import ai.djl.ModelException; import ai.djl.huggingface.translator.FillMaskTranslatorFactory; import ai.djl.huggingface.translator.QuestionAnsweringTranslatorFactory; +import ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory; import ai.djl.huggingface.translator.TokenClassificationTranslatorFactory; import ai.djl.inference.Predictor; import ai.djl.modality.Classifications; @@ -26,6 +27,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; import ai.djl.nn.LambdaBlock; import ai.djl.repository.zoo.Criteria; @@ -325,4 +327,83 @@ public void testTokenClassificationTranslator() () -> factory.newInstance(String.class, Integer.class, model, arguments)); } } + + @Test + public void testTextEmbeddingTranslator() + throws ModelException, IOException, TranslateException { + TestRequirements.notArm(); + + String text = "This is an example sentence"; + + Block block = + new LambdaBlock( + a -> { + NDManager manager = a.getManager(); + NDArray arr = manager.ones(new Shape(1, 7, 384)); + arr.setName("last_hidden_state"); + return new NDList(arr); + }, + "model"); + Path modelDir = Paths.get("build/model"); + Files.createDirectories(modelDir); + + Criteria criteria = + Criteria.builder() + .setTypes(String.class, float[].class) + .optModelPath(modelDir) + .optBlock(block) + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-uncased") + .optOption("hasParameter", "false") + .optTranslatorFactory(new TextEmbeddingTranslatorFactory()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + float[] res = predictor.predict(text); + Assert.assertEquals(res.length, 384); + Assertions.assertAlmostEquals(res[0], 0.05103); + } + + Criteria criteria2 = + Criteria.builder() + .setTypes(Input.class, Output.class) + .optModelPath(modelDir) + .optBlock(block) + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-uncased") + .optOption("hasParameter", "false") + .optTranslatorFactory(new TextEmbeddingTranslatorFactory()) + .build(); + + try (ZooModel model = criteria2.loadModel(); + Predictor predictor = model.newPredictor()) { + Input input = new Input(); + input.add(text); + Output out = predictor.predict(input); + float[] res = JsonUtils.GSON.fromJson(out.getAsString(0), float[].class); + Assert.assertEquals(res.length, 384); + Assertions.assertAlmostEquals(res[0], 0.05103); + } + + try (Model model = Model.newInstance("test")) { + model.setBlock(block); + Map options = new HashMap<>(); + options.put("hasParameter", "false"); + model.load(modelDir, "test", options); + + TextEmbeddingTranslatorFactory factory = new TextEmbeddingTranslatorFactory(); + Map arguments = new HashMap<>(); + + Assert.assertThrows( + TranslateException.class, + () -> factory.newInstance(String.class, Integer.class, model, arguments)); + + arguments.put("tokenizer", "bert-base-uncased"); + + Assert.assertThrows( + IllegalArgumentException.class, + () -> factory.newInstance(String.class, Integer.class, model, arguments)); + } + } }