Skip to content

Commit

Permalink
Update to latest master
Browse files Browse the repository at this point in the history
Change-Id: I5e46341fc48853f7ae0dfa4deaa1923fa5bb5c6a
  • Loading branch information
zachgk committed Jan 14, 2021
1 parent f70f153 commit f36eca7
Show file tree
Hide file tree
Showing 13 changed files with 106 additions and 122 deletions.
17 changes: 7 additions & 10 deletions api/src/main/java/ai/djl/nn/transformer/BertBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ private BertBlock(Builder builder) {
this.embeddingDropout =
addChildBlock(
"embeddingDropout",
Dropout.builder().optProbability(builder.hiddenDropoutProbability).build());
Dropout.builder().optRate(builder.hiddenDropoutProbability).build());
// the transformer blocks
this.transformerEncoderBlocks = new ArrayList<>(builder.transformerBlockCount);
for (int i = 0; i < builder.transformerBlockCount; ++i) {
Expand All @@ -112,10 +112,7 @@ private BertBlock(Builder builder) {
this.pooling =
addChildBlock(
"poolingProjection",
Linear.builder()
.setOutChannels(builder.embeddingSize)
.optBias(true)
.build());
Linear.builder().setUnits(builder.embeddingSize).optBias(true).build());
}

/**
Expand Down Expand Up @@ -238,9 +235,9 @@ public NDList forward(
ParameterStore ps, NDArray tokenIds, NDArray typeIds, NDArray masks, boolean training) {
MemoryScope initScope = MemoryScope.from(tokenIds).add(typeIds, masks);
// Create embeddings for inputs
NDArray embeddedTokens = tokenEmbedding.forward(ps, tokenIds);
NDArray embeddedTypes = typeEmbedding.forward(ps, typeIds);
NDArray embeddedPositions = ps.getValue(positionEmebdding, tokenIds.getDevice());
NDArray embeddedTokens = tokenEmbedding.forward(ps, tokenIds, training);
NDArray embeddedTypes = typeEmbedding.forward(ps, typeIds, training);
NDArray embeddedPositions = ps.getValue(positionEmebdding, tokenIds.getDevice(), training);
// Merge them to one embedding by adding them
// (We can just add the position embedding, even though it does not have a batch dimension:
// the tensor is automagically "broadcast" i.e. repeated in the batch dimension. That
Expand Down Expand Up @@ -400,8 +397,8 @@ public Builder optHiddenDropoutProbability(float hiddenDropoutProbability) {
}*/

