Skip to content

Commit

Permalink
Refactor BlockFactory interface (#1045)
Browse files Browse the repository at this point in the history
Change-Id: I7e9aa60f541c00852c548332338ee0fc914ee92f
  • Loading branch information
frankfliu authored Jun 23, 2021
1 parent 6a81d9d commit d4e93e6
Show file tree
Hide file tree
Showing 19 changed files with 216 additions and 291 deletions.
10 changes: 0 additions & 10 deletions api/src/main/java/ai/djl/BaseModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.BlockFactory;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.translate.Translator;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import ai.djl.util.Utils;
Expand Down Expand Up @@ -217,14 +215,6 @@ protected void setModelDir(Path modelDir) {
this.modelDir = modelDir.toAbsolutePath();
}

protected Block loadFromBlockFactory() {
BlockFactory factory = ClassLoaderUtils.findImplementation(modelDir, null);
if (factory == null) {
return null;
}
return factory.newBlock(manager);
}

/** {@inheritDoc} */
@Override
public void save(Path modelPath, String newModelName) throws IOException {
Expand Down
12 changes: 9 additions & 3 deletions api/src/main/java/ai/djl/nn/BlockFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
*/
package ai.djl.nn;

import ai.djl.ndarray.NDManager;
import ai.djl.Model;
import ai.djl.repository.zoo.ModelZoo;
import java.io.IOException;
import java.io.Serializable;
import java.nio.file.Path;
import java.util.Map;

/**
* Block factory is a component to make standard for block creating and saving procedure. Block
Expand All @@ -27,8 +30,11 @@ public interface BlockFactory extends Serializable {
/**
* Constructs the uninitialized block.
*
* @param manager the manager to assign to block
* @param model the model of the block
* @param modelPath the directory of the model location
* @param arguments the block creation arguments
* @return the uninitialized block
* @throws IOException if IO operation fails during creating block
*/
Block newBlock(NDManager manager);
Block newBlock(Model model, Path modelPath, Map<String, ?> arguments) throws IOException;
}
31 changes: 25 additions & 6 deletions api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.translator.ImageClassificationTranslatorFactory;
import ai.djl.ndarray.NDList;
import ai.djl.nn.Block;
import ai.djl.nn.BlockFactory;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
Expand All @@ -32,6 +34,7 @@
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.Pair;
import ai.djl.util.Progress;
import java.io.IOException;
Expand Down Expand Up @@ -155,10 +158,14 @@ public <I, O> ZooModel<I, O> loadModel(Criteria<I, O> criteria)
modelName = artifact.getName();
}

Model model = createModel(modelName, criteria.getDevice(), artifact, arguments, engine);
if (criteria.getBlock() != null) {
model.setBlock(criteria.getBlock());
}
Model model =
createModel(
modelPath,
modelName,
criteria.getDevice(),
criteria.getBlock(),
arguments,
engine);
model.load(modelPath, null, options);
Translator<I, O> translator = factory.newInstance(model, arguments);
return new ZooModel<>(model, translator);
Expand All @@ -182,13 +189,25 @@ public List<Artifact> listModels() throws IOException {
}

