Skip to content

Commit

Permalink
fix softmax flag behavior (#639)
Browse files Browse the repository at this point in the history
* fix softmax flag behavior

Change-Id: I18ee08116a7ca302a0542ef5d361a64c9e5e2227

* fix rnn test

Change-Id: I25f80a4f965e820d7d16aba515928b009d1a8b76

* fix flag

Change-Id: I855f71ec7f5ba30e11b3d3ca11c21937873dff6d
  • Loading branch information
roywei authored Feb 17, 2021
1 parent 7b0f062 commit d879306
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public SoftmaxCrossEntropyLoss() {
* @param name the name of the loss
*/
public SoftmaxCrossEntropyLoss(String name) {
this(name, 1, -1, true, false);
this(name, 1, -1, true, true);
}

/**
Expand All @@ -52,9 +52,10 @@ public SoftmaxCrossEntropyLoss(String name) {
* @param name the name of the loss
* @param weight the weight to apply on the loss value, default 1
* @param classAxis the axis that represents the class probabilities, default -1
* @param sparseLabel whether labels are integer array or probabilities, default true
* @param fromLogit whether predictions are log probabilities or un-normalized numbers, default
* false
* @param sparseLabel whether labels are 1-D integer array or 2-D probabilities of [batch_size,
* n-class], default true
* @param fromLogit whether predictions are un-normalized numbers or log probabilities, if true,
* logSoftmax will be applied to input, default true
*/
public SoftmaxCrossEntropyLoss(
String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit) {
Expand All @@ -69,7 +70,7 @@ public SoftmaxCrossEntropyLoss(
@Override
public NDArray evaluate(NDList label, NDList prediction) {
NDArray pred = prediction.singletonOrThrow();
if (!fromLogit) {
if (fromLogit) {
pred = pred.logSoftmax(classAxis);
}
NDArray loss;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ public void testRNNTanh() throws IOException, MalformedModelException {
Assertions.assertAlmostEquals(result.size(), 2);
NDArray lossValue =
loss.evaluate(new NDList(labels), new NDList(result.head()));
Assertions.assertAlmostEquals(lossValue.getFloat(), -18);
Assertions.assertAlmostEquals(lossValue.getFloat(), 24.9533);
testEncode(manager, block);
}
}
Expand Down Expand Up @@ -521,7 +521,9 @@ public void testRNNRelu() throws IOException, MalformedModelException {
Assertions.assertAlmostEquals(result.size(), 2);
NDArray lossValue =
loss.evaluate(new NDList(labels), new NDList(result.head()));
Assertions.assertAlmostEquals(lossValue.getFloat(), -908);
// loss should be the same as testRNNTanh because outputs are equal for each
// class
Assertions.assertAlmostEquals(lossValue.getFloat(), 24.9533);
testEncode(manager, block);
}
}
Expand Down Expand Up @@ -571,7 +573,7 @@ public void testLstm() throws IOException, MalformedModelException {
Assertions.assertAlmostEquals(result.size(), 3);
NDArray lossValue =
loss.evaluate(new NDList(labels), new NDList(result.head()));
Assertions.assertAlmostEquals(lossValue.getFloat(), -16.340019);
Assertions.assertAlmostEquals(lossValue.getFloat(), 24.9533);
testEncode(manager, block);
}
}
Expand Down Expand Up @@ -628,7 +630,7 @@ public void testGRU() throws IOException, MalformedModelException {
Assertions.assertAlmostEquals(result.size(), 1);
NDArray lossValue =
loss.evaluate(new NDList(labels), new NDList(result.head()));
Assertions.assertAlmostEquals(lossValue.getFloat(), -8.17537307E-4);
Assertions.assertAlmostEquals(lossValue.getFloat(), 24.9533);
testEncode(manager, block);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,31 @@ public void l2LossTest() {
@Test
public void softmaxCrossEntropyTest() {
try (NDManager manager = NDManager.newBaseManager()) {
// test fromLogits=true, sparseLabel=true
NDArray pred = manager.create(new float[] {1, 2, 3, 4, 5});
NDArray label = manager.ones(new Shape(1));
Assertions.assertAlmostEquals(
Loss.softmaxCrossEntropyLoss().evaluate(new NDList(label), new NDList(pred)),
Loss.softmaxCrossEntropyLoss("loss", 1, -1, true, true)
.evaluate(new NDList(label), new NDList(pred)),
manager.create(3.45191431f));

// test fromLogits=false, sparseLabel=true
pred =
manager.create(
new float[] {4.0f, 2.0f, 1.0f, 0.0f, 5.0f, 1.0f}, new Shape(2, 3));
label = manager.create(new float[] {0, 1}, new Shape(2));
NDArray nonSparseLabel =
manager.create(new float[] {1f, 0f, 0f, 0f, 1f, 0f}, new Shape(2, 3));
NDArray sparseOutput =
Loss.softmaxCrossEntropyLoss()
.evaluate(new NDList(label), new NDList(pred.logSoftmax(-1)));
// test fromLogits=false, sparseLabel=false
NDArray nonSparseOutput =
Loss.softmaxCrossEntropyLoss("loss", 1, -1, false, false)
.evaluate(new NDList(nonSparseLabel), new NDList(pred.logSoftmax(-1)));

Assertions.assertAlmostEquals(sparseOutput, nonSparseOutput);
Assertions.assertAlmostEquals(sparseOutput, manager.create(0.09729549f));
}
}

Expand Down

0 comments on commit d879306

Please sign in to comment.