Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tokenizer] Adds option to disable sigmoid for CrossEncoderTranslator #3087

Merged
merged 1 commit into from
Apr 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,17 @@ public class CrossEncoderTranslator implements Translator<StringPair, float[]> {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private boolean sigmoid;
private Batchifier batchifier;

CrossEncoderTranslator(
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
HuggingFaceTokenizer tokenizer,
boolean includeTokenTypes,
boolean sigmoid,
Batchifier batchifier) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.sigmoid = sigmoid;
this.batchifier = batchifier;
}

Expand All @@ -57,8 +62,10 @@ public NDList processInput(TranslatorContext ctx, StringPair input) {
@Override
public float[] processOutput(TranslatorContext ctx, NDList list) {
NDArray logits = list.get(0);
NDArray result = logits.getNDArrayInternal().sigmoid();
return result.toFloatArray();
if (sigmoid) {
logits = logits.getNDArrayInternal().sigmoid();
}
return logits.toFloatArray();
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -97,6 +104,7 @@ public static final class Builder {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private boolean sigmoid;
private Batchifier batchifier = Batchifier.STACK;

Builder(HuggingFaceTokenizer tokenizer) {
Expand All @@ -114,6 +122,17 @@ public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
return this;
}

/**
* Sets if apply sigmoid for the {@link Translator}.
*
* @param sigmoid true to apply sigmoid
* @return this builder
*/
public Builder optSigmoid(boolean sigmoid) {
this.sigmoid = sigmoid;
return this;
}

/**
* Sets the {@link Batchifier} for the {@link Translator}.
*
Expand All @@ -132,6 +151,7 @@ public Builder optBatchifier(Batchifier batchifier) {
*/
public void configure(Map<String, ?> arguments) {
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
optSigmoid(ArgumentsUtil.booleanValue(arguments, "sigmoid", true));
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
optBatchifier(Batchifier.fromString(batchifierStr));
}
Expand All @@ -143,7 +163,7 @@ public void configure(Map<String, ?> arguments) {
* @throws IOException if I/O error occurs
*/
public CrossEncoderTranslator build() throws IOException {
return new CrossEncoderTranslator(tokenizer, includeTokenTypes, batchifier);
return new CrossEncoderTranslator(tokenizer, includeTokenTypes, sigmoid, batchifier);
}
}
}
Loading