Skip to content

Commit

Permalink
enable EOS token
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Jul 31, 2023
1 parent 2fb20f0 commit 81bb8f1
Showing 1 changed file with 41 additions and 9 deletions.
50 changes: 41 additions & 9 deletions api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslateException;

import java.util.Arrays;
import java.util.function.Function;
import java.util.stream.Collectors;

Expand All @@ -40,6 +41,8 @@ public class TextGenerator {

private NDArray positionOffset;

private long[] endPosition;

/**
* Constructs a new {@code TextGenerator} instance.
*
Expand All @@ -60,11 +63,15 @@ public TextGenerator(
* Executes greedy search.
*
* @param inputIds the input token ids.
* @return the output token ids stored as NDArray
* @return the output token ids stored as NDArray and the endPosition of each sentence
* @throws TranslateException if forward fails
*/
@SuppressWarnings("try")
public NDArray greedySearch(NDArray inputIds) throws TranslateException {
// Initialize the end position of each sentence
endPosition = new long[Math.toIntExact(inputIds.getShape().get(0))];
Arrays.fill(endPosition, config.getMaxSeqLength());

NDArray attentionMask = prepareAttentionMaskOffset(inputIds, config);
NDManager manager = inputIds.getManager();
GreedyBatchTensorList searchState =
Expand Down Expand Up @@ -110,7 +117,12 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException {
}

// Termination Criteria
// TODO: <EOS>, delete the sentence and add it to result.
long[] outputIdsArray = searchState.getNextInputIds().toLongArray();
for (int i = 0; i < outputIdsArray.length; i++) {
if (outputIdsArray[i] == config.getEosTokenId()) {
endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1;
}
}
if (searchState.getPastOutputIds().getShape().get(1) + 1 >= config.getMaxSeqLength()) {
break;
}
Expand All @@ -123,11 +135,15 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException {
*
* @param inputIds input tokens ids
* @see <a href="https://huggingface.co/blog/how-to-generate">Beam Search</a>
* @return output tensor
* @return the output token ids stored as NDArray and the endPosition of each sentence
* @throws TranslateException if failed run forward
*/
@SuppressWarnings("try")
public NDArray beamSearch(NDArray inputIds) throws TranslateException {
// Initialize the end position of each sentence
endPosition = new long[Math.toIntExact(inputIds.getShape().get(0))];
Arrays.fill(endPosition, config.getMaxSeqLength());

NDArray attentionMask = prepareAttentionMaskOffset(inputIds, config);
NDManager manager = inputIds.getManager();
long numBeam = config.getBeam();
Expand Down Expand Up @@ -223,7 +239,12 @@ public NDArray beamSearch(NDArray inputIds) throws TranslateException {
}

// Termination Criteria
// TODO: <EOS>, delete the sentence and add it to result.
long[] outputIdsArray = searchState.getNextInputIds().toLongArray();
for (int i = 0; i < outputIdsArray.length; i++) {
if (outputIdsArray[i] == config.getEosTokenId()) {
endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1;
}
}
if (searchState.getPastOutputIds().getShape().getLastDimension() + 1
>= config.getMaxSeqLength()) {
break;
Expand All @@ -241,14 +262,18 @@ public NDArray beamSearch(NDArray inputIds) throws TranslateException {
*
* @param inputIds input token ids
* @see <a href="https://huggingface.co/blog/introducing-csearch">Contrastive Search</a>
* @return the generated {@code NDArray}
* @return the output token ids stored as NDArray
* @throws TranslateException if forward failed
*/
@SuppressWarnings("try")
public NDArray contrastiveSearch(NDArray inputIds) throws TranslateException {
// inputIds: [batchSize, seqLength: t_init]
// attentionMask: [batchSize, pastSeq]. seq-dim-size = |past_seq| + |inputIds|.

// Initialize the end position of each sentence
endPosition = new long[Math.toIntExact(inputIds.getShape().get(0))];
Arrays.fill(endPosition, config.getMaxSeqLength());

NDManager manager = inputIds.getManager();
NDArray attentionMask = prepareAttentionMaskOffset(inputIds, config);
ContrastiveBatchTensorList searchState = new ContrastiveBatchTensorList();
Expand Down Expand Up @@ -339,13 +364,11 @@ public NDArray contrastiveSearch(NDArray inputIds) throws TranslateException {
NDScope.unregister(searchState.getPastKeyValues());
}

// TODO: <EOS>, delete the sentence and add it to result.
// Termination Criteria
long[] outputIdsArray = searchState.getPastOutputIds().toLongArray();
for (int i = 0; i < outputIdsArray.length; i++) {
if (outputIdsArray[i] == config.getEosTokenId()) {
if (!exitIndexEndPosition.containsKey((long) i)) {
exitIndexEndPosition.put((long) i, seqLength);
}
endPosition[i] = searchState.getPastOutputIds().getShape().get(1);
}
}
if (searchState.getPastOutputIds().getShape().get(1) >= config.getMaxSeqLength()) {
Expand Down Expand Up @@ -572,4 +595,13 @@ public NDArray generate(NDArray inputIds) throws TranslateException {
public NDArray getPositionOffset() {
return positionOffset;
}

/**
* Gets the end position of each sentence induced by EOS tokenId or reaching maxSeqLength.
*
* @return the end position of each sentence
*/
public long[] getEndPosition() {
return endPosition;
}
}

0 comments on commit 81bb8f1

Please sign in to comment.