From 81bb8f1f0f070c906f079784983dfb4d9f9156a1 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Mon, 31 Jul 2023 11:32:33 -0700 Subject: [PATCH] enable EOS token --- .../modality/nlp/generate/TextGenerator.java | 50 +++++++++++++++---- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java b/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java index b54f3ec0fad..d1d57baff40 100644 --- a/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java +++ b/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java @@ -22,6 +22,7 @@ import ai.djl.ndarray.types.Shape; import ai.djl.translate.TranslateException; +import java.util.Arrays; import java.util.function.Function; import java.util.stream.Collectors; @@ -40,6 +41,8 @@ public class TextGenerator { private NDArray positionOffset; + private long[] endPosition; + /** * Constructs a new {@code TextGenerator} instance. * @@ -60,11 +63,15 @@ public TextGenerator( * Executes greedy search. * * @param inputIds the input token ids. - * @return the output token ids stored as NDArray + * @return the output token ids stored as NDArray and the endPosition of each sentence * @throws TranslateException if forward fails */ @SuppressWarnings("try") public NDArray greedySearch(NDArray inputIds) throws TranslateException { + // Initialize the end position of each sentence + endPosition = new long[Math.toIntExact(inputIds.getShape().get(0))]; + Arrays.fill(endPosition, config.getMaxSeqLength()); + NDArray attentionMask = prepareAttentionMaskOffset(inputIds, config); NDManager manager = inputIds.getManager(); GreedyBatchTensorList searchState = @@ -110,7 +117,12 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException { } // Termination Criteria - // TODO: , delete the sentence and add it to result. + long[] outputIdsArray = searchState.getNextInputIds().toLongArray(); + for (int i = 0; i < outputIdsArray.length; i++) { + if (outputIdsArray[i] == config.getEosTokenId()) { + endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1; + } + } if (searchState.getPastOutputIds().getShape().get(1) + 1 >= config.getMaxSeqLength()) { break; } @@ -123,11 +135,15 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException { * * @param inputIds input tokens ids * @see Beam Search - * @return output tensor + * @return the output token ids stored as NDArray and the endPosition of each sentence * @throws TranslateException if failed run forward */ @SuppressWarnings("try") public NDArray beamSearch(NDArray inputIds) throws TranslateException { + // Initialize the end position of each sentence + endPosition = new long[Math.toIntExact(inputIds.getShape().get(0))]; + Arrays.fill(endPosition, config.getMaxSeqLength()); + NDArray attentionMask = prepareAttentionMaskOffset(inputIds, config); NDManager manager = inputIds.getManager(); long numBeam = config.getBeam(); @@ -223,7 +239,12 @@ public NDArray beamSearch(NDArray inputIds) throws TranslateException { } // Termination Criteria - // TODO: , delete the sentence and add it to result. + long[] outputIdsArray = searchState.getNextInputIds().toLongArray(); + for (int i = 0; i < outputIdsArray.length; i++) { + if (outputIdsArray[i] == config.getEosTokenId()) { + endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1; + } + } if (searchState.getPastOutputIds().getShape().getLastDimension() + 1 >= config.getMaxSeqLength()) { break; @@ -241,7 +262,7 @@ public NDArray beamSearch(NDArray inputIds) throws TranslateException { * * @param inputIds input token ids * @see Contrastive Search - * @return the generated {@code NDArray} + * @return the output token ids stored as NDArray * @throws TranslateException if forward failed */ @SuppressWarnings("try") @@ -249,6 +270,10 @@ public NDArray contrastiveSearch(NDArray inputIds) throws TranslateException { // inputIds: [batchSize, seqLength: t_init] // attentionMask: [batchSize, pastSeq]. seq-dim-size = |past_seq| + |inputIds|. + // Initialize the end position of each sentence + endPosition = new long[Math.toIntExact(inputIds.getShape().get(0))]; + Arrays.fill(endPosition, config.getMaxSeqLength()); + NDManager manager = inputIds.getManager(); NDArray attentionMask = prepareAttentionMaskOffset(inputIds, config); ContrastiveBatchTensorList searchState = new ContrastiveBatchTensorList(); @@ -339,13 +364,11 @@ public NDArray contrastiveSearch(NDArray inputIds) throws TranslateException { NDScope.unregister(searchState.getPastKeyValues()); } - // TODO: , delete the sentence and add it to result. + // Termination Criteria long[] outputIdsArray = searchState.getPastOutputIds().toLongArray(); for (int i = 0; i < outputIdsArray.length; i++) { if (outputIdsArray[i] == config.getEosTokenId()) { - if (!exitIndexEndPosition.containsKey((long) i)) { - exitIndexEndPosition.put((long) i, seqLength); - } + endPosition[i] = searchState.getPastOutputIds().getShape().get(1); } } if (searchState.getPastOutputIds().getShape().get(1) >= config.getMaxSeqLength()) { @@ -572,4 +595,13 @@ public NDArray generate(NDArray inputIds) throws TranslateException { public NDArray getPositionOffset() { return positionOffset; } + + /** + * Gets the end position of each sentence induced by EOS tokenId or reaching maxSeqLength. + * + * @return the end position of each sentence + */ + public long[] getEndPosition() { + return endPosition; + } }