Skip to content

Commit

Permalink
Adds TextEmbeddingTranslator (deepjavalibrary#1953)
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanktliu authored and patins1 committed Aug 26, 2022
1 parent 07efdd4 commit 02b5536
Show file tree
Hide file tree
Showing 7 changed files with 378 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.nlp.translator;

import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.NDList;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.JsonUtils;

/** A {@link Translator} that can handle generic text embedding {@link Input} and {@link Output}. */
public class TextEmbeddingServingTranslator implements Translator<Input, Output> {

private Translator<String, float[]> translator;

/**
* Constructs a {@code TextEmbeddingServingTranslator} instance.
*
* @param translator a {@code Translator} processes text embedding input
*/
public TextEmbeddingServingTranslator(Translator<String, float[]> translator) {
this.translator = translator;
}

/** {@inheritDoc} */
@Override
public Batchifier getBatchifier() {
return translator.getBatchifier();
}

/** {@inheritDoc} */
@Override
public void prepare(TranslatorContext ctx) throws Exception {
translator.prepare(ctx);
}

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
String text = input.getData().getAsString();
return translator.processInput(ctx, text);
}

/** {@inheritDoc} */
@Override
public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
float[] ret = translator.processOutput(ctx, list);
Output output = new Output();
output.add(JsonUtils.GSON_PRETTY.toJson(ret));
return output;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.huggingface.translator;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;

import java.io.IOException;
import java.util.Map;

/** The translator for Huggingface text embedding model. */
public class TextEmbeddingTranslator implements Translator<String, float[]> {

private static final int[] AXIS = {0};

private HuggingFaceTokenizer tokenizer;
private Batchifier batchifier;

TextEmbeddingTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) {
this.tokenizer = tokenizer;
this.batchifier = batchifier;
}

/** {@inheritDoc} */
@Override
public Batchifier getBatchifier() {
return batchifier;
}

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, String input) {
NDManager manager = ctx.getNDManager();
Encoding encoding = tokenizer.encode(input);
ctx.setAttachment("encoding", encoding);
long[] indices = encoding.getIds();
long[] attentionMask = encoding.getAttentionMask();
NDList ndList = new NDList(2);
ndList.add(manager.create(indices));
ndList.add(manager.create(attentionMask));
return ndList;
}

/** {@inheritDoc} */
@Override
public float[] processOutput(TranslatorContext ctx, NDList list) {
NDArray embeddings = list.get("last_hidden_state");
Encoding encoding = (Encoding) ctx.getAttachment("encoding");
long[] attentionMask = encoding.getAttentionMask();
NDManager manager = ctx.getNDManager();
NDArray inputAttentionMask = manager.create(attentionMask).toType(DataType.FLOAT32, true);
long[] shape = embeddings.getShape().getShape();
inputAttentionMask = inputAttentionMask.tile(shape[shape.length - 1]);
inputAttentionMask = inputAttentionMask.reshape(embeddings.getShape());
NDArray inputAttentionMaskSum = inputAttentionMask.sum(AXIS);
NDArray clamp = inputAttentionMaskSum.clip(1e-9, 1e12);
NDArray prod = embeddings.mul(inputAttentionMask);
NDArray sum = prod.sum(AXIS);
embeddings = sum.div(clamp).normalize(2, 0);

return embeddings.toFloatArray();
}

/**
* Creates a builder to build a {@code TextEmbeddingTranslator}.
*
* @param tokenizer the tokenizer
* @return a new builder
*/
public static Builder builder(HuggingFaceTokenizer tokenizer) {
return new Builder(tokenizer);
}

/**
* Creates a builder to build a {@code TextEmbeddingTranslator}.
*
* @param tokenizer the tokenizer
* @param arguments the models' arguments
* @return a new builder
*/
public static Builder builder(HuggingFaceTokenizer tokenizer, Map<String, ?> arguments) {
Builder builder = builder(tokenizer);
builder.configure(arguments);

return builder;
}

