Skip to content

Commit

Permalink
[huggingface] allow creating BPE huggingface tokenizers
Browse files Browse the repository at this point in the history
  • Loading branch information
larochef committed Apr 19, 2023
1 parent ff94632 commit c783242
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 0 deletions.
28 changes: 28 additions & 0 deletions extensions/tokenizers/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use tk::utils::padding::{PaddingParams, PaddingStrategy};
use tk::utils::truncation::{TruncationParams, TruncationStrategy};
use tk::Tokenizer;
use tk::{FromPretrainedParameters, Offsets};
use tk::models::bpe::BPE;

use jni::objects::{JClass, JMethodID, JObject, JString, JValue, ReleaseMode};
use jni::sys::{jboolean, jint, jlong, jlongArray, jobjectArray, jsize, jstring, JNI_TRUE};
Expand Down Expand Up @@ -693,6 +694,33 @@ fn to_handle<T: 'static>(val: T) -> jlong {
handle
}

// Tokenizer using BPE model
#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_createBPETokenizer(
env: JNIEnv,
_: JObject,
vocabulary: JString,
merges: JString,
) -> jlong {
let vocabulary: String = env
.get_string(vocabulary)
.expect("Couldn't get java string!")
.into();

let merges: String = env
.get_string(merges)
.expect("Couldn't get java string!")
.into();

match BPE::from_file(&merges, &vocabulary).build() {
Ok(model) => to_handle(Tokenizer::new(model)),
Err(err) => {
env.throw(err.to_string()).unwrap();
0
}
}
}

fn cast_handle<T>(handle: jlong) -> &'static mut T {
assert_ne!(handle, 0, "Invalid handle value");

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.huggingface.tokenizers;

import ai.djl.huggingface.tokenizers.jni.TokenizersLibrary;

import java.nio.file.Path;

/**
* Class used to create tokenizer from a BPE vocabulary and merges files.
* The builder should then be used by a HuggingFace tokenizer builder
* to wrap the BPE model in the tokenizer.
*/
public final class BPETokenizerBuilder implements NativeTokenizerBuilder {

final Path vocabulary;
final Path merges;

private BPETokenizerBuilder(Path vocabulary, Path merges) {
this.vocabulary = vocabulary;
this.merges = merges;
}

/**
* Create a new builder.
* @param vocabulary the vocab.json file
* @param merges the merges.txt file
* @return the builder, ready to be built
*/
static BPETokenizerBuilder newBuilder(Path vocabulary, Path merges) {
return new BPETokenizerBuilder(vocabulary, merges);
}

/**
* Builds the native tokenizer.
* @return the handler to the native tokenizer
*/
@Override
public long build() {
return TokenizersLibrary.LIB.createBPETokenizer(
vocabulary.toAbsolutePath().toString(),
merges.toAbsolutePath().toString()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,8 @@ public static final class Builder {
private NDManager manager;
private Map<String, String> options;

private NativeTokenizerBuilder nativeTokenizerBuilder;

Builder() {
options = new ConcurrentHashMap<>();
options.put("addSpecialTokens", "true");
Expand Down Expand Up @@ -729,6 +731,18 @@ public void configure(Map<String, ?> arguments) {
}
}

/**
* Configure the builder to use an alternate model for the tokenizer.
* If a tokenizerName was set, this parameter will be ignored.
*
* @param nativeTokenizerBuilder the builder to use to create the native tokenizer
* @return the updated tokenizer
*/
public Builder optNativeTokenizerBuilder(NativeTokenizerBuilder nativeTokenizerBuilder) {
this.nativeTokenizerBuilder = nativeTokenizerBuilder;
return this;
}

/**
* Utility to make a tokenizer managed by the builder manager (if one is specified).
*
Expand All @@ -753,6 +767,10 @@ public HuggingFaceTokenizer build() throws IOException {
if (tokenizerName != null) {
return managed(HuggingFaceTokenizer.newInstance(tokenizerName, options));
}
if(nativeTokenizerBuilder != null) {
long tokenizerHandle = nativeTokenizerBuilder.build();
return managed(new HuggingFaceTokenizer(tokenizerHandle, options));
}
if (tokenizerPath == null) {
throw new IllegalArgumentException("Missing tokenizer path.");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.huggingface.tokenizers;

/**
* This interface is meant to build tokenizers using different models then the default one.
* Implementing it allows to integrate with the HuggingFace tokenizer builder.
*/
public interface NativeTokenizerBuilder {
/**
* Builds the native tokenizer.
* @return the handle to the native tokenizer to wrap
*/
long build();
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,6 @@ public native void setPadding(

public native void setTruncation(
long tokenizer, int maxLength, String truncationStrategy, int stride);

public native long createBPETokenizer(String vocabulary, String merges);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.huggingface.tokenizers;

import ai.djl.training.util.DownloadUtils;
import org.testng.annotations.Test;

import java.io.IOException;
import java.nio.file.Paths;

public class BPETokenizerBuilderTest {

@Test
public void testTokenizerCreation() throws IOException {
String vocabDestination = "build/tokenizer/BPE/vocab.json";
String mergesDestination = "build/tokenizer/BPE/merges.txt";

DownloadUtils.download(
"https://huggingface.co/flaubert/flaubert_base_uncased/raw/main/vocab.json",
vocabDestination);
DownloadUtils.download(
"https://huggingface.co/flaubert/flaubert_base_uncased/raw/main/merges.txt",
mergesDestination);

BPETokenizerBuilder builder = BPETokenizerBuilder.newBuilder(
Paths.get(vocabDestination),
Paths.get(mergesDestination)
);

HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder().optNativeTokenizerBuilder(builder).build();
// TODO: there should be some expectations here
tokenizer.close();
}
}

0 comments on commit c783242

Please sign in to comment.