Skip to content

Commit

Permalink
[text-embedding] Remove CrossEncoderTranslatorFactory in favor of Tex…
Browse files Browse the repository at this point in the history
…tEmbeddingTranslatorFactory (#3239)
  • Loading branch information
frankfliu authored Jun 5, 2024
1 parent f2a1c60 commit bd7f4ca
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 81 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.modality.nlp.translator.CrossEncoderServingTranslator;
import ai.djl.modality.nlp.translator.TextEmbeddingServingTranslator;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import ai.djl.util.StringPair;

import java.io.IOException;
import java.io.Serializable;
Expand All @@ -39,6 +42,7 @@ public class TextEmbeddingTranslatorFactory implements TranslatorFactory, Serial

static {
SUPPORTED_TYPES.add(new Pair<>(String.class, float[].class));
SUPPORTED_TYPES.add(new Pair<>(StringPair.class, float[].class));
SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class));
}

Expand All @@ -61,6 +65,17 @@ public <I, O> Translator<I, O> newInstance(
.optTokenizerPath(modelPath)
.optManager(model.getNDManager())
.build();
if (ArgumentsUtil.booleanValue(arguments, "reranking")) {
CrossEncoderTranslator translator =
CrossEncoderTranslator.builder(tokenizer, arguments).build();
if (input == StringPair.class && output == float[].class) {
return (Translator<I, O>) translator;
} else if (input == Input.class && output == Output.class) {
return (Translator<I, O>) new CrossEncoderServingTranslator(translator);
}
throw new IllegalArgumentException("Unsupported input/output types.");
}

TextEmbeddingTranslator translator =
TextEmbeddingTranslator.builder(tokenizer, arguments).build();
if (input == String.class && output == float[].class) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.huggingface.translator.CrossEncoderTranslatorFactory;
import ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory;
import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
Expand Down Expand Up @@ -65,8 +65,9 @@ public void testCrossEncoderTranslator()
.optEngine("PyTorch")
.optArgument("tokenizer", "bert-base-cased")
.optArgument("tokenizerPath", modelDir)
.optArgument("reranking", true)
.optOption("hasParameter", "false")
.optTranslatorFactory(new CrossEncoderTranslatorFactory())
.optTranslatorFactory(new TextEmbeddingTranslatorFactory())
.build();

try (ZooModel<StringPair, float[]> model = criteria.loadModel();
Expand All @@ -83,8 +84,9 @@ public void testCrossEncoderTranslator()
.optBlock(block)
.optEngine("PyTorch")
.optArgument("tokenizer", "bert-base-cased")
.optArgument("reranking", true)
.optOption("hasParameter", "false")
.optTranslatorFactory(new CrossEncoderTranslatorFactory())
.optTranslatorFactory(new TextEmbeddingTranslatorFactory())
.build();

try (ZooModel<Input, Output> model = criteria2.loadModel();
Expand Down Expand Up @@ -131,14 +133,15 @@ public void testCrossEncoderTranslator()
options.put("hasParameter", "false");
model.load(modelDir, "test", options);

CrossEncoderTranslatorFactory factory = new CrossEncoderTranslatorFactory();
TextEmbeddingTranslatorFactory factory = new TextEmbeddingTranslatorFactory();
Map<String, String> arguments = new HashMap<>();

Assert.assertThrows(
TranslateException.class,
() -> factory.newInstance(String.class, Integer.class, model, arguments));

arguments.put("tokenizer", "bert-base-cased");
arguments.put("reranking", "true");

Assert.assertThrows(
IllegalArgumentException.class,
Expand Down

0 comments on commit bd7f4ca

Please sign in to comment.