Skip to content

Commit

Permalink
[tokenizer] Add do_lower_case support (#3069)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Apr 26, 2024
1 parent 398cda6 commit dd2fa85
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

Expand All @@ -45,6 +46,7 @@ public final class HuggingFaceTokenizer extends NativeResource<Long> implements

private boolean addSpecialTokens;
private boolean withOverflowingTokens;
private Locale doLowerCase;
private TruncationStrategy truncation;
private PaddingStrategy padding;
private int maxLength;
Expand Down Expand Up @@ -77,6 +79,12 @@ private HuggingFaceTokenizer(long handle, Map<String, String> options) {
maxLength = ArgumentsUtil.intValue(options, "maxLength", maxLength);
stride = ArgumentsUtil.intValue(options, "stride", stride);
padToMultipleOf = ArgumentsUtil.intValue(options, "padToMultipleOf", padToMultipleOf);
String lowerCase = options.getOrDefault("doLowerCase", "false");
if ("true".equals(lowerCase)) {
this.doLowerCase = Locale.getDefault();
} else if (!"false".equals(lowerCase)) {
this.doLowerCase = Locale.forLanguageTag(lowerCase);
}
} else {
addSpecialTokens = true;
modelMaxLength = 512;
Expand Down Expand Up @@ -210,6 +218,9 @@ public void close() {
* @return the {@code Encoding} of the input sentence
*/
public Encoding encode(String text, boolean addSpecialTokens, boolean withOverflowingTokens) {
if (doLowerCase != null) {
text = text.toLowerCase(doLowerCase);
}
long encoding = TokenizersLibrary.LIB.encode(getHandle(), text, addSpecialTokens);
return toEncoding(encoding, withOverflowingTokens);
}
Expand All @@ -236,6 +247,11 @@ public Encoding encode(String text) {
*/
public Encoding encode(
String text, String textPair, boolean addSpecialTokens, boolean withOverflowingTokens) {
if (doLowerCase != null) {
text = text.toLowerCase(doLowerCase);
textPair = textPair.toLowerCase(doLowerCase);
}

long encoding =
TokenizersLibrary.LIB.encodeDual(getHandle(), text, textPair, addSpecialTokens);
return toEncoding(encoding, withOverflowingTokens);
Expand Down Expand Up @@ -288,6 +304,11 @@ public Encoding encode(List<String> inputs) {
*/
public Encoding encode(
String[] inputs, boolean addSpecialTokens, boolean withOverflowingTokens) {
if (doLowerCase != null) {
for (int i = 0; i < inputs.length; ++i) {
inputs[i] = inputs[i].toLowerCase(doLowerCase);
}
}
long encoding = TokenizersLibrary.LIB.encodeList(getHandle(), inputs, addSpecialTokens);
return toEncoding(encoding, withOverflowingTokens);
}
Expand Down Expand Up @@ -338,6 +359,11 @@ public Encoding[] batchEncode(List<String> inputs) {
*/
public Encoding[] batchEncode(
String[] inputs, boolean addSpecialTokens, boolean withOverflowingTokens) {
if (doLowerCase != null) {
for (int i = 0; i < inputs.length; ++i) {
inputs[i] = inputs[i].toLowerCase(doLowerCase);
}
}
long[] encodings = TokenizersLibrary.LIB.batchEncode(getHandle(), inputs, addSpecialTokens);
Encoding[] ret = new Encoding[encodings.length];
for (int i = 0; i < encodings.length; ++i) {
Expand Down Expand Up @@ -371,6 +397,14 @@ public Encoding[] batchEncode(
boolean withOverflowingTokens) {
String[] text = inputs.keyArray(Utils.EMPTY_ARRAY);
String[] textPair = inputs.valueArray(Utils.EMPTY_ARRAY);
if (doLowerCase != null) {
for (int i = 0; i < text.length; ++i) {
text[i] = text[i].toLowerCase(doLowerCase);
}
for (int i = 0; i < textPair.length; ++i) {
textPair[i] = textPair[i].toLowerCase(doLowerCase);
}
}
long[] encodings =
TokenizersLibrary.LIB.batchEncodePair(
getHandle(), text, textPair, addSpecialTokens);
Expand Down Expand Up @@ -821,6 +855,28 @@ public Builder optStride(int stride) {
return this;
}

/**
* Sets the doLowerCase for the tokenizer.
*
* @param doLowerCase {@code true} to enable convert to lowercase
* @return this builder
*/
public Builder optDoLowerCase(boolean doLowerCase) {
options.put("doLowerCase", String.valueOf(doLowerCase));
return this;
}

/**
* Sets the doLowerCase for the tokenizer with specific locale.
*
* @param locale the locale to use when converting to lowercase
* @return this builder
*/
public Builder optDoLowerCase(String locale) {
options.put("doLowerCase", locale);
return this;
}

/**
* Configures the builder with the arguments.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

Expand Down Expand Up @@ -129,6 +130,37 @@ public void testTokenizer() throws IOException {
});
}

@Test
public void testDoLowerCase() throws IOException {
String input = "Hello, y'all! How are you 😁 ?";
String[] inputs = {"Hello, y'all!", "How are you 😁 ?"};
try (HuggingFaceTokenizer tokenizer =
HuggingFaceTokenizer.builder()
.optTokenizerName("bert-base-cased")
.optAddSpecialTokens(false)
.optDoLowerCase(true)
.build()) {
Encoding encoding = tokenizer.encode(inputs);
String sentence = tokenizer.buildSentence(Arrays.asList(encoding.getTokens()));
Assert.assertEquals(sentence, "hello , y ' all ! how are you [UNK] ?");

encoding = tokenizer.encode(input);
Assert.assertEquals(encoding.getTokens().length, 11);

encoding = tokenizer.encode(input, "How are you my friend");
Assert.assertEquals(encoding.getTokens().length, 16);

Encoding[] encodings = tokenizer.batchEncode(inputs);
Assert.assertEquals(encodings.length, 2);

PairList<String, String> batch = new PairList<>(2);
batch.add("Hello", "How are you");
batch.add("Hi, you all", "I'm fine.");
encodings = tokenizer.batchEncode(batch);
Assert.assertEquals(encodings.length, 2);
}
}

@Test
public void testTokenizerDecoding() throws IOException {
long[][] testIds = {
Expand Down Expand Up @@ -383,6 +415,7 @@ public void testTruncationAndPaddingForPairInputs() throws IOException {
.optTokenizerName("bert-base-cased")
.optTruncateSecondOnly()
.optMaxLength(8)
.optDoLowerCase(Locale.ROOT.toLanguageTag())
.build()) {
Encoding encoding = tokenizer.encode(text, textPair);
Assert.assertEquals(encoding.getIds().length, 8);
Expand Down

0 comments on commit dd2fa85

Please sign in to comment.