Skip to content

Commit

Permalink
[tokenizer] Adds option to disable sigmoid for CrossEncoderTranslator (
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Apr 15, 2024
1 parent a920d21 commit 6be4fd6
Showing 1 changed file with 24 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,17 @@ public class CrossEncoderTranslator implements Translator<StringPair, float[]> {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private boolean sigmoid;
private Batchifier batchifier;

CrossEncoderTranslator(
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
HuggingFaceTokenizer tokenizer,
boolean includeTokenTypes,
boolean sigmoid,
Batchifier batchifier) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.sigmoid = sigmoid;
this.batchifier = batchifier;
}

Expand All @@ -57,8 +62,10 @@ public NDList processInput(TranslatorContext ctx, StringPair input) {
@Override
public float[] processOutput(TranslatorContext ctx, NDList list) {
NDArray logits = list.get(0);
NDArray result = logits.getNDArrayInternal().sigmoid();
return result.toFloatArray();
if (sigmoid) {
logits = logits.getNDArrayInternal().sigmoid();
}
return logits.toFloatArray();
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -97,6 +104,7 @@ public static final class Builder {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private boolean sigmoid;
private Batchifier batchifier = Batchifier.STACK;

Builder(HuggingFaceTokenizer tokenizer) {
Expand All @@ -114,6 +122,17 @@ public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
return this;
}

/**
* Sets if apply sigmoid for the {@link Translator}.
*
* @param sigmoid true to apply sigmoid
* @return this builder
*/
public Builder optSigmoid(boolean sigmoid) {
this.sigmoid = sigmoid;
return this;
}

/**
* Sets the {@link Batchifier} for the {@link Translator}.
*
Expand All @@ -132,6 +151,7 @@ public Builder optBatchifier(Batchifier batchifier) {
*/
public void configure(Map<String, ?> arguments) {
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
optSigmoid(ArgumentsUtil.booleanValue(arguments, "sigmoid", true));
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
optBatchifier(Batchifier.fromString(batchifierStr));
}
Expand All @@ -143,7 +163,7 @@ public void configure(Map<String, ?> arguments) {
* @throws IOException if I/O error occurs
*/
public CrossEncoderTranslator build() throws IOException {
return new CrossEncoderTranslator(tokenizer, includeTokenTypes, batchifier);
return new CrossEncoderTranslator(tokenizer, includeTokenTypes, sigmoid, batchifier);
}
}
}

0 comments on commit 6be4fd6

Please sign in to comment.