Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[examples] Fixes TextGeneration EOS bug #3177

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/continuous.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ jobs:
uses: actions/upload-artifact@v3
if: always()
with:
name: reports
name: reports-${{ matrix.operating-system }}
path: |
${{ github.workspace }}/**/build/reports/**/*
!${{ github.workspace }}/**/build/reports/jacoco/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,12 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException {

// Termination Criteria
long[] outputIdsArray = searchState.getNextInputIds().toLongArray();
for (int i = 0; i < outputIdsArray.length; i++) {
if (outputIdsArray[i] == config.getEosTokenId()) {
endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1;
for (int i = 0; i < endPosition.length; ++i) {
for (long tokenId : outputIdsArray) {
if (tokenId == config.getEosTokenId()) {
endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1;
break;
}
}
}
if (searchState.getPastOutputIds().getShape().get(1) + 1 >= config.getMaxSeqLength()) {
Expand Down Expand Up @@ -240,9 +243,12 @@ public NDArray beamSearch(NDArray inputIds) throws TranslateException {

// Termination Criteria
long[] outputIdsArray = searchState.getNextInputIds().toLongArray();
for (int i = 0; i < outputIdsArray.length; i++) {
if (outputIdsArray[i] == config.getEosTokenId()) {
endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1;
for (int i = 0; i < endPosition.length; ++i) {
for (long tokenId : outputIdsArray) {
if (tokenId == config.getEosTokenId()) {
endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1;
break;
}
}
}
if (searchState.getPastOutputIds().getShape().getLastDimension() + 1
Expand Down Expand Up @@ -366,9 +372,12 @@ public NDArray contrastiveSearch(NDArray inputIds) throws TranslateException {

// Termination Criteria
long[] outputIdsArray = searchState.getPastOutputIds().toLongArray();
for (int i = 0; i < outputIdsArray.length; i++) {
if (outputIdsArray[i] == config.getEosTokenId()) {
endPosition[i] = searchState.getPastOutputIds().getShape().get(1);
for (int i = 0; i < endPosition.length; ++i) {
for (long tokenId : outputIdsArray) {
if (tokenId == config.getEosTokenId()) {
endPosition[i] = searchState.getPastOutputIds().getShape().get(1);
break;
}
}
}
if (searchState.getPastOutputIds().getShape().get(1) >= config.getMaxSeqLength()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,6 @@ public void testTextGeneration() throws TranslateException, ModelException, IOEx
TestRequirements.linux();
TestRequirements.weekly();

// Beam with Ort
String[] output0 = TextGeneration.generateTextWithOnnxRuntimeBeam();
Assert.assertEquals(
output0[0],
"DeepMind Company is a global leader in the field of artificial intelligence and"
+ " artificial intelligence research and development.\n"
+ "\n"
+ "Our mission is to provide the world with the best and brightest minds in the"
+ " field of artificial intelligence and artificial intelligence research and"
+ " development.\n"
+ "\n"
+ "Our mission is to provide the world with the best");

// Greedy
String expected =
"DeepMind Company is a global leader in the field of artificial"
Expand Down Expand Up @@ -82,11 +69,10 @@ public void testTextGeneration() throws TranslateException, ModelException, IOEx
+ " development.\n"
+ "\n"
+ "Our mission is to");
Assert.assertEquals(
output2[3],
"Memories follow me left and right. I can't tell you how many times I've been told"
+ " that I'm not a good person. I'm not a good person. I'm not a good person."
+ " I'm not a good person. I'm not a good person. I'm not a");
Assert.assertTrue(
output2[3].startsWith(
"Memories follow me left and right. I can't tell you how many times I've"
+ " been told that I'm not a good person."));
}

@Test
Expand Down Expand Up @@ -119,4 +105,25 @@ public void testSeqBatchScheduler() throws TranslateException, ModelException, I
+ "But if you're lucky, you can escape from prison and live happily ever"
+ " after.\n");
}

@Test
public void testTextGenerationWithOnnx()
throws TranslateException, ModelException, IOException {
TestRequirements.linux();
TestRequirements.weekly();
TestRequirements.engine("PyTorch");

// Beam with Ort
String[] output0 = TextGeneration.generateTextWithOnnxRuntimeBeam();
Assert.assertEquals(
output0[0],
"DeepMind Company is a global leader in the field of artificial intelligence and"
+ " artificial intelligence research and development.\n"
+ "\n"
+ "Our mission is to provide the world with the best and brightest minds in the"
+ " field of artificial intelligence and artificial intelligence research and"
+ " development.\n"
+ "\n"
+ "Our mission is to provide the world with the best");
}
}
Loading