From 6be4fd6012600f8bc5d64d2e4c7cc72c73bb7c70 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 15 Apr 2024 09:32:34 -0700 Subject: [PATCH] [tokenizer] Adds option to disable sigmoid for CrossEncoderTranslator (#3087) --- .../translator/CrossEncoderTranslator.java | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java index b88347bc60e..c3f4db0cc17 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java @@ -30,12 +30,17 @@ public class CrossEncoderTranslator implements Translator { private HuggingFaceTokenizer tokenizer; private boolean includeTokenTypes; + private boolean sigmoid; private Batchifier batchifier; CrossEncoderTranslator( - HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) { + HuggingFaceTokenizer tokenizer, + boolean includeTokenTypes, + boolean sigmoid, + Batchifier batchifier) { this.tokenizer = tokenizer; this.includeTokenTypes = includeTokenTypes; + this.sigmoid = sigmoid; this.batchifier = batchifier; } @@ -57,8 +62,10 @@ public NDList processInput(TranslatorContext ctx, StringPair input) { @Override public float[] processOutput(TranslatorContext ctx, NDList list) { NDArray logits = list.get(0); - NDArray result = logits.getNDArrayInternal().sigmoid(); - return result.toFloatArray(); + if (sigmoid) { + logits = logits.getNDArrayInternal().sigmoid(); + } + return logits.toFloatArray(); } /** {@inheritDoc} */ @@ -97,6 +104,7 @@ public static final class Builder { private HuggingFaceTokenizer tokenizer; private boolean includeTokenTypes; + private boolean sigmoid; private Batchifier batchifier = Batchifier.STACK; Builder(HuggingFaceTokenizer tokenizer) { @@ -114,6 +122,17 @@ public Builder optIncludeTokenTypes(boolean includeTokenTypes) { return this; } + /** + * Sets if apply sigmoid for the {@link Translator}. + * + * @param sigmoid true to apply sigmoid + * @return this builder + */ + public Builder optSigmoid(boolean sigmoid) { + this.sigmoid = sigmoid; + return this; + } + /** * Sets the {@link Batchifier} for the {@link Translator}. * @@ -132,6 +151,7 @@ public Builder optBatchifier(Batchifier batchifier) { */ public void configure(Map arguments) { optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes")); + optSigmoid(ArgumentsUtil.booleanValue(arguments, "sigmoid", true)); String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack"); optBatchifier(Batchifier.fromString(batchifierStr)); } @@ -143,7 +163,7 @@ public void configure(Map arguments) { * @throws IOException if I/O error occurs */ public CrossEncoderTranslator build() throws IOException { - return new CrossEncoderTranslator(tokenizer, includeTokenTypes, batchifier); + return new CrossEncoderTranslator(tokenizer, includeTokenTypes, sigmoid, batchifier); } } }