diff --git a/api/src/main/java/ai/djl/modality/cv/translator/SemanticSegmentationTranslatorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/SemanticSegmentationTranslatorFactory.java index 1bc7f5a5cd5..6840eca4d71 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/SemanticSegmentationTranslatorFactory.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/SemanticSegmentationTranslatorFactory.java @@ -17,39 +17,26 @@ import ai.djl.modality.cv.output.CategoryMask; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorFactory; -import ai.djl.util.Pair; import java.io.Serializable; -import java.lang.reflect.Type; -import java.util.HashSet; import java.util.Map; -import java.util.Set; /** A {@link TranslatorFactory} that creates a {@link SemanticSegmentationTranslator} instance. */ -public class SemanticSegmentationTranslatorFactory implements TranslatorFactory, Serializable { +public class SemanticSegmentationTranslatorFactory extends BaseImageTranslatorFactory + implements Serializable { private static final long serialVersionUID = 1L; - private static final Set> SUPPORTED_TYPES = new HashSet<>(); - - static { - SUPPORTED_TYPES.add(new Pair<>(Image.class, CategoryMask.class)); - } - /** {@inheritDoc} */ @Override - @SuppressWarnings("unchecked") - public Translator newInstance( - Class input, Class output, Model model, Map arguments) { - if (input == Image.class && output == CategoryMask.class) { - return (Translator) SemanticSegmentationTranslator.builder(arguments).build(); - } - throw new IllegalArgumentException("Unsupported input/output types."); + protected Translator buildBaseTranslator( + Model model, Map arguments) { + return SemanticSegmentationTranslator.builder(arguments).build(); } /** {@inheritDoc} */ @Override - public Set> getSupportedTypes() { - return SUPPORTED_TYPES; + public Class getBaseOutputType() { + return CategoryMask.class; } } diff --git a/api/src/main/java/ai/djl/modality/cv/translator/SimplePoseTranslatorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/SimplePoseTranslatorFactory.java index c8fe84991fb..fd5696e5b16 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/SimplePoseTranslatorFactory.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/SimplePoseTranslatorFactory.java @@ -13,49 +13,29 @@ package ai.djl.modality.cv.translator; import ai.djl.Model; -import ai.djl.modality.Input; -import ai.djl.modality.Output; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.output.Joints; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorFactory; -import ai.djl.util.Pair; import java.io.Serializable; -import java.lang.reflect.Type; -import java.util.HashSet; import java.util.Map; -import java.util.Set; /** An {@link TranslatorFactory} that creates a {@link SimplePoseTranslator} instance. */ -public class SimplePoseTranslatorFactory implements TranslatorFactory, Serializable { +public class SimplePoseTranslatorFactory extends BaseImageTranslatorFactory + implements Serializable { private static final long serialVersionUID = 1L; - private static final Set> SUPPORTED_TYPES = new HashSet<>(); - - static { - SUPPORTED_TYPES.add(new Pair<>(Image.class, Joints.class)); - SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); - } - /** {@inheritDoc} */ @Override - public Set> getSupportedTypes() { - return SUPPORTED_TYPES; + protected Translator buildBaseTranslator(Model model, Map arguments) { + return SimplePoseTranslator.builder(arguments).build(); } /** {@inheritDoc} */ @Override - @SuppressWarnings("unchecked") - public Translator newInstance( - Class input, Class output, Model model, Map arguments) { - SimplePoseTranslator translator = SimplePoseTranslator.builder(arguments).build(); - if (input == Image.class && output == Joints.class) { - return (Translator) translator; - } else if (input == Input.class && output == Output.class) { - return (Translator) new ImageServingTranslator(translator); - } - throw new IllegalArgumentException("Unsupported input/output types."); + public Class getBaseOutputType() { + return Joints.class; } } diff --git a/api/src/main/java/ai/djl/modality/cv/translator/StyleTransferTranslatorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/StyleTransferTranslatorFactory.java index 4c492af5044..a0c05321074 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/StyleTransferTranslatorFactory.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/StyleTransferTranslatorFactory.java @@ -16,33 +16,25 @@ import ai.djl.modality.cv.Image; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorFactory; -import ai.djl.util.Pair; import java.io.Serializable; -import java.lang.reflect.Type; -import java.util.Collections; import java.util.Map; -import java.util.Set; /** A {@link TranslatorFactory} that creates a {@link StyleTransferTranslator} instance. */ -public class StyleTransferTranslatorFactory implements TranslatorFactory, Serializable { +public class StyleTransferTranslatorFactory extends BaseImageTranslatorFactory + implements Serializable { private static final long serialVersionUID = 1L; /** {@inheritDoc} */ @Override - public Set> getSupportedTypes() { - return Collections.singleton(new Pair<>(Image.class, Image.class)); + protected Translator buildBaseTranslator(Model model, Map arguments) { + return new StyleTransferTranslator(); } /** {@inheritDoc} */ @Override - @SuppressWarnings("unchecked") - public Translator newInstance( - Class input, Class output, Model model, Map arguments) { - if (!isSupported(input, output)) { - throw new IllegalArgumentException("Unsupported input/output types."); - } - return (Translator) new StyleTransferTranslator(); + public Class getBaseOutputType() { + return Image.class; } } diff --git a/api/src/test/java/ai/djl/modality/cv/translator/SemanticSegmentationTranslatorFactoryTest.java b/api/src/test/java/ai/djl/modality/cv/translator/SemanticSegmentationTranslatorFactoryTest.java index ec54e3bcb47..9f24729503b 100644 --- a/api/src/test/java/ai/djl/modality/cv/translator/SemanticSegmentationTranslatorFactoryTest.java +++ b/api/src/test/java/ai/djl/modality/cv/translator/SemanticSegmentationTranslatorFactoryTest.java @@ -14,7 +14,6 @@ import ai.djl.Model; import ai.djl.modality.Input; -import ai.djl.modality.Output; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.output.CategoryMask; import ai.djl.translate.Translator; @@ -37,7 +36,7 @@ public void setUp() { @Test public void testGetSupportedTypes() { - Assert.assertEquals(factory.getSupportedTypes().size(), 1); + Assert.assertEquals(factory.getSupportedTypes().size(), 6); } @Test @@ -49,7 +48,7 @@ public void testNewInstance() { Assert.assertTrue(translator instanceof SemanticSegmentationTranslator); Assert.assertThrows( IllegalArgumentException.class, - () -> factory.newInstance(Input.class, Output.class, model, arguments)); + () -> factory.newInstance(Input.class, Image.class, model, arguments)); } } } diff --git a/api/src/test/java/ai/djl/modality/cv/translator/SimplePoseTranslatorFactoryTest.java b/api/src/test/java/ai/djl/modality/cv/translator/SimplePoseTranslatorFactoryTest.java index 0a2b54fd843..4eab9fa2eea 100644 --- a/api/src/test/java/ai/djl/modality/cv/translator/SimplePoseTranslatorFactoryTest.java +++ b/api/src/test/java/ai/djl/modality/cv/translator/SimplePoseTranslatorFactoryTest.java @@ -37,7 +37,7 @@ public void setUp() { @Test public void testGetSupportedTypes() { - Assert.assertEquals(factory.getSupportedTypes().size(), 2); + Assert.assertEquals(factory.getSupportedTypes().size(), 6); } @Test diff --git a/api/src/test/java/ai/djl/modality/cv/translator/StyleTransferTranslatorFactoryTest.java b/api/src/test/java/ai/djl/modality/cv/translator/StyleTransferTranslatorFactoryTest.java index 785ffa87206..1db9902b000 100644 --- a/api/src/test/java/ai/djl/modality/cv/translator/StyleTransferTranslatorFactoryTest.java +++ b/api/src/test/java/ai/djl/modality/cv/translator/StyleTransferTranslatorFactoryTest.java @@ -14,7 +14,6 @@ import ai.djl.Model; import ai.djl.modality.Input; -import ai.djl.modality.Output; import ai.djl.modality.cv.Image; import ai.djl.translate.Translator; @@ -36,7 +35,7 @@ public void setUp() { @Test public void testGetSupportedTypes() { - Assert.assertEquals(factory.getSupportedTypes().size(), 1); + Assert.assertEquals(factory.getSupportedTypes().size(), 6); } @Test @@ -48,7 +47,7 @@ public void testNewInstance() { Assert.assertTrue(translator instanceof StyleTransferTranslator); Assert.assertThrows( IllegalArgumentException.class, - () -> factory.newInstance(Input.class, Output.class, model, arguments)); + () -> factory.newInstance(Input.class, Image.class, model, arguments)); } } }