protected Model createModel(
Path modelPath,
String name,
Device device,
Artifact artifact,
Block block,
Map<String, Object> arguments,
String engine)
throws IOException {
return Model.newInstance(name, device, engine);
Model model = Model.newInstance(name, device, engine);
if (block == null) {
String className = (String) arguments.get("blockFactory");
BlockFactory factory = ClassLoaderUtils.findImplementation(modelPath, className);
if (factory != null) {
block = factory.newBlock(model, modelPath, arguments);
}
}
if (block != null) {
model.setBlock(block);
}
return model;
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import ai.djl.fasttext.FtModel;
import ai.djl.fasttext.zoo.FtModelZoo;
import ai.djl.modality.Classifications;
import ai.djl.repository.Artifact;
import ai.djl.nn.Block;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.BaseModelLoader;
Expand All @@ -30,6 +30,7 @@
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Map;

/** Model loader for fastText cooking stackexchange models. */
Expand Down Expand Up @@ -68,9 +69,10 @@ public ZooModel<String, Classifications> loadModel()
/** {@inheritDoc} */
@Override
protected Model createModel(
Path modelPath,
String name,
Device device,
Artifact artifact,
Block block,
Map<String, Object> arguments,
String engine) {
return new FtModel(name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,59 +30,21 @@
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.testing.Assertions;
import ai.djl.training.ParameterStore;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.util.Utils;
import ai.djl.util.ZipUtils;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;
import org.testng.Assert;
import org.testng.annotations.Test;

public class BlockFactoryTest {

@Test
public void testBlockLoadingSaving()
throws IOException, ModelNotFoundException, MalformedModelException,
TranslateException {
TestBlockFactory factory = new TestBlockFactory();
Model model = factory.getRemoveLastBlockModel();
try (NDManager manager = NDManager.newBaseManager()) {
Block block = model.getBlock();
block.forward(
new ParameterStore(manager, true),
new NDList(manager.ones(new Shape(1, 3, 32, 32))),
true);
ByteArrayOutputStream os = new ByteArrayOutputStream();
block.saveParameters(new DataOutputStream(os));
ByteArrayInputStream bis = new ByteArrayInputStream(os.toByteArray());
Block newBlock = factory.newBlock(manager);
newBlock.loadParameters(manager, new DataInputStream(bis));
try (Model test = Model.newInstance("test")) {
test.setBlock(newBlock);
try (Predictor<NDList, NDList> predOrigin =
model.newPredictor(new NoopTranslator());
Predictor<NDList, NDList> predDest =
test.newPredictor(new NoopTranslator())) {
NDList input = new NDList(manager.ones(new Shape(1, 3, 32, 32)));
NDList originOut = predOrigin.predict(input);
NDList destOut = predDest.predict(input);
Assertions.assertAlmostEquals(originOut, destOut);
}
}
}
model.close();
}

@Test
public void testBlockFactoryLoadingFromZip()
throws MalformedModelException, ModelNotFoundException, IOException,
Expand All @@ -97,9 +59,9 @@ public void testBlockFactoryLoadingFromZip()
.optModelPath(zipPath)
.optModelName("exported")
.build();
try (NDManager manager = NDManager.newBaseManager();
ZooModel<NDList, NDList> model = criteria.loadModel();
try (ZooModel<NDList, NDList> model = criteria.loadModel();
Predictor<NDList, NDList> pred = model.newPredictor()) {
NDManager manager = model.getNDManager();
NDList destOut = pred.predict(new NDList(manager.ones(new Shape(1, 3, 32, 32))));
Assert.assertEquals(destOut.singletonOrThrow().getShape(), new Shape(1, 10));
}
Expand Down Expand Up @@ -136,9 +98,9 @@ public static class TestBlockFactory implements BlockFactory {
private static final long serialVersionUID = 1234567L;

@Override
public Block newBlock(NDManager manager) {
public Block newBlock(Model model, Path modelPath, Map<String, ?> arguments) {
SequentialBlock newBlock = new SequentialBlock();
newBlock.add(SymbolBlock.newInstance(manager));
newBlock.add(SymbolBlock.newInstance(model.getNDManager()));
newBlock.add(Linear.builder().setUnits(10).build());
return newBlock;
}
Expand Down
16 changes: 10 additions & 6 deletions model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
*/
package ai.djl.basicmodelzoo;

import ai.djl.basicmodelzoo.cv.classification.MlpModelLoader;
import ai.djl.basicmodelzoo.cv.classification.ResNetModelLoader;
import ai.djl.basicmodelzoo.cv.object_detection.ssd.SsdModelLoader;
import ai.djl.modality.cv.zoo.ImageClassificationModelLoader;
import ai.djl.modality.cv.zoo.ObjectDetectionModelLoader;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.ModelLoader;
import ai.djl.repository.zoo.ModelZoo;
import java.util.HashSet;
import java.util.Set;
Expand All @@ -25,11 +25,15 @@ public class BasicModelZoo implements ModelZoo {

private static final String REPO_URL = "https://mlrepo.djl.ai/";
private static final Repository REPOSITORY = Repository.newInstance("zoo", REPO_URL);
private static final ModelZoo ZOO = new BasicModelZoo();
public static final String GROUP_ID = "ai.djl.zoo";

public static final ResNetModelLoader RESNET = new ResNetModelLoader(REPOSITORY);
public static final MlpModelLoader MLP = new MlpModelLoader(REPOSITORY);
public static final SsdModelLoader SSD = new SsdModelLoader(REPOSITORY);
public static final ModelLoader RESNET =
new ImageClassificationModelLoader(REPOSITORY, GROUP_ID, "resnet", "0.0.2", ZOO);
public static final ModelLoader MLP =
new ImageClassificationModelLoader(REPOSITORY, GROUP_ID, "mlp", "0.0.3", ZOO);
public static final ModelLoader SSD =
new ObjectDetectionModelLoader(REPOSITORY, GROUP_ID, "ssd", "0.0.2", ZOO);

/** {@inheritDoc} */
@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.basicmodelzoo.basic;

import ai.djl.Model;
import ai.djl.nn.Block;
import ai.djl.nn.BlockFactory;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;

/** A {@link BlockFactory} class that creates MLP block. */
public class MlpBlockFactory implements BlockFactory {

private static final long serialVersionUID = 1L;

/** {@inheritDoc} */
@Override
@SuppressWarnings("unchecked")
public Block newBlock(Model model, Path modelPath, Map<String, ?> arguments) {
Double width = (Double) arguments.get("width");
if (width == null) {
width = 28d;
}
Double height = (Double) arguments.get("height");
if (height == null) {
height = 28d;
}
int input = width.intValue() * height.intValue();
int output = ((Double) arguments.get("output")).intValue();
int[] hidden =
((List<Double>) arguments.get("hidden"))
.stream()
.mapToInt(Double::intValue)
.toArray();

return new Mlp(input, output, hidden);
}
}

This file was deleted.

Loading

0 comments on commit d4e93e6

Please sign in to comment.