diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslatorFactory.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslatorFactory.java deleted file mode 100644 index b3b765bbf1e..00000000000 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslatorFactory.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.huggingface.translator; - -import ai.djl.Model; -import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; -import ai.djl.modality.Input; -import ai.djl.modality.Output; -import ai.djl.modality.nlp.translator.CrossEncoderServingTranslator; -import ai.djl.translate.TranslateException; -import ai.djl.translate.Translator; -import ai.djl.translate.TranslatorFactory; -import ai.djl.util.Pair; -import ai.djl.util.StringPair; - -import java.io.IOException; -import java.io.Serializable; -import java.lang.reflect.Type; -import java.nio.file.Path; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; - -/** A {@link TranslatorFactory} that creates a {@link CrossEncoderTranslatorFactory} instance. */ -public class CrossEncoderTranslatorFactory implements TranslatorFactory, Serializable { - - private static final long serialVersionUID = 1L; - - private static final Set> SUPPORTED_TYPES = new HashSet<>(); - - static { - SUPPORTED_TYPES.add(new Pair<>(StringPair.class, float[].class)); - SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); - } - - /** {@inheritDoc} */ - @Override - public Set> getSupportedTypes() { - return SUPPORTED_TYPES; - } - - /** {@inheritDoc} */ - @Override - @SuppressWarnings("unchecked") - public Translator newInstance( - Class input, Class output, Model model, Map arguments) - throws TranslateException { - Path modelPath = model.getModelPath(); - try { - HuggingFaceTokenizer tokenizer = - HuggingFaceTokenizer.builder(arguments) - .optTokenizerPath(modelPath) - .optManager(model.getNDManager()) - .build(); - CrossEncoderTranslator translator = - CrossEncoderTranslator.builder(tokenizer, arguments).build(); - if (input == StringPair.class && output == float[].class) { - return (Translator) translator; - } else if (input == Input.class && output == Output.class) { - return (Translator) new CrossEncoderServingTranslator(translator); - } - throw new IllegalArgumentException("Unsupported input/output types."); - } catch (IOException e) { - throw new TranslateException("Failed to load tokenizer.", e); - } - } -} diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslatorFactory.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslatorFactory.java index 1f8dd5a5164..d91c35cf2ea 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslatorFactory.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslatorFactory.java @@ -16,11 +16,14 @@ import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; import ai.djl.modality.Input; import ai.djl.modality.Output; +import ai.djl.modality.nlp.translator.CrossEncoderServingTranslator; import ai.djl.modality.nlp.translator.TextEmbeddingServingTranslator; +import ai.djl.translate.ArgumentsUtil; import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorFactory; import ai.djl.util.Pair; +import ai.djl.util.StringPair; import java.io.IOException; import java.io.Serializable; @@ -39,6 +42,7 @@ public class TextEmbeddingTranslatorFactory implements TranslatorFactory, Serial static { SUPPORTED_TYPES.add(new Pair<>(String.class, float[].class)); + SUPPORTED_TYPES.add(new Pair<>(StringPair.class, float[].class)); SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); } @@ -61,6 +65,17 @@ public Translator newInstance( .optTokenizerPath(modelPath) .optManager(model.getNDManager()) .build(); + if (ArgumentsUtil.booleanValue(arguments, "reranking")) { + CrossEncoderTranslator translator = + CrossEncoderTranslator.builder(tokenizer, arguments).build(); + if (input == StringPair.class && output == float[].class) { + return (Translator) translator; + } else if (input == Input.class && output == Output.class) { + return (Translator) new CrossEncoderServingTranslator(translator); + } + throw new IllegalArgumentException("Unsupported input/output types."); + } + TextEmbeddingTranslator translator = TextEmbeddingTranslator.builder(tokenizer, arguments).build(); if (input == String.class && output == float[].class) { diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java index 2a98f63db65..33cbd9bd560 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java @@ -14,7 +14,7 @@ import ai.djl.Model; import ai.djl.ModelException; -import ai.djl.huggingface.translator.CrossEncoderTranslatorFactory; +import ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory; import ai.djl.inference.Predictor; import ai.djl.modality.Input; import ai.djl.modality.Output; @@ -65,8 +65,9 @@ public void testCrossEncoderTranslator() .optEngine("PyTorch") .optArgument("tokenizer", "bert-base-cased") .optArgument("tokenizerPath", modelDir) + .optArgument("reranking", true) .optOption("hasParameter", "false") - .optTranslatorFactory(new CrossEncoderTranslatorFactory()) + .optTranslatorFactory(new TextEmbeddingTranslatorFactory()) .build(); try (ZooModel model = criteria.loadModel(); @@ -83,8 +84,9 @@ public void testCrossEncoderTranslator() .optBlock(block) .optEngine("PyTorch") .optArgument("tokenizer", "bert-base-cased") + .optArgument("reranking", true) .optOption("hasParameter", "false") - .optTranslatorFactory(new CrossEncoderTranslatorFactory()) + .optTranslatorFactory(new TextEmbeddingTranslatorFactory()) .build(); try (ZooModel model = criteria2.loadModel(); @@ -131,7 +133,7 @@ public void testCrossEncoderTranslator() options.put("hasParameter", "false"); model.load(modelDir, "test", options); - CrossEncoderTranslatorFactory factory = new CrossEncoderTranslatorFactory(); + TextEmbeddingTranslatorFactory factory = new TextEmbeddingTranslatorFactory(); Map arguments = new HashMap<>(); Assert.assertThrows( @@ -139,6 +141,7 @@ public void testCrossEncoderTranslator() () -> factory.newInstance(String.class, Integer.class, model, arguments)); arguments.put("tokenizer", "bert-base-cased"); + arguments.put("reranking", "true"); Assert.assertThrows( IllegalArgumentException.class,