Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[huggingface] allow creating BPE huggingface tokenizers #2550

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion extensions/tokenizers/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use tk::utils::padding::{PaddingParams, PaddingStrategy};
use tk::utils::truncation::{TruncationParams, TruncationStrategy};
use tk::Tokenizer;
use tk::{FromPretrainedParameters, Offsets};
use tk::models::bpe::BPE;

use jni::objects::{JClass, JMethodID, JObject, JString, JValue, ReleaseMode};
use jni::sys::{jboolean, jint, jlong, jlongArray, jobjectArray, jsize, jstring, JNI_TRUE};
Expand Down Expand Up @@ -69,6 +70,33 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_
}
}

// Tokenizer using BPE model
#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_createBpeTokenizer(
env: JNIEnv,
_: JObject,
vocabulary: JString,
merges: JString,
) -> jlong {
let vocabulary: String = env
.get_string(vocabulary)
.expect("Couldn't get java string!")
.into();

let merges: String = env
.get_string(merges)
.expect("Couldn't get java string!")
.into();

match BPE::from_file(&vocabulary, &merges).build() {
Ok(model) => to_handle(Tokenizer::new(model)),
Err(err) => {
env.throw(err.to_string()).unwrap();
0
}
}
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_deleteTokenizer(
_env: JNIEnv,
Expand Down Expand Up @@ -702,6 +730,6 @@ fn cast_handle<T>(handle: jlong) -> &'static mut T {

fn drop_handle<T: 'static>(handle: jlong) {
unsafe {
Box::from_raw(handle as *mut T);
let _ = Box::from_raw(handle as *mut T);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,26 @@ public static HuggingFaceTokenizer newInstance(Path modelPath, Map<String, Strin
}
}

/**
* Create a pre-trained BPE {@code HuggingFaceTokenizer} instance from existing models.
*
* @param vocab the BPE vocabulary file
* @param merges the BPE merges file
* @param options tokenizer options
* @return a {@code HuggingFaceTokenizer} instance
* @throws IOException when IO operation fails in loading a resource
*/
public static HuggingFaceTokenizer newInstance(
Path vocab, Path merges, Map<String, String> options) throws IOException {
Ec2Utils.callHome("Huggingface");
LibUtils.checkStatus();

String vocabFile = vocab.toAbsolutePath().toString();
String mergesFile = merges.toAbsolutePath().toString();
long handle = TokenizersLibrary.LIB.createBpeTokenizer(vocabFile, mergesFile);
return new HuggingFaceTokenizer(handle, options);
}

/**
* Create a pre-trained {@code HuggingFaceTokenizer} instance from {@code InputStream}.
*
Expand Down Expand Up @@ -756,6 +776,20 @@ public HuggingFaceTokenizer build() throws IOException {
if (tokenizerPath == null) {
throw new IllegalArgumentException("Missing tokenizer path.");
}
if (Files.isDirectory(tokenizerPath)) {
Path tokenizerFile = tokenizerPath.resolve("tokenizer.json");
if (Files.exists(tokenizerFile)) {
return managed(HuggingFaceTokenizer.newInstance(tokenizerPath, options));
}
Path vocab = tokenizerPath.resolve("vocab.json");
Path merges = tokenizerPath.resolve("merges.txt");
if (Files.exists(vocab) && Files.exists(merges)) {
return managed(HuggingFaceTokenizer.newInstance(vocab, merges, options));
}
throw new IOException("tokenizer.json file not found.");
} else if (Files.exists(tokenizerPath)) {
throw new IOException("Tokenizer file not exits: " + tokenizerPath);
}
return managed(HuggingFaceTokenizer.newInstance(tokenizerPath, options));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ private TokenizersLibrary() {}

public native long createTokenizerFromString(String json);

public native long createBpeTokenizer(String vocabulary, String merges);

public native void deleteTokenizer(long handle);

public native long encode(long tokenizer, String text, boolean addSpecialTokens);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.huggingface.tokenizers;

import ai.djl.training.util.DownloadUtils;

import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.IOException;
import java.net.URL;
import java.nio.file.Path;
import java.nio.file.Paths;

public class BpeTokenizerBuilderTest {

@Test
public void testTokenizerCreation() throws IOException {
Path bpe = Paths.get("build/BPE");
Path vocab = bpe.resolve("vocab.json");
Path merges = bpe.resolve("merges.txt");

DownloadUtils.download(
new URL("https://huggingface.co/microsoft/layoutlmv3-base/raw/main/vocab.json"),
vocab,
null);
DownloadUtils.download(
new URL("https://huggingface.co/microsoft/layoutlmv3-base/raw/main/merges.txt"),
merges,
null);

try (HuggingFaceTokenizer tokenizer =
HuggingFaceTokenizer.builder().optTokenizerPath(bpe).build()) {
Assert.assertNotNull(tokenizer);
}
}
}