Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Block usability #712

Merged
merged 4 commits into from
Mar 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions api/src/main/java/ai/djl/modality/nlp/Decoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape...

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
return block.getOutputShapes(manager, inputShapes);
public Shape[] getOutputShapes(Shape[] inputShapes) {
return block.getOutputShapes(inputShapes);
}

/** {@inheritDoc} */
Expand Down
4 changes: 2 additions & 2 deletions api/src/main/java/ai/djl/modality/nlp/Encoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape...

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
return block.getOutputShapes(manager, inputShapes);
public Shape[] getOutputShapes(Shape[] inputShapes) {
return block.getOutputShapes(inputShapes);
}

/** {@inheritDoc} */
Expand Down
9 changes: 4 additions & 5 deletions api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,18 @@ public NDList forward(
* @param manager the NDManager to initialize the parameters
* @param dataType the datatype of the parameters
* @param inputShapes the shapes of the inputs to the block
* @return the shapes of the outputs of the block
*/
@Override
public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
beforeInitialize(inputShapes);
encoder.initialize(manager, dataType, inputShapes[0]);
return decoder.initialize(manager, dataType, inputShapes[1]);
decoder.initialize(manager, dataType, inputShapes[1]);
}

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
return decoder.getOutputShapes(manager, new Shape[] {inputShapes[1]});
public Shape[] getOutputShapes(Shape[] inputShapes) {
return decoder.getOutputShapes(new Shape[] {inputShapes[1]});
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape...

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
return trainableWordEmbedding.getOutputShapes(manager, inputShapes);
public Shape[] getOutputShapes(Shape[] inputShapes) {
return trainableWordEmbedding.getOutputShapes(inputShapes);
}
}
145 changes: 51 additions & 94 deletions api/src/main/java/ai/djl/nn/AbstractBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.function.Function;
import java.util.function.Predicate;

