Skip to content

Commit

Permalink
[examples] Fixes TextGeneration EOS bug
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed May 11, 2024
1 parent c7ba16b commit 1cbe519
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/continuous.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ jobs:
uses: actions/upload-artifact@v3
if: always()
with:
name: reports
name: reports-${{ matrix.operating-system }}
path: |
${{ github.workspace }}/**/build/reports/**/*
!${{ github.workspace }}/**/build/reports/jacoco/*
Expand Down
27 changes: 18 additions & 9 deletions api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,12 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException {

// Termination Criteria
long[] outputIdsArray = searchState.getNextInputIds().toLongArray();
for (int i = 0; i < outputIdsArray.length; i++) {
if (outputIdsArray[i] == config.getEosTokenId()) {
endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1;
for (int i = 0; i < endPosition.length; ++i) {
for (long tokenId : outputIdsArray) {
if (tokenId == config.getEosTokenId()) {
endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1;
break;
}
}
}
if (searchState.getPastOutputIds().getShape().get(1) + 1 >= config.getMaxSeqLength()) {
Expand Down Expand Up @@ -240,9 +243,12 @@ public NDArray beamSearch(NDArray inputIds) throws TranslateException {

// Termination Criteria
long[] outputIdsArray = searchState.getNextInputIds().toLongArray();
for (int i = 0; i < outputIdsArray.length; i++) {
if (outputIdsArray[i] == config.getEosTokenId()) {
endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1;
for (int i = 0; i < endPosition.length; ++i) {
for (long tokenId : outputIdsArray) {
if (tokenId == config.getEosTokenId()) {
endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1;
break;
}
}
}
if (searchState.getPastOutputIds().getShape().getLastDimension() + 1
Expand Down Expand Up @@ -366,9 +372,12 @@ public NDArray contrastiveSearch(NDArray inputIds) throws TranslateException {

// Termination Criteria
long[] outputIdsArray = searchState.getPastOutputIds().toLongArray();
for (int i = 0; i < outputIdsArray.length; i++) {
if (outputIdsArray[i] == config.getEosTokenId()) {
endPosition[i] = searchState.getPastOutputIds().getShape().get(1);
for (int i = 0; i < endPosition.length; ++i) {
for (long tokenId : outputIdsArray) {
if (tokenId == config.getEosTokenId()) {
endPosition[i] = searchState.getPastOutputIds().getShape().get(1);
break;
}
}
}
if (searchState.getPastOutputIds().getShape().get(1) >= config.getMaxSeqLength()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,6 @@ public void testTextGeneration() throws TranslateException, ModelException, IOEx
TestRequirements.linux();
TestRequirements.weekly();

// Beam with Ort
String[] output0 = TextGeneration.generateTextWithOnnxRuntimeBeam();
Assert.assertEquals(
output0[0],
"DeepMind Company is a global leader in the field of artificial intelligence and"
+ " artificial intelligence research and development.\n"
+ "\n"
+ "Our mission is to provide the world with the best and brightest minds in the"
+ " field of artificial intelligence and artificial intelligence research and"
+ " development.\n"
+ "\n"
+ "Our mission is to provide the world with the best");

// Greedy
String expected =
"DeepMind Company is a global leader in the field of artificial"
Expand Down Expand Up @@ -119,4 +106,25 @@ public void testSeqBatchScheduler() throws TranslateException, ModelException, I
+ "But if you're lucky, you can escape from prison and live happily ever"
+ " after.\n");
}

@Test
public void testTextGenerationWithOnnx()
throws TranslateException, ModelException, IOException {
TestRequirements.linux();
TestRequirements.weekly();
TestRequirements.engine("PyTorch");

// Beam with Ort
String[] output0 = TextGeneration.generateTextWithOnnxRuntimeBeam();
Assert.assertEquals(
output0[0],
"DeepMind Company is a global leader in the field of artificial intelligence and"
+ " artificial intelligence research and development.\n"
+ "\n"
+ "Our mission is to provide the world with the best and brightest minds in the"
+ " field of artificial intelligence and artificial intelligence research and"
+ " development.\n"
+ "\n"
+ "Our mission is to provide the world with the best");
}
}

0 comments on commit 1cbe519

Please sign in to comment.