diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java new file mode 100644 index 00000000000..e62167a34b2 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java @@ -0,0 +1,115 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.nlp.translator; + +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.BytesSupplier; +import ai.djl.ndarray.NDList; +import ai.djl.translate.Batchifier; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import ai.djl.util.JsonUtils; +import ai.djl.util.PairList; +import ai.djl.util.StringPair; + +import com.google.gson.JsonElement; +import com.google.gson.JsonParseException; + +/** A {@link Translator} that can handle generic cross encoder {@link Input} and {@link Output}. */ +public class CrossEncoderServingTranslator implements NoBatchifyTranslator { + + private Translator translator; + private Translator batchTranslator; + + /** + * Constructs a {@code CrossEncoderServingTranslator} instance. + * + * @param translator a {@code Translator} processes question answering input + */ + public CrossEncoderServingTranslator(Translator translator) { + this.translator = translator; + this.batchTranslator = translator.toBatchTranslator(); + } + + /** {@inheritDoc} */ + @Override + public void prepare(TranslatorContext ctx) throws Exception { + translator.prepare(ctx); + batchTranslator.prepare(ctx); + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, Input input) throws Exception { + PairList content = input.getContent(); + if (content.isEmpty()) { + throw new TranslateException("Input data is empty."); + } + + String contentType = input.getProperty("Content-Type", null); + StringPair pair; + if ("application/json".equals(contentType)) { + String json = input.getData().getAsString(); + try { + JsonElement element = JsonUtils.GSON.fromJson(json, JsonElement.class); + if (element.isJsonArray()) { + ctx.setAttachment("batch", Boolean.TRUE); + StringPair[] inputs = JsonUtils.GSON.fromJson(json, StringPair[].class); + return batchTranslator.processInput(ctx, inputs); + } + + pair = JsonUtils.GSON.fromJson(json, StringPair.class); + if (pair.getKey() == null || pair.getValue() == null) { + throw new TranslateException("Missing key or value in json."); + } + } catch (JsonParseException e) { + throw new TranslateException("Input is not a valid json.", e); + } + } else { + String key = input.getAsString("key"); + String value = input.getAsString("value"); + if (key == null || value == null) { + throw new TranslateException("Missing key or value in input."); + } + pair = new StringPair(key, value); + } + + NDList ret = translator.processInput(ctx, pair); + Batchifier batchifier = translator.getBatchifier(); + if (batchifier != null) { + NDList[] batch = {ret}; + return batchifier.batchify(batch); + } + return ret; + } + + /** {@inheritDoc} */ + @Override + public Output processOutput(TranslatorContext ctx, NDList list) throws Exception { + Output output = new Output(); + output.addProperty("Content-Type", "application/json"); + if (ctx.getAttachment("batch") != null) { + output.add(BytesSupplier.wrapAsJson(batchTranslator.processOutput(ctx, list))); + } else { + Batchifier batchifier = translator.getBatchifier(); + if (batchifier != null) { + list = batchifier.unbatchify(list)[0]; + } + output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list))); + } + return output; + } +} diff --git a/api/src/main/java/ai/djl/util/StringPair.java b/api/src/main/java/ai/djl/util/StringPair.java new file mode 100644 index 00000000000..a42e739614b --- /dev/null +++ b/api/src/main/java/ai/djl/util/StringPair.java @@ -0,0 +1,27 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.util; + +/** A class containing the string key-value pair. */ +public class StringPair extends Pair { + + /** + * Constructs a {@code Pair} instance with key and value. + * + * @param key the key + * @param value the value + */ + public StringPair(String key, String value) { + super(key, value); + } +} diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderBatchTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderBatchTranslator.java new file mode 100644 index 00000000000..6f43c7cb480 --- /dev/null +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderBatchTranslator.java @@ -0,0 +1,69 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.huggingface.translator; + +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.translate.Batchifier; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslateException; +import ai.djl.translate.TranslatorContext; +import ai.djl.util.PairList; +import ai.djl.util.StringPair; + +import java.util.Arrays; + +/** The translator for Huggingface cross encoder model. */ +public class CrossEncoderBatchTranslator implements NoBatchifyTranslator { + + private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; + private Batchifier batchifier; + + CrossEncoderBatchTranslator( + HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) { + this.tokenizer = tokenizer; + this.includeTokenTypes = includeTokenTypes; + this.batchifier = batchifier; + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, StringPair[] inputs) + throws TranslateException { + NDManager manager = ctx.getNDManager(); + PairList list = new PairList<>(Arrays.asList(inputs)); + Encoding[] encodings = tokenizer.batchEncode(list); + NDList[] batch = new NDList[encodings.length]; + for (int i = 0; i < encodings.length; ++i) { + batch[i] = encodings[i].toNDList(manager, includeTokenTypes); + } + return batchifier.batchify(batch); + } + + /** {@inheritDoc} */ + @Override + public float[][] processOutput(TranslatorContext ctx, NDList list) { + NDList[] batch = batchifier.unbatchify(list); + float[][] ret = new float[batch.length][]; + for (int i = 0; i < batch.length; ++i) { + NDArray logits = list.get(0); + NDArray result = logits.getNDArrayInternal().sigmoid(); + ret[i] = result.toFloatArray(); + } + return ret; + } +} diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java new file mode 100644 index 00000000000..b88347bc60e --- /dev/null +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java @@ -0,0 +1,149 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.huggingface.translator; + +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.translate.ArgumentsUtil; +import ai.djl.translate.Batchifier; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import ai.djl.util.StringPair; + +import java.io.IOException; +import java.util.Map; + +/** The translator for Huggingface cross encoder model. */ +public class CrossEncoderTranslator implements Translator { + + private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; + private Batchifier batchifier; + + CrossEncoderTranslator( + HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) { + this.tokenizer = tokenizer; + this.includeTokenTypes = includeTokenTypes; + this.batchifier = batchifier; + } + + /** {@inheritDoc} */ + @Override + public Batchifier getBatchifier() { + return batchifier; + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, StringPair input) { + Encoding encoding = tokenizer.encode(input.getKey(), input.getValue()); + ctx.setAttachment("encoding", encoding); + return encoding.toNDList(ctx.getNDManager(), includeTokenTypes); + } + + /** {@inheritDoc} */ + @Override + public float[] processOutput(TranslatorContext ctx, NDList list) { + NDArray logits = list.get(0); + NDArray result = logits.getNDArrayInternal().sigmoid(); + return result.toFloatArray(); + } + + /** {@inheritDoc} */ + @Override + public CrossEncoderBatchTranslator toBatchTranslator(Batchifier batchifier) { + tokenizer.enableBatch(); + return new CrossEncoderBatchTranslator(tokenizer, includeTokenTypes, batchifier); + } + + /** + * Creates a builder to build a {@code CrossEncoderTranslator}. + * + * @param tokenizer the tokenizer + * @return a new builder + */ + public static Builder builder(HuggingFaceTokenizer tokenizer) { + return new Builder(tokenizer); + } + + /** + * Creates a builder to build a {@code CrossEncoderTranslator}. + * + * @param tokenizer the tokenizer + * @param arguments the models' arguments + * @return a new builder + */ + public static Builder builder(HuggingFaceTokenizer tokenizer, Map arguments) { + Builder builder = builder(tokenizer); + builder.configure(arguments); + + return builder; + } + + /** The builder for question answering translator. */ + public static final class Builder { + + private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; + private Batchifier batchifier = Batchifier.STACK; + + Builder(HuggingFaceTokenizer tokenizer) { + this.tokenizer = tokenizer; + } + + /** + * Sets if include token types for the {@link Translator}. + * + * @param includeTokenTypes true to include token types + * @return this builder + */ + public Builder optIncludeTokenTypes(boolean includeTokenTypes) { + this.includeTokenTypes = includeTokenTypes; + return this; + } + + /** + * Sets the {@link Batchifier} for the {@link Translator}. + * + * @param batchifier true to include token types + * @return this builder + */ + public Builder optBatchifier(Batchifier batchifier) { + this.batchifier = batchifier; + return this; + } + + /** + * Configures the builder with the model arguments. + * + * @param arguments the model arguments + */ + public void configure(Map arguments) { + optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes")); + String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack"); + optBatchifier(Batchifier.fromString(batchifierStr)); + } + + /** + * Builds the translator. + * + * @return the new translator + * @throws IOException if I/O error occurs + */ + public CrossEncoderTranslator build() throws IOException { + return new CrossEncoderTranslator(tokenizer, includeTokenTypes, batchifier); + } + } +} diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslatorFactory.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslatorFactory.java new file mode 100644 index 00000000000..f4f9af02c4b --- /dev/null +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslatorFactory.java @@ -0,0 +1,80 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.huggingface.translator; + +import ai.djl.Model; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.modality.nlp.translator.CrossEncoderServingTranslator; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorFactory; +import ai.djl.util.Pair; +import ai.djl.util.StringPair; + +import java.io.IOException; +import java.io.Serializable; +import java.lang.reflect.Type; +import java.nio.file.Path; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** A {@link TranslatorFactory} that creates a {@link CrossEncoderTranslatorFactory} instance. */ +public class CrossEncoderTranslatorFactory implements TranslatorFactory, Serializable { + + private static final long serialVersionUID = 1L; + + private static final Set> SUPPORTED_TYPES = new HashSet<>(); + + static { + SUPPORTED_TYPES.add(new Pair<>(StringPair.class, float[].class)); + SUPPORTED_TYPES.add(new Pair<>(StringPair[].class, float[][].class)); + SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); + } + + /** {@inheritDoc} */ + @Override + public Set> getSupportedTypes() { + return SUPPORTED_TYPES; + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public Translator newInstance( + Class input, Class output, Model model, Map arguments) + throws TranslateException { + Path modelPath = model.getModelPath(); + try { + HuggingFaceTokenizer tokenizer = + HuggingFaceTokenizer.builder(arguments) + .optTokenizerPath(modelPath) + .optManager(model.getNDManager()) + .build(); + CrossEncoderTranslator translator = + CrossEncoderTranslator.builder(tokenizer, arguments).build(); + if (input == StringPair.class && output == float[].class) { + return (Translator) translator; + } else if (input == StringPair[].class && output == float[][].class) { + return (Translator) translator.toBatchTranslator(); + } else if (input == Input.class && output == Output.class) { + return (Translator) new CrossEncoderServingTranslator(translator); + } + throw new IllegalArgumentException("Unsupported input/output types."); + } catch (IOException e) { + throw new TranslateException("Failed to load tokenizer.", e); + } + } +} diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java new file mode 100644 index 00000000000..f3ee102e325 --- /dev/null +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java @@ -0,0 +1,204 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.huggingface.tokenizers; + +import ai.djl.Model; +import ai.djl.ModelException; +import ai.djl.huggingface.translator.CrossEncoderTranslatorFactory; +import ai.djl.inference.Predictor; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.nn.Block; +import ai.djl.nn.LambdaBlock; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.TranslateException; +import ai.djl.util.JsonUtils; +import ai.djl.util.StringPair; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + +public class CrossEncoderTranslatorTest { + + @Test + public void testCrossEncoderTranslator() + throws ModelException, IOException, TranslateException { + String text1 = "Sentence 1"; + String text2 = "Sentence 2"; + Block block = + new LambdaBlock( + a -> { + NDManager manager = a.getManager(); + NDArray array = manager.create(new float[] {-0.7329f}); + return new NDList(array); + }, + "model"); + Path modelDir = Paths.get("build/model"); + Files.createDirectories(modelDir); + + Criteria criteria = + Criteria.builder() + .setTypes(StringPair.class, float[].class) + .optModelPath(modelDir) + .optBlock(block) + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-cased") + .optOption("hasParameter", "false") + .optTranslatorFactory(new CrossEncoderTranslatorFactory()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + StringPair input = new StringPair(text1, text2); + float[] res = predictor.predict(input); + Assert.assertEquals(res[0], 0.32456556f, 0.0001); + } + + Criteria criteria2 = + Criteria.builder() + .setTypes(Input.class, Output.class) + .optModelPath(modelDir) + .optBlock(block) + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-cased") + .optOption("hasParameter", "false") + .optTranslatorFactory(new CrossEncoderTranslatorFactory()) + .build(); + + try (ZooModel model = criteria2.loadModel(); + Predictor predictor = model.newPredictor()) { + Input input = new Input(); + input.add("key", text1); + input.add("value", text2); + Output res = predictor.predict(input); + float[] buf = (float[]) res.getData().getAsObject(); + Assert.assertEquals(buf[0], 0.32455865, 0.0001); + + Assert.assertThrows(TranslateException.class, () -> predictor.predict(new Input())); + + Assert.assertThrows( + TranslateException.class, + () -> { + Input req = new Input(); + req.add("something", "false"); + predictor.predict(req); + }); + + Assert.assertThrows( + TranslateException.class, + () -> { + Input req = new Input(); + req.addProperty("Content-Type", "application/json"); + req.add("Invalid json"); + predictor.predict(req); + }); + + Assert.assertThrows( + TranslateException.class, + () -> { + Input req = new Input(); + req.addProperty("Content-Type", "application/json"); + req.add(JsonUtils.GSON.toJson(new StringPair(text1, null))); + predictor.predict(req); + }); + } + + try (Model model = Model.newInstance("test")) { + model.setBlock(block); + Map options = new HashMap<>(); + options.put("hasParameter", "false"); + model.load(modelDir, "test", options); + + CrossEncoderTranslatorFactory factory = new CrossEncoderTranslatorFactory(); + Map arguments = new HashMap<>(); + + Assert.assertThrows( + TranslateException.class, + () -> factory.newInstance(String.class, Integer.class, model, arguments)); + + arguments.put("tokenizer", "bert-base-cased"); + + Assert.assertThrows( + IllegalArgumentException.class, + () -> factory.newInstance(String.class, Integer.class, model, arguments)); + } + } + + @Test + public void testCrossEncoderBatchTranslator() + throws ModelException, IOException, TranslateException { + StringPair pair1 = new StringPair("Sentence 1", "Sentence 2"); + StringPair pair2 = new StringPair("Sentence 3", "Sentence 4"); + + Block block = + new LambdaBlock( + a -> { + NDManager manager = a.getManager(); + NDArray array = manager.create(new float[][] {{-0.7329f}, {-0.7329f}}); + return new NDList(array); + }, + "model"); + Path modelDir = Paths.get("build/model"); + Files.createDirectories(modelDir); + + Criteria criteria = + Criteria.builder() + .setTypes(StringPair[].class, float[][].class) + .optModelPath(modelDir) + .optBlock(block) + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-cased") + .optOption("hasParameter", "false") + .optTranslatorFactory(new CrossEncoderTranslatorFactory()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + StringPair[] inputs = {pair1, pair2}; + float[][] res = predictor.predict(inputs); + Assert.assertEquals(res[1][0], 0.32455865, 0.0001); + } + + Criteria criteria2 = + Criteria.builder() + .setTypes(Input.class, Output.class) + .optModelPath(modelDir) + .optBlock(block) + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-cased") + .optOption("hasParameter", "false") + .optTranslatorFactory(new CrossEncoderTranslatorFactory()) + .build(); + + try (ZooModel model = criteria2.loadModel(); + Predictor predictor = model.newPredictor()) { + Input input = new Input(); + input.add(JsonUtils.GSON.toJson(new StringPair[] {pair1, pair2})); + input.addProperty("Content-Type", "application/json"); + Output out = predictor.predict(input); + float[][] buf = (float[][]) out.getData().getAsObject(); + Assert.assertEquals(buf[0][0], 0.32455865, 0.0001); + } + } +}