Skip to content

Commit

Permalink
[api] Adds serving support for some CV models (#3499)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Oct 6, 2024
1 parent 8be1c96 commit 281a780
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,26 @@
import ai.djl.modality.cv.output.CategoryMask;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;

import java.io.Serializable;
import java.lang.reflect.Type;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/** A {@link TranslatorFactory} that creates a {@link SemanticSegmentationTranslator} instance. */
public class SemanticSegmentationTranslatorFactory implements TranslatorFactory, Serializable {
public class SemanticSegmentationTranslatorFactory extends BaseImageTranslatorFactory<CategoryMask>
implements Serializable {

private static final long serialVersionUID = 1L;

private static final Set<Pair<Type, Type>> SUPPORTED_TYPES = new HashSet<>();

static {
SUPPORTED_TYPES.add(new Pair<>(Image.class, CategoryMask.class));
}

/** {@inheritDoc} */
@Override
@SuppressWarnings("unchecked")
public <I, O> Translator<I, O> newInstance(
Class<I> input, Class<O> output, Model model, Map<String, ?> arguments) {
if (input == Image.class && output == CategoryMask.class) {
return (Translator<I, O>) SemanticSegmentationTranslator.builder(arguments).build();
}
throw new IllegalArgumentException("Unsupported input/output types.");
protected Translator<Image, CategoryMask> buildBaseTranslator(
Model model, Map<String, ?> arguments) {
return SemanticSegmentationTranslator.builder(arguments).build();
}

/** {@inheritDoc} */
@Override
public Set<Pair<Type, Type>> getSupportedTypes() {
return SUPPORTED_TYPES;
public Class<CategoryMask> getBaseOutputType() {
return CategoryMask.class;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,49 +13,29 @@
package ai.djl.modality.cv.translator;

import ai.djl.Model;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.Joints;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;

import java.io.Serializable;
import java.lang.reflect.Type;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/** An {@link TranslatorFactory} that creates a {@link SimplePoseTranslator} instance. */
public class SimplePoseTranslatorFactory implements TranslatorFactory, Serializable {
public class SimplePoseTranslatorFactory extends BaseImageTranslatorFactory<Joints>
implements Serializable {

private static final long serialVersionUID = 1L;

private static final Set<Pair<Type, Type>> SUPPORTED_TYPES = new HashSet<>();

static {
SUPPORTED_TYPES.add(new Pair<>(Image.class, Joints.class));
SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class));
}

/** {@inheritDoc} */
@Override
public Set<Pair<Type, Type>> getSupportedTypes() {
return SUPPORTED_TYPES;
protected Translator<Image, Joints> buildBaseTranslator(Model model, Map<String, ?> arguments) {
return SimplePoseTranslator.builder(arguments).build();
}

/** {@inheritDoc} */
@Override
@SuppressWarnings("unchecked")
public <I, O> Translator<I, O> newInstance(
Class<I> input, Class<O> output, Model model, Map<String, ?> arguments) {
SimplePoseTranslator translator = SimplePoseTranslator.builder(arguments).build();
if (input == Image.class && output == Joints.class) {
return (Translator<I, O>) translator;
} else if (input == Input.class && output == Output.class) {
return (Translator<I, O>) new ImageServingTranslator(translator);
}
throw new IllegalArgumentException("Unsupported input/output types.");
public Class<Joints> getBaseOutputType() {
return Joints.class;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,25 @@
import ai.djl.modality.cv.Image;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;

import java.io.Serializable;
import java.lang.reflect.Type;
import java.util.Collections;
import java.util.Map;
import java.util.Set;

/** A {@link TranslatorFactory} that creates a {@link StyleTransferTranslator} instance. */
public class StyleTransferTranslatorFactory implements TranslatorFactory, Serializable {
public class StyleTransferTranslatorFactory extends BaseImageTranslatorFactory<Image>
implements Serializable {

private static final long serialVersionUID = 1L;

/** {@inheritDoc} */
@Override
public Set<Pair<Type, Type>> getSupportedTypes() {
return Collections.singleton(new Pair<>(Image.class, Image.class));
protected Translator<Image, Image> buildBaseTranslator(Model model, Map<String, ?> arguments) {
return new StyleTransferTranslator();
}

/** {@inheritDoc} */
@Override
@SuppressWarnings("unchecked")
public <I, O> Translator<I, O> newInstance(
Class<I> input, Class<O> output, Model model, Map<String, ?> arguments) {
if (!isSupported(input, output)) {
throw new IllegalArgumentException("Unsupported input/output types.");
}
return (Translator<I, O>) new StyleTransferTranslator();
public Class<Image> getBaseOutputType() {
return Image.class;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import ai.djl.Model;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.CategoryMask;
import ai.djl.translate.Translator;
Expand All @@ -37,7 +36,7 @@ public void setUp() {

@Test
public void testGetSupportedTypes() {
Assert.assertEquals(factory.getSupportedTypes().size(), 1);
Assert.assertEquals(factory.getSupportedTypes().size(), 6);
}

@Test
Expand All @@ -49,7 +48,7 @@ public void testNewInstance() {
Assert.assertTrue(translator instanceof SemanticSegmentationTranslator);
Assert.assertThrows(
IllegalArgumentException.class,
() -> factory.newInstance(Input.class, Output.class, model, arguments));
() -> factory.newInstance(Input.class, Image.class, model, arguments));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public void setUp() {

@Test
public void testGetSupportedTypes() {
Assert.assertEquals(factory.getSupportedTypes().size(), 2);
Assert.assertEquals(factory.getSupportedTypes().size(), 6);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import ai.djl.Model;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.modality.cv.Image;
import ai.djl.translate.Translator;

Expand All @@ -36,7 +35,7 @@ public void setUp() {

@Test
public void testGetSupportedTypes() {
Assert.assertEquals(factory.getSupportedTypes().size(), 1);
Assert.assertEquals(factory.getSupportedTypes().size(), 6);
}

@Test
Expand All @@ -48,7 +47,7 @@ public void testNewInstance() {
Assert.assertTrue(translator instanceof StyleTransferTranslator);
Assert.assertThrows(
IllegalArgumentException.class,
() -> factory.newInstance(Input.class, Output.class, model, arguments));
() -> factory.newInstance(Input.class, Image.class, model, arguments));
}
}
}

0 comments on commit 281a780

Please sign in to comment.