Skip to content

Commit

Permalink
[api] Fixes QaServingTranslator output format and TokenClassification…
Browse files Browse the repository at this point in the history
… crash (#3500)
  • Loading branch information
frankfliu authored Oct 18, 2024
1 parent 9bd717d commit 187e440
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,17 @@
import com.google.gson.reflect.TypeToken;

import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

/**
* A {@link Translator} that can handle generic question answering {@link Input} and {@link Output}.
*/
public class QaServingTranslator implements NoBatchifyTranslator<Input, Output> {

private static final Type LIST_TYPE = new TypeToken<List<String>>() {}.getType();
private static final Type LIST_TYPE = new TypeToken<List<QAInput>>() {}.getType();

private Translator<QAInput, String> translator;

Expand Down Expand Up @@ -116,13 +119,20 @@ public Output processOutput(TranslatorContext ctx, NDList list) throws Exception
Output output = new Output();
output.addProperty("Content-Type", "application/json");
if (ctx.getAttachment("batch") != null) {
output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list)));
List<String> answers = translator.batchProcessOutput(ctx, list);
List<Map<String, String>> ret = new ArrayList<>();
for (String answer : answers) {
ret.add(Collections.singletonMap("answer", answer));
}
output.add(BytesSupplier.wrapAsJson(ret));
} else {
Batchifier batchifier = translator.getBatchifier();
if (batchifier != null) {
list = batchifier.unbatchify(list)[0];
}
output.add(translator.processOutput(ctx, list));
String answer = translator.processOutput(ctx, list);
Map<String, String> ret = Collections.singletonMap("answer", answer);
output.add(BytesSupplier.wrapAsJson(ret));
}
return output;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception
TextPrompt prompt = TextPrompt.parseInput(input);
if (prompt.isBatch()) {
ctx.setAttachment("batch", Boolean.TRUE);
return translator.batchProcessInput(ctx, prompt.getBatch());
}

NDList ret = translator.processInput(ctx, prompt.getText());
Expand Down
4 changes: 2 additions & 2 deletions api/src/main/java/ai/djl/ndarray/BytesSupplierImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class BytesSupplierImpl implements BytesSupplier {
public byte[] getAsBytes() {
if (buf == null) {
if (value == null) {
value = JsonUtils.toJson(obj) + '\n';
value = JsonUtils.toJson(obj);
}
buf = value.getBytes(StandardCharsets.UTF_8);
}
Expand All @@ -52,7 +52,7 @@ public byte[] getAsBytes() {
public String getAsString() {
if (value == null) {
if (obj != null) {
value = JsonUtils.toJson(obj) + '\n';
value = JsonUtils.toJson(obj);
} else {
value = new String(buf, StandardCharsets.UTF_8);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

/**
Expand Down Expand Up @@ -232,6 +233,9 @@ public void close() {
* @return the {@code Encoding} of the input sentence
*/
public Encoding encode(String text, boolean addSpecialTokens, boolean withOverflowingTokens) {
if (text == null) {
throw new NullPointerException("text cannot be null");
}
if (doLowerCase != null) {
text = text.toLowerCase(doLowerCase);
}
Expand Down Expand Up @@ -261,6 +265,10 @@ public Encoding encode(String text) {
*/
public Encoding encode(
String text, String textPair, boolean addSpecialTokens, boolean withOverflowingTokens) {
if (text == null || textPair == null) {
throw new NullPointerException("text/text_pair cannot be null");
}

if (doLowerCase != null) {
text = text.toLowerCase(doLowerCase);
textPair = textPair.toLowerCase(doLowerCase);
Expand Down Expand Up @@ -322,6 +330,8 @@ public Encoding encode(
for (int i = 0; i < inputs.length; ++i) {
inputs[i] = inputs[i].toLowerCase(doLowerCase);
}
} else if (Arrays.stream(inputs).anyMatch(Objects::isNull)) {
throw new NullPointerException("input text cannot be null");
}
long encoding = TokenizersLibrary.LIB.encodeList(getHandle(), inputs, addSpecialTokens);
return toEncoding(encoding, withOverflowingTokens);
Expand Down Expand Up @@ -377,6 +387,8 @@ public Encoding[] batchEncode(
for (int i = 0; i < inputs.length; ++i) {
inputs[i] = inputs[i].toLowerCase(doLowerCase);
}
} else if (Arrays.stream(inputs).anyMatch(Objects::isNull)) {
throw new NullPointerException("input text cannot be null");
}
long[] encodings = TokenizersLibrary.LIB.batchEncode(getHandle(), inputs, addSpecialTokens);
Encoding[] ret = new Encoding[encodings.length];
Expand Down Expand Up @@ -418,6 +430,13 @@ public Encoding[] batchEncode(
for (int i = 0; i < textPair.length; ++i) {
textPair[i] = textPair[i].toLowerCase(doLowerCase);
}
} else {
if (inputs.keys().stream().anyMatch(Objects::isNull)) {
throw new NullPointerException("text pair key cannot be null");
}
if (inputs.values().stream().anyMatch(Objects::isNull)) {
throw new NullPointerException("text pair value cannot be null");
}
}
long[] encodings =
TokenizersLibrary.LIB.batchEncodePair(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -102,6 +103,17 @@ public void testTokenizer() throws IOException {
Assert.assertEquals(charSpansExpected[i].getStart(), charSpansResult[i].getStart());
Assert.assertEquals(charSpansExpected[i].getEnd(), charSpansResult[i].getEnd());
}

Assert.assertThrows(() -> tokenizer.encode((String) null));
Assert.assertThrows(() -> tokenizer.encode(new String[] {null}));
Assert.assertThrows(() -> tokenizer.encode(null, null));
Assert.assertThrows(() -> tokenizer.encode("null", null));
Assert.assertThrows(() -> tokenizer.batchEncode(new String[] {null}));
List<String> empty = Collections.singletonList(null);
List<String> some = Collections.singletonList("null");

Assert.assertThrows(() -> tokenizer.batchEncode(new PairList<>(empty, some)));
Assert.assertThrows(() -> tokenizer.batchEncode(new PairList<>(some, empty)));
}

Map<String, String> options = new ConcurrentHashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public void testQATranslator() throws ModelException, IOException, TranslateExce
input.add("question", question);
input.add("paragraph", paragraph);
Output res = predictor.predict(input);
Assert.assertEquals(res.getAsString(0), "December 2004");
Assert.assertEquals(res.getAsString(0), "{\"answer\":\"December 2004\"}");

Assert.assertThrows(
"Input data is empty.",
Expand Down

0 comments on commit 187e440

Please sign in to comment.