diff --git a/.github/workflows/continuous.yml b/.github/workflows/continuous.yml index 2abbcd5da7a..3b91b8168e9 100644 --- a/.github/workflows/continuous.yml +++ b/.github/workflows/continuous.yml @@ -93,7 +93,7 @@ jobs: uses: actions/upload-artifact@v3 if: always() with: - name: reports + name: reports-${{ matrix.operating-system }} path: | ${{ github.workspace }}/**/build/reports/**/* !${{ github.workspace }}/**/build/reports/jacoco/* diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java b/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java index 7590244e4a4..7664b4f7e73 100644 --- a/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java +++ b/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java @@ -118,9 +118,12 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException { // Termination Criteria long[] outputIdsArray = searchState.getNextInputIds().toLongArray(); - for (int i = 0; i < outputIdsArray.length; i++) { - if (outputIdsArray[i] == config.getEosTokenId()) { - endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1; + for (int i = 0; i < endPosition.length; ++i) { + for (long tokenId : outputIdsArray) { + if (tokenId == config.getEosTokenId()) { + endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1; + break; + } } } if (searchState.getPastOutputIds().getShape().get(1) + 1 >= config.getMaxSeqLength()) { @@ -240,9 +243,12 @@ public NDArray beamSearch(NDArray inputIds) throws TranslateException { // Termination Criteria long[] outputIdsArray = searchState.getNextInputIds().toLongArray(); - for (int i = 0; i < outputIdsArray.length; i++) { - if (outputIdsArray[i] == config.getEosTokenId()) { - endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1; + for (int i = 0; i < endPosition.length; ++i) { + for (long tokenId : outputIdsArray) { + if (tokenId == config.getEosTokenId()) { + endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1; + break; + } } } if (searchState.getPastOutputIds().getShape().getLastDimension() + 1 @@ -366,9 +372,12 @@ public NDArray contrastiveSearch(NDArray inputIds) throws TranslateException { // Termination Criteria long[] outputIdsArray = searchState.getPastOutputIds().toLongArray(); - for (int i = 0; i < outputIdsArray.length; i++) { - if (outputIdsArray[i] == config.getEosTokenId()) { - endPosition[i] = searchState.getPastOutputIds().getShape().get(1); + for (int i = 0; i < endPosition.length; ++i) { + for (long tokenId : outputIdsArray) { + if (tokenId == config.getEosTokenId()) { + endPosition[i] = searchState.getPastOutputIds().getShape().get(1); + break; + } } } if (searchState.getPastOutputIds().getShape().get(1) >= config.getMaxSeqLength()) { diff --git a/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java b/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java index 29557174503..8ef4b7ade0c 100644 --- a/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java @@ -28,19 +28,6 @@ public void testTextGeneration() throws TranslateException, ModelException, IOEx TestRequirements.linux(); TestRequirements.weekly(); - // Beam with Ort - String[] output0 = TextGeneration.generateTextWithOnnxRuntimeBeam(); - Assert.assertEquals( - output0[0], - "DeepMind Company is a global leader in the field of artificial intelligence and" - + " artificial intelligence research and development.\n" - + "\n" - + "Our mission is to provide the world with the best and brightest minds in the" - + " field of artificial intelligence and artificial intelligence research and" - + " development.\n" - + "\n" - + "Our mission is to provide the world with the best"); - // Greedy String expected = "DeepMind Company is a global leader in the field of artificial" @@ -82,11 +69,10 @@ public void testTextGeneration() throws TranslateException, ModelException, IOEx + " development.\n" + "\n" + "Our mission is to"); - Assert.assertEquals( - output2[3], - "Memories follow me left and right. I can't tell you how many times I've been told" - + " that I'm not a good person. I'm not a good person. I'm not a good person." - + " I'm not a good person. I'm not a good person. I'm not a"); + Assert.assertTrue( + output2[3].startsWith( + "Memories follow me left and right. I can't tell you how many times I've" + + " been told that I'm not a good person.")); } @Test @@ -119,4 +105,25 @@ public void testSeqBatchScheduler() throws TranslateException, ModelException, I + "But if you're lucky, you can escape from prison and live happily ever" + " after.\n"); } + + @Test + public void testTextGenerationWithOnnx() + throws TranslateException, ModelException, IOException { + TestRequirements.linux(); + TestRequirements.weekly(); + TestRequirements.engine("PyTorch"); + + // Beam with Ort + String[] output0 = TextGeneration.generateTextWithOnnxRuntimeBeam(); + Assert.assertEquals( + output0[0], + "DeepMind Company is a global leader in the field of artificial intelligence and" + + " artificial intelligence research and development.\n" + + "\n" + + "Our mission is to provide the world with the best and brightest minds in the" + + " field of artificial intelligence and artificial intelligence research and" + + " development.\n" + + "\n" + + "Our mission is to provide the world with the best"); + } }