/**
* {@code AbstractBlock} is an abstract implementation of {@link Block}.
Expand All @@ -43,12 +43,11 @@
* <ul>
* <li>Define a version for serializing parameter and metadata and pass it to the parent
* constructor
* <li>Use {@link AbstractBlock#addParameter(Parameter, Shape)} or {@link
* AbstractBlock#addParameter(Parameter, Function)} to add parameters to your block in the
* <li>Use {@link AbstractBlock#addParameter(Parameter)} to add parameters to your block in the
* constructor if necessary.
* <li>Use {@link AbstractBlock#addChildBlock(String, Block)} to add child blocks if necessary.
* <li>Override {@link AbstractBlock#getOutputShapes(NDManager, Shape[])} to determine the shape
* of your custom block's output based on the input it will receive.
* <li>Override {@link Block#getOutputShapes(Shape[])} to determine the shape of your custom
* block's output based on the input it will receive.
* <li>Override {@link AbstractBlock#initializeChildBlocks(NDManager, DataType, Shape...)} if you
* added child blocks to initialize them based on the input shape your block will receive. You
* can skip this if your block does not contain child blocks
Expand All @@ -61,9 +60,9 @@
* </ul>
*
* <p>If you use {@link AbstractBlock#addParameter(Parameter)} to add parameters, you have to take
* care of parameter initialization yourself. In this case, you need to override {@link
* AbstractBlock#getParameterShape(String, Shape[])} to determine the shape of your parameters. If
* you use the other variants of {@code addParameter} this is done for you.
* care of parameter initialization yourself. In this case, you need to setShape to your parameters
* if you know the shape of Parameter or you can implement prepare to setShape when you see the
* input shape.
*/
// Using LinkedHashMap instead of Map is intentional: we want to make sure that consumers
// of this API know the children and parameters are always iterated over in insertion order.
Expand Down Expand Up @@ -99,14 +98,6 @@ public abstract class AbstractBlock implements Block {
*/
protected LinkedHashMap<String, Parameter> parameters = new LinkedHashMap<>();

/**
* Callbacks to determine the shape of a parameter. Values may be null in which case extending
* classes need to override {@link Block#getParameterShape(String, Shape[])} and implement
* parameter shape resolution manually.
*/
protected LinkedHashMap<String, Function<Shape[], Shape>> parameterShapeCallbacks =
new LinkedHashMap<>();

/**
* Builds an empty block with the given version for parameter serialization.
*
Expand Down Expand Up @@ -195,73 +186,20 @@ protected final <B extends Block> B addChildBlock(String name, B block) {
return block;
}

/**
* Adds a parameter to this block. If parameters are added with this method, subclasses need to
* override {@link Block#getParameterShape(String, Shape[])} and return the shapes of parameters
* themselves.
*
* @param parameter the parameter to add, not null
* @param <P> the specific parameter subclass
* @return the parameter passed as arguments to make it easier to create and assign paramters in
* one line
*/
protected final <P extends Parameter> P addParameter(P parameter) {
return addParameter(parameter, (Function<Shape[], Shape>) null);
}

/**
* Adds a parameter to this block. If parameters are added with this method, intialization of
* the parameter works out of the box
*
* @param parameter the parameter to add, not null
* @param shape the shape of the parameter
* @param <P> the specific parameter subclass
* @return the parameter passed as arguments to make it easier to create and assign paramters in
* one line
*/
protected final <P extends Parameter> P addParameter(P parameter, Shape shape) {
return addParameter(parameter, (inputShapes) -> shape);
}

/**
* Adds a parameter to this block. If parameters are added with this method, intialization of
* the parameter works out of the box
*
* @param parameter the parameter to add, not null
* @param shapeCallback the method to call once the input shape of this block is known to
* determine the shape of the given parameter
* @param <P> the specific parameter subclass
* @return the parameter passed as arguments to make it easier to create and assign parameters
* in one line
*/
protected final <P extends Parameter> P addParameter(
P parameter, Function<Shape[], Shape> shapeCallback) {
protected final <P extends Parameter> P addParameter(P parameter) {
parameters.put(parameter.getName(), parameter);
parameterShapeCallbacks.put(parameter.getName(), shapeCallback);
return parameter;
}

/** {@inheritDoc} */
@Override
public Shape getParameterShape(String name, Shape[] inputShapes) {
Function<Shape[], Shape> callback = parameterShapeCallbacks.get(name);
if (callback == null) {
Parameter parameter = parameters.get(name);
if (parameter == null) {
throw new IllegalArgumentException(
"No parameter named " + name + " found in this block.");
} else {
throw new IllegalStateException(
"No shape initializer for parameter "
+ name
+ "found. "
+ "Either pass an initializer for the shape when adding the "
+ "parameter or override getParameterShape in the subclass.");
}
}
return callback.apply(inputShapes);
}

/** {@inheritDoc} */
@Override
public BlockList getChildren() {
Expand All @@ -285,13 +223,9 @@ public PairList<String, Shape> describeInput() {

/** {@inheritDoc} */
@Override
public void setInitializer(Initializer initializer) {
for (Parameter parameter : parameters.values()) {
parameter.setInitializer(initializer, false);
}
for (Block child : children.values()) {
child.setInitializer(initializer);
}
public void setInitializer(Initializer initializer, Parameter.Type params) {
Predicate<Parameter> predicate = parameter -> parameter.getType().equals(params);
setInitializer(initializer, predicate);
}

/** {@inheritDoc} */
Expand All @@ -301,18 +235,50 @@ public void setInitializer(Initializer initializer, String paramName) {
if (parameter == null) {
throw new IllegalArgumentException("Could not find parameter " + paramName);
}
parameter.setInitializer(initializer, true);
parameter.setInitializer(initializer);
}

/** {@inheritDoc} */
@Override
public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
public void setInitializer(Initializer initializer, Predicate<Parameter> predicate) {
List<Parameter> params = getParameters().values();
for (Parameter param : params) {
if (predicate.test(param)) {
param.setInitializer(initializer);
}
}
}

/** {@inheritDoc} */
@Override
public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
beforeInitialize(inputShapes);
// if parameters are initialized, skip it
if (!isInitialized()) {
// setShape for all params
prepare(inputShapes);
}
for (Parameter parameter : parameters.values()) {
parameter.initialize(manager, dataType, inputShapes);
parameter.initialize(manager, dataType);
}
initializeChildBlocks(manager, dataType, inputShapes);
return getOutputShapes(manager, inputShapes);
}

/**
* Performs any action necessary before initialization. For example, keep the input information
* or verify the layout.
*
* @param inputShapes the expected shapes of the input
*/
protected void beforeInitialize(Shape... inputShapes) {
if (inputNames.isEmpty()) {
// automatically assign input names
inputNames = new ArrayList<>();
for (int i = 0; i < inputShapes.length; ++i) {
inputNames.add("data" + i);
}
}
this.inputShapes = inputShapes;
}

/**
Expand Down Expand Up @@ -355,20 +321,11 @@ public ParameterList getDirectParameters() {
}

/**
* Performs any action necessary before initialization.
* Sets the shape of {@link Parameter}s.
*
* @param inputShapes the expected shapes of the input
* @param inputShapes the shapes of inputs
*/
protected void beforeInitialize(Shape[] inputShapes) {
if (inputNames.isEmpty()) {
// automatically assign input names
inputNames = new ArrayList<>();
for (int i = 0; i < inputShapes.length; ++i) {
inputNames.add("data" + i);
}
}
this.inputShapes = inputShapes;
}
protected void prepare(Shape[] inputShapes) {}

/** {@inheritDoc} */
@Override
Expand Down Expand Up @@ -494,7 +451,7 @@ public String toString() {
appendShape(sb, inputShapeDescription.values().toArray(new Shape[0]));
sb.append(" -> ");
Shape[] outputShapes =
getOutputShapes(null, inputShapeDescription.values().toArray(new Shape[0]));
getOutputShapes(inputShapeDescription.values().toArray(new Shape[0]));
appendShape(sb, outputShapes);
} else {
sb.append("Uninitialized");
Expand Down
3 changes: 1 addition & 2 deletions api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.nn;

import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;

/** {@code AbstractSymbolBlock} is an abstract implementation of {@link SymbolBlock}. */
Expand All @@ -29,7 +28,7 @@ public AbstractSymbolBlock(byte version) {

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
throw new UnsupportedOperationException("not implement!");
}
}
31 changes: 14 additions & 17 deletions api/src/main/java/ai/djl/nn/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.function.Predicate;

/**
* A {@code Block} is a composable function that forms a neural network.
Expand Down Expand Up @@ -158,11 +159,12 @@ default NDList forward(
}

/**
* Sets an {@link Initializer} to the block.
* Sets an {@link Initializer} to all the parameters that match parameter type in the block.
*
* @param initializer the initializer to set
* @param type the Parameter Type we want to setInitializer
*/
void setInitializer(Initializer initializer);
void setInitializer(Initializer initializer, Parameter.Type type);

/**
* Sets an {@link Initializer} to the specified direct parameter of the block, overriding the
Expand All @@ -173,15 +175,22 @@ default NDList forward(
*/
void setInitializer(Initializer initializer, String paramName);

/**
* Sets an {@link Initializer} to all the parameters that match Predicate in the block.
*
* @param initializer the initializer to be set
* @param predicate predicate function to indicate parameters you want to set
*/
void setInitializer(Initializer initializer, Predicate<Parameter> predicate);

/**
* Initializes the parameters of the block. This method must be called before calling `forward`.
*
* @param manager the NDManager to initialize the parameters
* @param dataType the datatype of the parameters
* @param inputShapes the shapes of the inputs to the block
* @return the shapes of the outputs of the block
*/
Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes);
void initialize(NDManager manager, DataType dataType, Shape... inputShapes);

/**
* Returns a boolean whether the block is initialized.
Expand Down Expand Up @@ -232,25 +241,13 @@ default NDList forward(
*/
ParameterList getParameters();

/**
* Returns the shape of the specified direct parameter of this block given the shapes of the
* input to the block.
*
* @param name the name of the parameter
* @param inputShapes the shapes of the input to the block
* @return the shape of the parameter specified
* @throws IllegalArgumentException if the parameter name specified is invalid
*/
Shape getParameterShape(String name, Shape[] inputShapes);

/**
* Returns the expected output shapes of the block for the specified input shapes.
*
* @param manager an NDManager
* @param inputShapes the shapes of the inputs
* @return the expected output shapes of the block
*/
Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes);
Shape[] getOutputShapes(Shape[] inputShapes);

/**
* Writes the parameters of the block to the given outputStream.
Expand Down
6 changes: 3 additions & 3 deletions api/src/main/java/ai/djl/nn/LambdaBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ protected NDList forwardInternal(

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
try (NDManager subManager = manager.newSubManager()) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
try (NDManager manager = NDManager.newBaseManager()) {
NDList input = new NDList(inputShapes.length);
for (Shape shape : inputShapes) {
input.add(subManager.zeros(shape));
input.add(manager.zeros(shape));
}
NDList output = lambda.apply(input);
Shape[] outputShapes = new Shape[output.size()];
Expand Down
Loading