/**
* Sets the maximum sequence length this model can process. Memory & compute requirements of
* the attention mechanism is O(n²), so large values can easily exhaust your GPU memory!
* Sets the maximum sequence length this model can process. Memory and compute requirements
* of the attention mechanism is O(n²), so large values can easily exhaust your GPU memory!
*
* @param maxSequenceLength the maximum sequence length this model can process.
* @return this builder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public BertMaskedLanguageModelBlock(
addChildBlock(
"sequenceProjection",
Linear.builder()
.setOutChannels(bertBlock.getEmbeddingSize())
.setUnits(bertBlock.getEmbeddingSize())
.optBias(true)
.build());
this.sequenceNorm = addChildBlock("sequenceNorm", BatchNorm.builder().optAxis(1).build());
Expand Down Expand Up @@ -158,7 +158,7 @@ public NDArray forward(
final NDArray logits = normalizedTokens.dot(embeddingTransposed); // (B * I, D)
// we add an offset for each dictionary entry
final NDArray logitsWithBias =
logits.add(ps.getValue(dictionaryBias, logits.getDevice())); // (B * I, D)
logits.add(ps.getValue(dictionaryBias, logits.getDevice(), training)); // (B * I, D)
// now we apply log Softmax to get proper log probabilities
final NDArray logProbs = logitsWithBias.logSoftmax(1); // (B * I, D)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public NDArray evaluate(NDList labels, NDList predictions) {
/**
* Calculates the percentage of correctly predicted masked tokens.
*
* @param labels expected tokens & mask
* @param labels expected tokens and mask
* @param predictions prediction of a bert model
* @return the percentage of correctly predicted masked tokens
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ public BertNextSentenceBlock() {
super(VERSION);
this.binaryClassifier =
addChildBlock(
"binaryClassifier",
Linear.builder().setOutChannels(2).optBias(true).build());
"binaryClassifier", Linear.builder().setUnits(2).optBias(true).build());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import ai.djl.util.PairList;
import java.util.Arrays;

/** Creates a block that performs all bert pretraining tasks (next sentence & masked language). */
/** Creates a block that performs all bert pretraining tasks (next sentence and masked language). */
public class BertPretrainingBlock extends AbstractBlock {

private static final byte VERSION = 1;
Expand Down Expand Up @@ -104,7 +104,7 @@ public NDList forward(
final NDArray nextSentenceProbabilities = nsBlock.forward(ps, pooledOutput, training);
// de-mask masked tokens
final NDArray embeddingTable =
bertBlock.getTokenEmbedding().getValue(ps, embeddedSequence.getDevice());
bertBlock.getTokenEmbedding().getValue(ps, embeddedSequence.getDevice(), training);
final NDArray logProbs =
mlBlock.forward(ps, embeddedSequence, maskedIndices, embeddingTable, training);

Expand Down
15 changes: 9 additions & 6 deletions api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,14 @@ public NDList forward(ParameterStore ps, NDList inputs, boolean training) {
*
* @param parameterStore used to get the current state of the embedding table
* @param input an ndarry of token ids
* @param training true for a training forward pass
* @return the embeddings for the given ids
*/
public NDArray forward(ParameterStore parameterStore, NDArray input) {
public NDArray forward(ParameterStore parameterStore, NDArray input, boolean training) {
// on info to the right shapes, see: http://beta.mxnet.io/r/api/mx.symbol.gather_nd.html
NDArray ids = input.flatten().reshape(1, input.getShape().size());
// create the embedding Table
NDArray embeddingTable = parameterStore.getValue(embedding, ids.getDevice());
NDArray embeddingTable = parameterStore.getValue(embedding, ids.getDevice(), training);
// We do not perform a sparse lookup, instead we just project into the table
NDArray result = MissingOps.gatherNd(embeddingTable, ids);
// we want the original shape of the input + the last dimension of the embedding
Expand All @@ -94,14 +95,15 @@ public NDArray forward(ParameterStore parameterStore, NDArray input) {
*
* @param parameterStore the parameters store
* @param input the embeddings to create log probabilities for
* @param training true for a training forward pass
* @return log probabilities for each embedding
*/
public NDArray probabilities(ParameterStore parameterStore, NDArray input) {
public NDArray probabilities(ParameterStore parameterStore, NDArray input, boolean training) {
// reshape input into a matrix
NDArray asMatrix = input.reshape(-1, embeddingSize);
// get embedding table
NDArray embeddingTableTransposed =
parameterStore.getValue(embedding, input.getDevice()).transpose();
parameterStore.getValue(embedding, input.getDevice(), training).transpose();
embeddingTableTransposed.attach(input.getManager());
// Create raw logits by taking the scalar product of the tokens and the embedding table
NDArray logitsFlat = asMatrix.dot(embeddingTableTransposed);
Expand All @@ -121,10 +123,11 @@ public NDArray probabilities(ParameterStore parameterStore, NDArray input) {
*
* @param ps the parameter store
* @param device device to get internal table for
* @param training true for a training forward pass
* @return this embedding table as an array on the given device
*/
public NDArray getValue(ParameterStore ps, Device device) {
return ps.getValue(embedding, device);
public NDArray getValue(ParameterStore ps, Device device, boolean training) {
return ps.getValue(embedding, device, training);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,12 @@ public PointwiseFeedForwardBlock(
int count = 0;
for (final int hiddenSize : hiddenSizes) {
addChildBlock(
"linear_" + count,
Linear.builder()
.optBias(true)
.optFlatten(false)
.setOutChannels(hiddenSize)
.build());
"linear_" + count, Linear.builder().optBias(true).setUnits(hiddenSize).build());
addChildBlock("activation_" + count, new LambdaBlock(activationFunction));
++count;
}
// add output layer without activation
addChildBlock(
"output_layer", Linear.builder().optBias(true).setOutChannels(outputSize).build());
addChildBlock("output_layer", Linear.builder().optBias(true).setUnits(outputSize).build());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public TransformerEncoderBlock(
.setHeadCount(headCount)
.optAttentionProbsDropoutProb(dropoutProbability)
.build());
this.selfAttentionDropout = Dropout.builder().optProbability(dropoutProbability).build();
this.selfAttentionDropout = Dropout.builder().optRate(dropoutProbability).build();
this.attentionNorm = addChildBlock("attentionNorm", BatchNorm.builder().optAxis(2).build());
this.pointWisefullyConnected =
addChildBlock(
Expand All @@ -76,7 +76,7 @@ public TransformerEncoderBlock(
Collections.singletonList(hiddenSize),
embeddingSize,
activationFunction));
this.fullyConnectedDropout = Dropout.builder().optProbability(dropoutProbability).build();
this.fullyConnectedDropout = Dropout.builder().optRate(dropoutProbability).build();
this.outputNorm = addChildBlock("outputNorm", BatchNorm.builder().optAxis(2).build());
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,42 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.training.optimizer.learningrate;
package ai.djl.training.tracker;

/** Polynomial decay learning rate. */
@SuppressWarnings("PMD")
public class PolynomialDecayLearningRateTracker extends LearningRateTracker {
/** Polynomial decay {@link Tracker}. */
public class PolynomialDecayTracker implements Tracker {

protected float endLearningRate;
protected int decaySteps;
protected float power;
private float baseValue;
private float endLearningRate;
private int decaySteps;
private float power;

/**
* Builds a PolynomialDecayLearningRateTracker.
* Builds a PolynomialDecayTracker.
*
* @param builder parameters
*/
public PolynomialDecayLearningRateTracker(final Builder builder) {
super(builder);
public PolynomialDecayTracker(Builder builder) {
if (Float.isNaN(builder.endLearningRate)) {
throw new IllegalArgumentException("End learning rate is not set.");
}
if (builder.decaySteps <= 0) {
throw new IllegalArgumentException("Decay steps is not set.");
}
this.baseValue = builder.baseValue;
this.endLearningRate = builder.endLearningRate;
this.decaySteps = builder.decaySteps;
this.power = builder.power;
}

/** {@inheritDoc} */
@Override
public float getNewLearningRate(final int numUpdate) {
if (numUpdate < warmUpSteps) {
return getWarmUpLearningRate(numUpdate);
}
int step = Math.max(0, Math.min(numUpdate - warmUpSteps, decaySteps));
double decayedLearningRate =
(baseLearningRate - endLearningRate)
public float getNewValue(int numUpdate) {
int step = Math.max(0, Math.min(numUpdate, decaySteps));
return (float)
((baseValue - endLearningRate)
* Math.pow(1.0 - (double) step / (double) decaySteps, power)
+ endLearningRate;
return (float) decayedLearningRate;
+ endLearningRate);
}

/**
Expand All @@ -60,12 +57,24 @@ public static Builder builder() {
return new Builder();
}

/** Builder for PolynomialDecayLearningRateTracker. */
public static class Builder extends LearningRateTracker.LrBaseBuilder<Builder> {
/** Builder for PolynomialDecayTracker. */
public static class Builder {

private float baseValue;
private float endLearningRate = Float.NaN;
private int decaySteps = -1;
private float power = 1f;

protected float endLearningRate = Float.NaN;
protected int decaySteps = -1;
protected float power = 1f;
/**
* Sets the initial value after no steps.
*
* @param baseValue the initial value
* @return this {@code Builder}
*/
public Builder setBaseValue(float baseValue) {
this.baseValue = baseValue;
return this;
}

/**
* Sets the learning rate at which to end rate decay.
Expand All @@ -75,7 +84,7 @@ public static class Builder extends LearningRateTracker.LrBaseBuilder<Builder> {
*/
public Builder setEndLearningRate(float endLearningRate) {
this.endLearningRate = endLearningRate;
return self();
return this;
}

/**
Expand All @@ -86,7 +95,7 @@ public Builder setEndLearningRate(float endLearningRate) {
*/
public Builder setDecaySteps(int decaySteps) {
this.decaySteps = decaySteps;
return self();
return this;
}

/**
Expand All @@ -97,21 +106,16 @@ public Builder setDecaySteps(int decaySteps) {
*/
public Builder optPower(float power) {
this.power = power;
return self();
return this;
}

/**
* Builds a PolynomialDecayLearningRateTracker.
* Builds a PolynomialDecayTracker.
*
* @return a PolynomialDecayLearningRateTracker
* @return a PolynomialDecayTracker
*/
public PolynomialDecayLearningRateTracker build() {
return new PolynomialDecayLearningRateTracker(this);
}

@Override
protected Builder self() {
return this;
public PolynomialDecayTracker build() {
return new PolynomialDecayTracker(this);
}
}
}
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/training/tracker/Tracker.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static FactorTracker.Builder factor() {
* @return the {@link WarmUpTracker} {@link WarmUpTracker.Builder}
*/
static WarmUpTracker.Builder warmUp() {
return new WarmUpTracker.Builder();
return WarmUpTracker.builder();
}

/**
Expand Down
Loading

0 comments on commit f36eca7

Please sign in to comment.