/** The builder for token classification translator. */
public static final class Builder {

private HuggingFaceTokenizer tokenizer;
private Batchifier batchifier = Batchifier.STACK;

Builder(HuggingFaceTokenizer tokenizer) {
this.tokenizer = tokenizer;
}

/**
* Sets the {@link Batchifier} for the {@link Translator}.
*
* @param batchifier true to include token types
* @return this builder
*/
public TextEmbeddingTranslator.Builder optBatchifier(Batchifier batchifier) {
this.batchifier = batchifier;
return this;
}

/**
* Configures the builder with the model arguments.
*
* @param arguments the model arguments
*/
public void configure(Map<String, ?> arguments) {
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
optBatchifier(Batchifier.fromString(batchifierStr));
}

/**
* Builds the translator.
*
* @return the new translator
* @throws IOException if I/O error occurs
*/
public TextEmbeddingTranslator build() throws IOException {
return new TextEmbeddingTranslator(tokenizer, batchifier);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.huggingface.translator;

import ai.djl.Model;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.modality.nlp.translator.TextEmbeddingServingTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;

import java.io.IOException;
import java.lang.reflect.Type;
import java.nio.file.Path;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/** A {@link TranslatorFactory} that creates a {@link TextEmbeddingTranslator} instance. */
public class TextEmbeddingTranslatorFactory implements TranslatorFactory {

private static final Set<Pair<Type, Type>> SUPPORTED_TYPES = new HashSet<>();

static {
SUPPORTED_TYPES.add(new Pair<>(String.class, float[].class));
SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class));
}

/** {@inheritDoc} */
@Override
public Set<Pair<Type, Type>> getSupportedTypes() {
return SUPPORTED_TYPES;
}

/** {@inheritDoc} */
@Override
public Translator<?, ?> newInstance(
Class<?> input, Class<?> output, Model model, Map<String, ?> arguments)
throws TranslateException {
Path modelPath = model.getModelPath();
try {
HuggingFaceTokenizer tokenizer =
HuggingFaceTokenizer.builder(arguments)
.optTokenizerPath(modelPath)
.optManager(model.getNDManager())
.build();
TextEmbeddingTranslator translator =
TextEmbeddingTranslator.builder(tokenizer, arguments).build();
if (input == String.class && output == float[].class) {
return translator;
} else if (input == Input.class && output == Output.class) {
return new TextEmbeddingServingTranslator(translator);
}
throw new IllegalArgumentException("Unsupported input/output types.");
} catch (IOException e) {
throw new TranslateException("Failed to load tokenizer.", e);
}
}
}
3 changes: 3 additions & 0 deletions extensions/tokenizers/src/main/python/fill_mask_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,6 @@ def verify_jit_output(self, hf_pipeline, encoding, out):
def encode_inputs(self, tokenizer):
text = self.inputs.replace("[MASK]", tokenizer.mask_token)
return tokenizer.encode_plus(text, return_tensors='pt')

def get_extra_arguments(self, hf_pipeline) -> dict:
return {"maskToken": hf_pipeline.tokenizer.mask_token}
12 changes: 8 additions & 4 deletions extensions/tokenizers/src/main/python/huggingface_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def save_model(self, model_id: str, output_dir: str, temp_dir: str):
return False, reason, -1

size = self.save_to_model_zoo(model_id, output_dir, temp_dir,
hf_pipeline.tokenizer.mask_token)
hf_pipeline)

return True, None, size

Expand Down Expand Up @@ -100,7 +100,7 @@ def jit_trace_model(self, hf_pipeline, model_id: str, temp_dir: str):
return model_file

def save_to_model_zoo(self, model_id: str, output_dir: str, temp_dir: str,
mask_token: str):
hf_pipeline):
artifact_ids = model_id.split("/")
model_name = artifact_ids[-1]

Expand All @@ -111,13 +111,14 @@ def save_to_model_zoo(self, model_id: str, output_dir: str, temp_dir: str,

# Save serving.properties
serving_file = os.path.join(temp_dir, "serving.properties")
arguments = self.get_extra_arguments(hf_pipeline)
with open(serving_file, 'w') as f:
f.write(f"engine=PyTorch\n"
f"option.modelName={model_name}\n"
f"option.mapLocation=true\n"
f"translatorFactory={self.translator}\n")
if mask_token:
f.write(f"maskToken={mask_token}\n")
for k, v in arguments.items():
f.write(f"{k}={v}\n")

# Save model as .zip file
logging.info(f"Saving DJL model as zip: {model_name}.zip ...")
Expand Down Expand Up @@ -157,6 +158,9 @@ def verify_jit_model(self, hf_pipeline, model_file: str):

return self.verify_jit_output(hf_pipeline, encoding, out)

def get_extra_arguments(self, hf_pipeline) -> dict:
return {}

def verify_jit_output(self, hf_pipeline, encoding, out):
if not hasattr(out, "last_hidden_layer"):
return False, f"Unexpected inference result: {out}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,6 @@ def verify_jit_output(self, hf_pipeline, encoding, out):
return False, f"Unexpected inference result: {last_hidden_state}"

return True, None

def get_extra_arguments(self, hf_pipeline) -> dict:
return {"padding": "true"}
Loading

0 comments on commit 02b5536

Please sign in to comment.