diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index 1b80170d908..fa0ee3b9b09 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -15,6 +15,7 @@ import ai.djl.Device; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; +import ai.djl.util.Pair; import ai.djl.util.PairList; import java.nio.Buffer; import java.nio.file.Path; @@ -34,12 +35,14 @@ public abstract class BaseNDManager implements NDManager { protected String name; protected Device device; protected ConcurrentHashMap resources; + protected ConcurrentHashMap> tempResources; protected AtomicBoolean closed = new AtomicBoolean(false); protected BaseNDManager(NDManager parent, Device device) { this.parent = parent; this.device = Device.defaultIfNull(device, getEngine()); resources = new ConcurrentHashMap<>(); + tempResources = new ConcurrentHashMap<>(); uid = UUID.randomUUID().toString(); } @@ -197,7 +200,7 @@ public String toString() { /** {@inheritDoc} */ @Override - public synchronized void attach(String resourceId, AutoCloseable resource) { + public synchronized void attachInternal(String resourceId, AutoCloseable resource) { if (closed.get()) { throw new IllegalStateException("NDManager has been closed already."); } @@ -206,7 +209,17 @@ public synchronized void attach(String resourceId, AutoCloseable resource) { /** {@inheritDoc} */ @Override - public synchronized void detach(String resourceId) { + public void tempAttachInternal( + NDManager originalManager, String resourceId, NDResource resource) { + if (closed.get()) { + throw new IllegalStateException("NDManager has been closed already."); + } + tempResources.put(resourceId, new Pair<>(resource, originalManager)); + } + + /** {@inheritDoc} */ + @Override + public synchronized void detachInternal(String resourceId) { if (closed.get()) { // This may happen in the middle of BaseNDManager.close() return; @@ -238,7 +251,14 @@ public synchronized void close() { logger.error("Resource close failed.", e); } } - parent.detach(uid); + for (Pair resource : tempResources.values()) { + try { + resource.getKey().attach(resource.getValue()); + } catch (Exception e) { + logger.error("Temporary resource return failed.", e); + } + } + parent.detachInternal(uid); resources.clear(); } } diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index 758c052537f..4f73cdb4dee 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -40,7 +40,7 @@ * href="https://github.com/awslabs/djl/blob/master/docs/development/memory_management.md">NDArray * Memory Management Guide */ -public interface NDArray extends AutoCloseable { +public interface NDArray extends NDResource { /** * Decodes {@code NDArray} from bytes. @@ -53,13 +53,6 @@ static NDArray decode(NDManager manager, byte[] byteArray) { return manager.decode(byteArray); } - /** - * Returns the {@link NDManager} used to create this {@code NDArray}. - * - * @return the {@link NDManager} used to create this {@code NDArray} - */ - NDManager getManager(); - /** * Returns the name of this {@code NDArray}. * @@ -146,27 +139,6 @@ default byte[] encode() { return NDSerializer.encode(this); } - /** - * Attaches this {@code NDArray} to the specified {@link NDManager}. - * - *

Attached resource will be closed when the {@link NDManager} is closed. - * - * @param manager the {@link NDManager} to be attached - * @return the original {@link NDManager} - */ - NDManager attach(NDManager manager); - - /** - * Detaches the {@code NDArray} from current {@link NDManager}'s lifecycle. - * - *

The {@code NDArray} becomes un-managed, it is the user's responsibility to close the - * {@code NDArray}. Failure to close the resource might cause your machine to run out of native - * memory. - * - * @see NDManager - */ - void detach(); - /** * Moves this {@code NDArray} to a different {@link Device}. * diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index 8afb9ca3ec0..7a435d12e64 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -83,7 +83,7 @@ default SparseFormat getSparseFormat() { /** {@inheritDoc} */ @Override - default NDManager attach(NDManager manager) { + default void attach(NDManager manager) { throw new UnsupportedOperationException(UNSUPPORTED_MSG); } diff --git a/api/src/main/java/ai/djl/ndarray/NDList.java b/api/src/main/java/ai/djl/ndarray/NDList.java index 88f727cf807..85a460f1cf9 100644 --- a/api/src/main/java/ai/djl/ndarray/NDList.java +++ b/api/src/main/java/ai/djl/ndarray/NDList.java @@ -22,9 +22,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; /** * An {@code NDList} represents a sequence of {@link NDArray}s with names. @@ -34,7 +31,7 @@ * * @see NDArray */ -public class NDList extends ArrayList implements AutoCloseable { +public class NDList extends ArrayList implements NDResource { private static final long serialVersionUID = 1L; @@ -200,36 +197,28 @@ public NDList toDevice(Device device, boolean copy) { return newNDList; } - /** - * Attaches each ndarray in this list to the specified manager. - * - * @param manager the manager to attach the lists to - * @return a list of {@code NDManager} with which original NDArray are attached - * @see NDArray#attach(NDManager) - */ - public List attach(NDManager manager) { - return stream().map(array -> array.attach(manager)).collect(Collectors.toList()); + /** {@inheritDoc} */ + @Override + public NDManager getManager() { + return head().getManager(); } - /** - * Attaches each ndarray in this list to the specified manager. - * - * @param managers the list of managers to attach - * @return a list of {@code NDManager} with which original NDArray are attached - */ - public List attach(List managers) { - return IntStream.range(0, size()) - .mapToObj(i -> get(i).attach(managers.get(i))) - .collect(Collectors.toList()); + /** {@inheritDoc} */ + @Override + public void attach(NDManager manager) { + stream().forEach(array -> array.attach(manager)); } - /** - * Detaches each ndarray in this list from their current managers. - * - * @see NDArray#detach() - */ + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { + stream().forEach(array -> array.tempAttach(manager)); + } + + /** {@inheritDoc} */ + @Override public void detach() { - forEach(NDArray::detach); + stream().forEach(NDResource::detach); } /** diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index cfcf1609b1c..48771bba805 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -133,6 +133,16 @@ static NDManager newBaseManager(Device device, String engineName) { return Engine.getEngine(engineName).newBaseManager(device); } + /** + * Creates a new manager based on the given resource. + * + * @param resource the resource to use + * @return a new memory scrope containing the array + */ + static NDManager from(NDResource resource) { + return resource.getManager().newSubManager(); + } + /** * Allocates a new engine specific direct byte buffer. * @@ -1274,14 +1284,34 @@ default NDArray randomNormal( Device getDevice(); /** - * Attaches a {@link NDArray} or {@code NDManager} to this {@code NDManager}. + * Attaches a resource to this {@code NDManager}. + * + *

The attached resource will be closed when this {@code NDManager} is closed. + * + *

This attachment is internal. Many resources will internally track which manager they are + * attached to. In that case, you should call {@link NDResource#attach(NDManager)} instead and + * that should then call attachInternal. + * + * @param resourceId the unique resourceId + * @param resource the {@link AutoCloseable} resource to be attached + */ + void attachInternal(String resourceId, AutoCloseable resource); + + /** + * Temporarily attaches a resource to this {@code NDManager} to be returned when this is closed. + * + *

The attached resource will be returned to it's original manager when this {@code + * NDManager} is closed. * - *

Attached resource will be closed when this {@code NDManager} is closed. + *

This attachment is internal. Many resources will internally track which manager they are + * attached to. In that case, you should call {@link NDResource#attach(NDManager)} instead and + * that should then call tempAttachInternal. * + * @param originalManager the original manager to return the resource to * @param resourceId the unique resourceId * @param resource the {@link AutoCloseable} resource to be attached */ - void attach(String resourceId, AutoCloseable resource); + void tempAttachInternal(NDManager originalManager, String resourceId, NDResource resource); /** * Detaches a {@link NDArray} from this {@code NDManager}'s lifecycle. @@ -1290,9 +1320,49 @@ default NDArray randomNormal( * resource. Failed to close the resource has to wait on GC to be freed, and might cause out of * native memory. * + *

This detach is internal. Many resources will internally track which manager they are + * attached to. In that case, you should call {@link NDResource#detach()} instead and that + * should then call detachInternal. + * * @param resourceId the resourceId to be removed from this {@code NDManager}'s lifecycle */ - void detach(String resourceId); + void detachInternal(String resourceId); + + /** + * Returns a value outside of this manager by attaching to this manager's parent. + * + * @param resource the resource to return + * @param the type of the resource + * @return the passed in resource, after attaching to a new manager + */ + default T ret(T resource) { + resource.attach(getParentManager()); + return resource; + } + + /** + * Attaches all resources to this manager. + * + * @param resources the resources to attach + * @see NDResource#attach(NDManager) + */ + default void attachAll(NDResource... resources) { + for (NDResource resource : resources) { + resource.attach(this); + } + } + + /** + * Temporarily attaches all resources to this manager. + * + * @param resources the resources to attach + * @see NDResource#tempAttach(NDManager) + */ + default void tempAttachAll(NDResource... resources) { + for (NDResource resource : resources) { + resource.tempAttach(this); + } + } /** * An engine specific generic invocation to native operation. diff --git a/api/src/main/java/ai/djl/ndarray/NDResource.java b/api/src/main/java/ai/djl/ndarray/NDResource.java new file mode 100644 index 00000000000..8033d022608 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/NDResource.java @@ -0,0 +1,57 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray; + +/** An object which is managed by an {@link NDManager} and tracks the manager it is attached to. */ +public interface NDResource extends AutoCloseable { + + /** + * Returns the {@link NDManager} that manages this. + * + * @return the {@link NDManager} that manages this. + */ + NDManager getManager(); + + /** + * Attaches this {@link NDResource} to the specified {@link NDManager}. + * + *

Attached resource will be closed when the {@link NDManager} is closed. + * + * @param manager the {@link NDManager} to be attached to + */ + void attach(NDManager manager); + + /** + * Temporarily attaches this {@link NDResource} to the specified {@link NDManager}. + * + *

Attached resource will be returned to the original manager when the {@link NDManager} is + * closed. + * + * @param manager the {@link NDManager} to be attached to + */ + void tempAttach(NDManager manager); + + /** + * Detaches the {@link NDResource} from current {@link NDManager}'s lifecycle. + * + *

This becomes un-managed and it is the user's responsibility to close this. Failure to + * close the resource might cause your machine to run out of native memory. + * + * @see NDManager + */ + void detach(); + + /** {@inheritDoc} */ + @Override + void close(); +} diff --git a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java index b8ca2caf120..868f6ace5fc 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java @@ -213,7 +213,8 @@ protected NDList forwardInternal( NDArray typeIds = inputs.get(1); // Third are the masks for the input NDArray masks = inputs.get(2); - MemoryScope initScope = MemoryScope.from(tokenIds).add(typeIds, masks); + NDManager initScope = NDManager.from(tokenIds); + initScope.tempAttachAll(inputs); // Create embeddings for inputs NDArray embeddedTokens = tokenEmbedding.forward(ps, new NDList(tokenIds), training).singletonOrThrow(); @@ -241,16 +242,15 @@ protected NDList forwardInternal( .mul(-100000f); // turn 1s (original 0s) into -100000 // Run through all transformer blocks NDList lastOutput = dropoutEmbedding; - initScope - .remove(tokenIds, typeIds, masks) - .waitToRead(dropoutEmbedding) - .waitToRead(offsetMask) - .close(); + initScope.ret(lastOutput); + initScope.ret(offsetMask); + initScope.close(); for (final TransformerEncoderBlock block : transformerEncoderBlocks) { NDList input = new NDList(lastOutput.head(), offsetMask); - MemoryScope innerScope = MemoryScope.from(input); - lastOutput = block.forward(ps, input, training); - innerScope.remove(offsetMask).waitToRead(lastOutput).close(); + try (NDManager innerScope = NDManager.from(input)) { + innerScope.tempAttachAll(input); + lastOutput = innerScope.ret(block.forward(ps, input, training)); + } } // We also return the pooled output - this is an additional fully connected layer // only applied to the first token, assumed to be the CLS token to be used for training diff --git a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java index 7fc336434b6..4c52e1e5f0a 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java @@ -114,30 +114,32 @@ protected NDList forwardInternal( NDArray sequenceOutput = inputs.get(0); // (B, S, E) NDArray maskedIndices = inputs.get(1); // (B, I) NDArray embeddingTable = inputs.get(2); // (D, E) - MemoryScope scope = MemoryScope.from(sequenceOutput).add(maskedIndices); - NDArray gatheredTokens = gatherFromIndices(sequenceOutput, maskedIndices); // (B * I, E) - NDArray projectedTokens = - hiddenActivation.apply( - sequenceProjection - .forward(ps, new NDList(gatheredTokens), training) - .head()); // (B * I, E) - NDArray normalizedTokens = - sequenceNorm - .forward(ps, new NDList(projectedTokens), training) - .head(); // (B * I, E) - // raw logits for each position to correspond to an entry in the embedding table - NDArray embeddingTransposed = embeddingTable.transpose(); - embeddingTransposed.attach(gatheredTokens.getManager()); - NDArray logits = normalizedTokens.dot(embeddingTransposed); // (B * I, D) - // we add an offset for each dictionary entry - NDArray logitsWithBias = - logits.add(ps.getValue(dictionaryBias, logits.getDevice(), training)); // (B * I, D) - // now we apply log Softmax to get proper log probabilities - NDArray logProbs = logitsWithBias.logSoftmax(1); // (B * I, D) + try (NDManager scope = NDManager.from(sequenceOutput)) { + scope.tempAttachAll(sequenceOutput, maskedIndices); + NDArray gatheredTokens = gatherFromIndices(sequenceOutput, maskedIndices); // (B * I, E) + NDArray projectedTokens = + hiddenActivation.apply( + sequenceProjection + .forward(ps, new NDList(gatheredTokens), training) + .head()); // (B * I, E) + NDArray normalizedTokens = + sequenceNorm + .forward(ps, new NDList(projectedTokens), training) + .head(); // (B * I, E) + // raw logits for each position to correspond to an entry in the embedding table + NDArray embeddingTransposed = embeddingTable.transpose(); + embeddingTransposed.attach(gatheredTokens.getManager()); + NDArray logits = normalizedTokens.dot(embeddingTransposed); // (B * I, D) + // we add an offset for each dictionary entry + NDArray logitsWithBias = + logits.add( + ps.getValue( + dictionaryBias, logits.getDevice(), training)); // (B * I, D) + // now we apply log Softmax to get proper log probabilities + NDArray logProbs = logitsWithBias.logSoftmax(1); // (B * I, D) - scope.remove(sequenceOutput, maskedIndices).waitToRead(logProbs).close(); - - return new NDList(logProbs); + return scope.ret(new NDList(logProbs)); + } } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelLoss.java b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelLoss.java index 203fa1cf857..dc640e3d171 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelLoss.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelLoss.java @@ -14,6 +14,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.training.loss.Loss; @@ -40,29 +41,30 @@ public BertMaskedLanguageModelLoss(int labelIdx, int maskIdx, int logProbsIdx) { @Override public NDArray evaluate(NDList labels, NDList predictions) { - MemoryScope scope = MemoryScope.from(labels).add(predictions); + try (NDManager scope = NDManager.from(labels)) { + scope.tempAttachAll(labels, predictions); - NDArray logProbs = predictions.get(logProbsIdx); // (B * I, D) - int dictionarySize = (int) logProbs.getShape().get(1); - NDArray targetIds = labels.get(labelIdx).flatten(); // (B * I) - NDArray mask = labels.get(maskIdx).flatten().toType(DataType.FLOAT32, false); // (B * I) - NDArray targetOneHots = targetIds.oneHot(dictionarySize); - // Multiplying log_probs and one_hot_labels leaves the log probabilities of the correct - // entries. - // By summing we get the total predicition quality. We want to minimize the error, - // so we negate the value - as we have logarithms, probability = 1 means log(prob) = 0, - // the less sure we are the smaller the log value. - NDArray perExampleLoss = logProbs.mul(targetOneHots).sum(new int[] {1}).mul(-1); - // Multiplying log_probs and one_hot_labels leaves the log probabilities of the correct - // entries. - // By summing we get the total prediction quality. - NDArray numerator = perExampleLoss.mul(mask).sum(); - // We normalize the loss by the actual number of predictions we had to make - NDArray denominator = mask.sum().add(1e-5f); - NDArray result = numerator.div(denominator); + NDArray logProbs = predictions.get(logProbsIdx); // (B * I, D) + int dictionarySize = (int) logProbs.getShape().get(1); + NDArray targetIds = labels.get(labelIdx).flatten(); // (B * I) + NDArray mask = labels.get(maskIdx).flatten().toType(DataType.FLOAT32, false); // (B * I) + NDArray targetOneHots = targetIds.oneHot(dictionarySize); + // Multiplying log_probs and one_hot_labels leaves the log probabilities of the correct + // entries. + // By summing we get the total predicition quality. We want to minimize the error, + // so we negate the value - as we have logarithms, probability = 1 means log(prob) = 0, + // the less sure we are the smaller the log value. + NDArray perExampleLoss = logProbs.mul(targetOneHots).sum(new int[] {1}).mul(-1); + // Multiplying log_probs and one_hot_labels leaves the log probabilities of the correct + // entries. + // By summing we get the total prediction quality. + NDArray numerator = perExampleLoss.mul(mask).sum(); + // We normalize the loss by the actual number of predictions we had to make + NDArray denominator = mask.sum().add(1e-5f); + NDArray result = numerator.div(denominator); - scope.remove(labels, predictions).waitToRead(result).close(); - return result; + return scope.ret(result); + } } /** @@ -73,19 +75,19 @@ public NDArray evaluate(NDList labels, NDList predictions) { * @return the percentage of correctly predicted masked tokens */ public NDArray accuracy(NDList labels, NDList predictions) { - MemoryScope scope = MemoryScope.from(labels).add(predictions); + try (NDManager scope = NDManager.from(labels)) { + scope.tempAttachAll(labels, predictions); - NDArray mask = labels.get(maskIdx).flatten(); // (B * I) - NDArray targetIds = labels.get(labelIdx).flatten(); // (B * I) - NDArray logProbs = predictions.get(logProbsIdx); // (B * I, D) - NDArray predictedIs = logProbs.argMax(1).toType(DataType.INT32, false); // (B * I) - NDArray equal = predictedIs.eq(targetIds).mul(mask); - NDArray equalCount = equal.sum().toType(DataType.FLOAT32, false); - NDArray count = mask.sum().toType(DataType.FLOAT32, false); - NDArray result = equalCount.div(count); + NDArray mask = labels.get(maskIdx).flatten(); // (B * I) + NDArray targetIds = labels.get(labelIdx).flatten(); // (B * I) + NDArray logProbs = predictions.get(logProbsIdx); // (B * I, D) + NDArray predictedIs = logProbs.argMax(1).toType(DataType.INT32, false); // (B * I) + NDArray equal = predictedIs.eq(targetIds).mul(mask); + NDArray equalCount = equal.sum().toType(DataType.FLOAT32, false); + NDArray count = mask.sum().toType(DataType.FLOAT32, false); + NDArray result = equalCount.div(count); - scope.remove(labels, predictions).waitToRead(result); - - return result; + return scope.ret(result); + } } } diff --git a/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceLoss.java b/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceLoss.java index 0916e096c1c..b11e2828d82 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceLoss.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceLoss.java @@ -14,6 +14,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.training.loss.Loss; @@ -38,20 +39,21 @@ public BertNextSentenceLoss(int labelIdx, int nextSentencePredictionIdx) { @Override public NDArray evaluate(NDList labels, NDList predictions) { - MemoryScope scope = MemoryScope.from(labels).add(predictions); - NDArray label = labels.get(labelIdx).toType(DataType.FLOAT32, false); - // predictions are log(softmax) - NDArray logPredictions = predictions.get(nextSentencePredictionIdx); - NDArray oneHotLabels = label.oneHot(2); - // we use negative log likelihood as loss: log(softmax) turns high confidence into - // negative values near one, low confidence into negative values near -inf, - // negating gives almost 0 for high confidence and near +inf for very low confidence - NDArray logPredictionForLabels = oneHotLabels.mul(logPredictions); - NDArray summedPredictions = logPredictionForLabels.sum(new int[] {1}); - NDArray perExampleLoss = summedPredictions.mul(-1f); - NDArray result = perExampleLoss.mean(); - scope.remove(labels, predictions).waitToRead(result).close(); - return result; + try (NDManager scope = NDManager.from(labels)) { + scope.tempAttachAll(labels, predictions); + NDArray label = labels.get(labelIdx).toType(DataType.FLOAT32, false); + // predictions are log(softmax) + NDArray logPredictions = predictions.get(nextSentencePredictionIdx); + NDArray oneHotLabels = label.oneHot(2); + // we use negative log likelihood as loss: log(softmax) turns high confidence into + // negative values near one, low confidence into negative values near -inf, + // negating gives almost 0 for high confidence and near +inf for very low confidence + NDArray logPredictionForLabels = oneHotLabels.mul(logPredictions); + NDArray summedPredictions = logPredictionForLabels.sum(new int[] {1}); + NDArray perExampleLoss = summedPredictions.mul(-1f); + NDArray result = perExampleLoss.mean(); + return scope.ret(result); + } } /** @@ -62,15 +64,16 @@ public NDArray evaluate(NDList labels, NDList predictions) { * @return the fraction of correct predictions. */ public NDArray accuracy(NDList labels, NDList predictions) { - MemoryScope scope = MemoryScope.from(labels).add(predictions); - NDArray label = labels.get(labelIdx); - NDArray predictionLogProbs = predictions.get(nextSentencePredictionIdx); - // predictions are log(softmax) -> highest confidence is highest (negative) value near 0 - NDArray prediction = predictionLogProbs.argMax(1).toType(DataType.INT32, false); - NDArray equalCount = label.eq(prediction).sum().toType(DataType.FLOAT32, false); - NDArray result = equalCount.div(label.getShape().size()); - scope.remove(labels, predictions).waitToRead(result).close(); + try (NDManager scope = NDManager.from(labels)) { + scope.tempAttachAll(labels, predictions); + NDArray label = labels.get(labelIdx); + NDArray predictionLogProbs = predictions.get(nextSentencePredictionIdx); + // predictions are log(softmax) -> highest confidence is highest (negative) value near 0 + NDArray prediction = predictionLogProbs.argMax(1).toType(DataType.INT32, false); + NDArray equalCount = label.eq(prediction).sum().toType(DataType.FLOAT32, false); + NDArray result = equalCount.div(label.getShape().size()); - return result; + return scope.ret(result); + } } } diff --git a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java index d7f44089564..5e60f28c25c 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java @@ -70,30 +70,31 @@ protected NDList forwardInternal( NDArray typeIds = inputs.get(1); NDArray sequenceMasks = inputs.get(2); NDArray maskedIndices = inputs.get(3); - MemoryScope scope = MemoryScope.from(tokenIds).add(typeIds, sequenceMasks, maskedIndices); - // run the core bert model - NDList bertResult = - bertBlock.forward(ps, new NDList(tokenIds, typeIds, sequenceMasks), training); - NDArray embeddedSequence = bertResult.get(0); - NDArray pooledOutput = bertResult.get(1); - // apply pooled output to the classifier - NDArray nextSentenceProbabilities = - nsBlock.forward(ps, new NDList(pooledOutput), training).singletonOrThrow(); - // de-mask masked tokens - NDArray embeddingTable = - bertBlock.getTokenEmbedding().getValue(ps, embeddedSequence.getDevice(), training); - NDArray logProbs = - mlBlock.forward( - ps, - new NDList(embeddedSequence, maskedIndices, embeddingTable), - training) - .singletonOrThrow(); + try (NDManager scope = NDManager.from(tokenIds)) { + scope.tempAttachAll(inputs); + // run the core bert model + NDList bertResult = + bertBlock.forward(ps, new NDList(tokenIds, typeIds, sequenceMasks), training); + NDArray embeddedSequence = bertResult.get(0); + NDArray pooledOutput = bertResult.get(1); + // apply pooled output to the classifier + NDArray nextSentenceProbabilities = + nsBlock.forward(ps, new NDList(pooledOutput), training).singletonOrThrow(); + // de-mask masked tokens + NDArray embeddingTable = + bertBlock + .getTokenEmbedding() + .getValue(ps, embeddedSequence.getDevice(), training); + NDArray logProbs = + mlBlock.forward( + ps, + new NDList(embeddedSequence, maskedIndices, embeddingTable), + training) + .singletonOrThrow(); - scope.remove(tokenIds, typeIds, sequenceMasks, maskedIndices) - .waitToRead(nextSentenceProbabilities, logProbs) - .close(); - // return the next sentence & masked language result to apply the loss to - return new NDList(nextSentenceProbabilities, logProbs); + // return the next sentence & masked language result to apply the loss to + return scope.ret(new NDList(nextSentenceProbabilities, logProbs)); + } } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/transformer/MemoryScope.java b/api/src/main/java/ai/djl/nn/transformer/MemoryScope.java deleted file mode 100644 index e1bf4ca159a..00000000000 --- a/api/src/main/java/ai/djl/nn/transformer/MemoryScope.java +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.nn.transformer; - -import ai.djl.ndarray.LazyNDArray; -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; - -/** - * Helper class for more complicated memory management scenarios. Allows to avoid boilerplate for - * memory handling. Makes sure the sub NDManager used is connected to the correct GPU to avoid - * crashes. - */ -public final class MemoryScope implements AutoCloseable { - - private NDManager parentManager; - private NDManager subManager; - - private MemoryScope(NDManager parentManager, NDManager subManager) { - this.parentManager = parentManager; - this.subManager = subManager; - } - - /** - * Adds all arrays in the given lists to this memory scope. - * - * @param lists the lists whose arrays to add to this scope, may be empty - * @return this scope - */ - public MemoryScope add(NDList... lists) { - for (NDList list : lists) { - list.attach(subManager); - } - return this; - } - - /** - * Adds the given arrays to this scopes sub manager. - * - * @param arrays the arrays to add - * @return this scope - */ - public MemoryScope add(NDArray... arrays) { - for (NDArray array : arrays) { - array.attach(subManager); - } - return this; - } - - /** - * Remove the given arrays from this scope and attach them back to this scopes parent NDManager. - * - * @param lists the lists containing the arrays to remove - * @return this scope - */ - public MemoryScope remove(NDList... lists) { - for (NDList list : lists) { - list.attach(parentManager); - } - return this; - } - - /** - * Remove the given arrays from this scope and attach them back to this scopes parent NDManager. - * - * @param arrays arrays to remove - * @return this scope - */ - public MemoryScope remove(NDArray... arrays) { - for (NDArray array : arrays) { - array.attach(parentManager); - } - return this; - } - - /** - * Returns the NDManager used to manage this scopes resources. - * - * @return the NDManager used to manage this scopes resources - */ - public NDManager getScopeManager() { - return subManager; - } - - /** - * Waits for all given arrays to be ready to read, i.e. waits for pending computations that - * write to them, then removes them from this scope. - * - * @param arrays arrays to wait for - * @return this scope - */ - public MemoryScope waitToRead(NDArray... arrays) { - for (NDArray array : arrays) { - if (array instanceof LazyNDArray) { - LazyNDArray lazyNDArray = (LazyNDArray) array; - lazyNDArray.waitToRead(); - } - remove(array); - } - return this; - } - - /** - * Waits for all arrays in all given lists to be ready to be read, i.e. waits for pending - * computations that write to them, then removes them from this scope. - * - * @param lists may be empty - * @return this scope - */ - public MemoryScope waitToRead(NDList... lists) { - for (NDList list : lists) { - if (list != null) { - for (NDArray array : list) { - waitToRead(array); - } - } - } - return this; - } - - /** - * Closes this scope by closing the sub manager used to manage it. This causes all arrays still - * attached to this scope to be closed as well. - */ - @Override - public void close() { - subManager.close(); - } - - /** - * Creates a new memory scope for the device of the given array and adds the array. - * - * @param ndArray an array - * @return a new memory scrope containing the array - */ - public static MemoryScope from(final NDArray ndArray) { - return new MemoryScope( - ndArray.getManager(), - ndArray.getManager().newSubManager(ndArray.getDevice())) - .add(ndArray); - } - - /** - * Creates a new memory scope that fits the device of the first array in the given list, adds - * all arrays in the given list. - * - * @param list a list of arrays, must not be empty - * @return a new memory scope - */ - public static MemoryScope from(final NDList list) { - NDArray ndArray = list.head(); - return new MemoryScope( - ndArray.getManager(), - ndArray.getManager().newSubManager(ndArray.getDevice())) - .add(list); - } -} diff --git a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDArray.java b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDArray.java index 8cb8fa653cc..f8c1b8fc119 100644 --- a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDArray.java +++ b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDArray.java @@ -45,7 +45,7 @@ public class DlrNDArray implements NDArrayAdapter { this.data = data; this.shape = shape; uid = UUID.randomUUID().toString(); - manager.attach(uid, this); + manager.attachInternal(uid, this); } /** {@inheritDoc} */ @@ -94,18 +94,25 @@ public Shape getShape() { /** {@inheritDoc} */ @Override - public NDManager attach(NDManager manager) { + public void attach(NDManager manager) { + detach(); + this.manager = (DlrNDManager) manager; + manager.attachInternal(getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { detach(); NDManager original = this.manager; this.manager = (DlrNDManager) manager; - manager.attach(getUid(), this); - return original; + manager.tempAttachInternal(original, getUid(), this); } /** {@inheritDoc} */ @Override public void detach() { - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = DlrNDManager.getSystemManager(); } diff --git a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java index f879c907cc3..ec87ba6675d 100644 --- a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java +++ b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java @@ -54,7 +54,7 @@ public ByteBuffer allocateDirect(int capacity) { @Override public DlrNDManager newSubManager(Device dev) { DlrNDManager manager = new DlrNDManager(this, dev); - attach(manager.uid, manager); + attachInternal(manager.uid, manager); return manager; } @@ -105,11 +105,11 @@ private static final class SystemManager extends DlrNDManager { /** {@inheritDoc} */ @Override - public void attach(String resourceId, AutoCloseable resource) {} + public void attachInternal(String resourceId, AutoCloseable resource) {} /** {@inheritDoc} */ @Override - public void detach(String resourceId) {} + public void detachInternal(String resourceId) {} /** {@inheritDoc} */ @Override diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java index 30460a25592..d13c75e684a 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java @@ -73,7 +73,7 @@ public CachedOp( this.dataIndicesMap = dataIndices.toMap(); // holds all parameter and data NDArray values, final inputs to CachedOp this.manager = manager; - manager.attach(getUid(), this); + manager.attachInternal(getUid(), this); } /** @@ -139,7 +139,7 @@ public NDList forward(ParameterStore parameterStore, NDList data, boolean traini public void close() { Pointer pointer = handle.getAndSet(null); if (pointer != null) { - manager.detach(getUid()); + manager.detachInternal(getUid()); JnaUtils.freeCachedOp(pointer); manager = null; } diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index e8ecde776a9..e9514b61d71 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -92,7 +92,7 @@ public class MxNDArray extends NativeResource implements LazyNDArray { super(handle); this.manager = manager; mxNDArrayEx = new MxNDArrayEx(this); - manager.attach(getUid(), this); + manager.attachInternal(getUid(), this); } /** @@ -163,18 +163,25 @@ public SparseFormat getSparseFormat() { /** {@inheritDoc} */ @Override - public NDManager attach(NDManager manager) { + public void attach(NDManager manager) { + detach(); + this.manager = (MxNDManager) manager; + manager.attachInternal(getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { NDManager original = this.manager; detach(); this.manager = (MxNDManager) manager; - manager.attach(getUid(), this); - return original; + manager.tempAttachInternal(original, getUid(), this); } /** {@inheritDoc} */ @Override public void detach() { - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = MxNDManager.getSystemManager(); } @@ -1602,7 +1609,7 @@ public void close() { if (pointer != null) { JnaUtils.waitToRead(pointer); JnaUtils.freeNdArray(pointer); - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = null; } } diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java index 09f154e964d..5e9c32cf047 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java @@ -377,8 +377,7 @@ public void adadeltaUpdate( // create a baseManager to close all intermediate NDArrays try (NDManager subManager = NDManager.newBaseManager()) { - List inputManagers = inputs.attach(subManager); - List weightManagers = weights.attach(subManager); + subManager.tempAttachAll(inputs, weights); // Preprocess Gradient grad.muli(rescaleGrad); @@ -394,10 +393,6 @@ public void adadeltaUpdate( // Update weight weight.subi(g); - - // attach back to their previous managers - inputs.attach(inputManagers); - weights.attach(weightManagers); } } diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java index 1effa33d3a5..b76f0492c52 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java @@ -272,7 +272,7 @@ public NDArray randomMultinomial(int n, NDArray pValues) { @Override public MxNDManager newSubManager(Device dev) { MxNDManager manager = new MxNDManager(this, dev, version); - attach(manager.uid, manager); + attachInternal(manager.uid, manager); return manager; } @@ -386,11 +386,11 @@ private static final class SystemManager extends MxNDManager { /** {@inheritDoc} */ @Override - public void attach(String resourceId, AutoCloseable resource) {} + public void attachInternal(String resourceId, AutoCloseable resource) {} /** {@inheritDoc} */ @Override - public void detach(String resourceId) {} + public void detachInternal(String resourceId) {} /** {@inheritDoc} */ @Override diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/Symbol.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/Symbol.java index 9bd825b3c7d..8fcc92e5061 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/Symbol.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/Symbol.java @@ -52,7 +52,7 @@ public class Symbol extends NativeResource { Symbol(MxNDManager manager, Pointer pointer) { super(pointer); this.manager = manager; - manager.attach(getUid(), this); + manager.attachInternal(getUid(), this); // argParams = JnaUtils.listSymbolArguments(getHandle()); // auxParams = JnaUtils.listSymbolAuxiliaryStates(getHandle()); } @@ -311,7 +311,7 @@ public String toString() { public void close() { Pointer pointer = handle.getAndSet(null); if (pointer != null) { - manager.detach(getUid()); + manager.detachInternal(getUid()); JnaUtils.freeSymbol(pointer); manager = null; } diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java index 42556a788ca..be01b25b9d4 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java @@ -44,7 +44,7 @@ public class OrtNDArray implements NDArrayAdapter { this.manager = manager; this.tensor = tensor; uid = UUID.randomUUID().toString(); - manager.attach(uid, this); + manager.attachInternal(uid, this); } OnnxTensor getTensor() { @@ -102,18 +102,25 @@ public Shape getShape() { /** {@inheritDoc} */ @Override - public NDManager attach(NDManager manager) { + public void attach(NDManager manager) { + detach(); + this.manager = (OrtNDManager) manager; + manager.attachInternal(getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { detach(); NDManager original = this.manager; this.manager = (OrtNDManager) manager; - manager.attach(getUid(), this); - return original; + manager.tempAttachInternal(original, getUid(), this); } /** {@inheritDoc} */ @Override public void detach() { - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = OrtNDManager.getSystemManager(); } diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java index ee427153209..c2d2b0671e4 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java @@ -109,7 +109,7 @@ public NDArray ones(Shape shape, DataType dataType) { @Override public OrtNDManager newSubManager(Device device) { OrtNDManager manager = new OrtNDManager(this, device, env); - attach(manager.uid, manager); + attachInternal(manager.uid, manager); return manager; } @@ -128,11 +128,11 @@ private static final class SystemManager extends OrtNDManager { /** {@inheritDoc} */ @Override - public void attach(String resourceId, AutoCloseable resource) {} + public void attachInternal(String resourceId, AutoCloseable resource) {} /** {@inheritDoc} */ @Override - public void detach(String resourceId) {} + public void detachInternal(String resourceId) {} /** {@inheritDoc} */ @Override diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java index 121213161ee..9757efb99eb 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java @@ -39,7 +39,7 @@ public class PpNDArray extends NativeResource implements NDArrayAdapter { public PpNDArray(PpNDManager manager, long handle) { super(handle); this.manager = manager; - manager.attach(getUid(), this); + manager.attachInternal(getUid(), this); } /** @@ -56,7 +56,7 @@ public PpNDArray(PpNDManager manager, long pointer, Shape shape, DataType dataTy this.manager = manager; this.shape = shape; this.dataType = dataType; - manager.attach(getUid(), this); + manager.attachInternal(getUid(), this); } /** {@inheritDoc} */ @@ -103,18 +103,25 @@ public Shape getShape() { /** {@inheritDoc} */ @Override - public NDManager attach(NDManager manager) { + public void attach(NDManager manager) { + detach(); + this.manager = (PpNDManager) manager; + manager.attachInternal(getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { detach(); NDManager original = this.manager; this.manager = (PpNDManager) manager; - manager.attach(getUid(), this); - return original; + manager.tempAttachInternal(original, getUid(), this); } /** {@inheritDoc} */ @Override public void detach() { - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = PpNDManager.getSystemManager(); } diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java index b277456af7b..f0aca98460d 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java @@ -51,7 +51,7 @@ public PpNDManager newSubManager() { @Override public PpNDManager newSubManager(Device device) { PpNDManager manager = new PpNDManager(this, device); - attach(manager.uid, manager); + attachInternal(manager.uid, manager); return manager; } @@ -156,11 +156,11 @@ private static final class SystemManager extends PpNDManager { /** {@inheritDoc} */ @Override - public void attach(String resourceId, AutoCloseable resource) {} + public void attachInternal(String resourceId, AutoCloseable resource) {} /** {@inheritDoc} */ @Override - public void detach(String resourceId) {} + public void detachInternal(String resourceId) {} /** {@inheritDoc} */ @Override diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index dedf1988561..fea19f8b576 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -69,7 +69,7 @@ public PtNDArray(PtNDManager manager, long handle) { super(handle); this.manager = manager; this.ptNDArrayEx = new PtNDArrayEx(this); - manager.attach(getUid(), this); + manager.attachInternal(getUid(), this); } /** @@ -84,7 +84,7 @@ public PtNDArray(PtNDManager manager, long handle, ByteBuffer data) { super(handle); this.manager = manager; this.ptNDArrayEx = new PtNDArrayEx(this); - manager.attach(getUid(), this); + manager.attachInternal(getUid(), this); dataRef = data; } @@ -279,18 +279,25 @@ public void copyTo(NDArray array) { /** {@inheritDoc} */ @Override - public NDManager attach(NDManager manager) { + public void attach(NDManager manager) { + detach(); + this.manager = (PtNDManager) manager; + manager.attachInternal(getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { detach(); NDManager original = this.manager; this.manager = (PtNDManager) manager; - manager.attach(getUid(), this); - return original; + manager.tempAttachInternal(original, getUid(), this); } /** {@inheritDoc} */ @Override public void detach() { - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = PtNDManager.getSystemManager(); } @@ -1436,7 +1443,7 @@ public void close() { Long pointer = handle.getAndSet(null); if (pointer != null) { JniUtils.deleteNDArray(pointer); - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = null; dataRef = null; } diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index ac51d54180c..7d2fa9a9123 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -181,7 +181,7 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy @Override public PtNDManager newSubManager(Device device) { PtNDManager manager = new PtNDManager(this, device); - attach(manager.uid, manager); + attachInternal(manager.uid, manager); return manager; } @@ -200,11 +200,11 @@ private static final class SystemManager extends PtNDManager { /** {@inheritDoc} */ @Override - public void attach(String resourceId, AutoCloseable resource) {} + public void attachInternal(String resourceId, AutoCloseable resource) {} /** {@inheritDoc} */ @Override - public void detach(String resourceId) {} + public void detachInternal(String resourceId) {} /** {@inheritDoc} */ @Override diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java index 964649cd5a1..9eebb11846b 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java @@ -65,7 +65,7 @@ public PtSymbolBlock(PtNDManager manager, long handle) { this.handle = new AtomicReference<>(handle); this.manager = manager; uid = String.valueOf(handle); - manager.attach(uid, this); + manager.attachInternal(uid, this); // training mode is on by default isTrain = true; first = true; @@ -90,7 +90,7 @@ public void close() { Long pointer = handle.getAndSet(null); if (pointer != null) { JniUtils.deleteModule(pointer); - manager.detach(uid); + manager.detachInternal(uid); manager = null; } } @@ -177,7 +177,7 @@ public void loadParameters(NDManager manager, DataInputStream is) long rawHandle = JniUtils.loadModuleHandle(is, manager.getDevice(), true); this.handle = new AtomicReference<>(rawHandle); uid = String.valueOf(rawHandle); - manager.attach(uid, this); + manager.attachInternal(uid, this); } /** diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 478257d23b7..e9931638df6 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -73,7 +73,7 @@ public class TfNDArray implements NDArray { this.manager = (TfNDManager) manager; this.tf = this.manager.getTf(); uid = UUID.randomUUID().toString(); - manager.attach(uid, this); + manager.attachInternal(uid, this); this.operand = this.manager .getEagerSession() @@ -279,18 +279,25 @@ public void set(Buffer data) { /** {@inheritDoc} */ @Override - public NDManager attach(NDManager manager) { + public void attach(NDManager manager) { + detach(); + this.manager = (TfNDManager) manager; + manager.attachInternal(uid, this); + } + + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { detach(); NDManager original = this.manager; this.manager = (TfNDManager) manager; - manager.attach(uid, this); - return original; + manager.tempAttachInternal(original, uid, this); } /** {@inheritDoc} */ @Override public void detach() { - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = TfNDManager.getSystemManager(); } diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java index efa991c8008..4d22ad6f761 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java @@ -416,7 +416,7 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy @Override public TfNDManager newSubManager(Device device) { TfNDManager manager = new TfNDManager(this, device); - attach(manager.uid, manager); + attachInternal(manager.uid, manager); // initialize eager sessions and operators only for sub managers manager.getEagerSession(); manager.getTf(); @@ -440,11 +440,11 @@ private static final class SystemManager extends TfNDManager { /** {@inheritDoc} */ @Override - public void attach(String resrouceId, AutoCloseable resource) {} + public void attachInternal(String resrouceId, AutoCloseable resource) {} /** {@inheritDoc} */ @Override - public void detach(String resourceId) {} + public void detachInternal(String resourceId) {} /** {@inheritDoc} */ @Override diff --git a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDArray.java b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDArray.java index 9b59b4893ab..86a87cc8fac 100644 --- a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDArray.java +++ b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDArray.java @@ -39,7 +39,7 @@ public class TfLiteNDArray implements NDArrayAdapter { TfLiteNDArray(TfLiteNDManager manager, Tensor tensor) { this.manager = manager; uid = UUID.randomUUID().toString(); - manager.attach(uid, this); + manager.attachInternal(uid, this); this.tensor = tensor; shape = new Shape(Arrays.stream(tensor.shape()).mapToLong(i -> i).toArray()); dataType = TfLiteDataType.fromTf(tensor.dataType()); @@ -103,18 +103,25 @@ public SparseFormat getSparseFormat() { /** {@inheritDoc} */ @Override - public NDManager attach(NDManager manager) { + public void attach(NDManager manager) { + detach(); + this.manager = (TfLiteNDManager) manager; + manager.attachInternal(getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { detach(); NDManager original = this.manager; this.manager = (TfLiteNDManager) manager; - manager.attach(getUid(), this); - return original; + manager.tempAttachInternal(original, getUid(), this); } /** {@inheritDoc} */ @Override public void detach() { - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = TfLiteNDManager.getSystemManager(); } diff --git a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java index 8d55da2e7a1..64da62767f7 100644 --- a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java +++ b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java @@ -132,7 +132,7 @@ public NDArray ones(Shape shape, DataType dataType) { @Override public TfLiteNDManager newSubManager(Device device) { TfLiteNDManager manager = new TfLiteNDManager(this, device); - attach(manager.uid, manager); + attachInternal(manager.uid, manager); return manager; } @@ -151,11 +151,11 @@ private static final class SystemManager extends TfLiteNDManager { /** {@inheritDoc} */ @Override - public void attach(String resourceId, AutoCloseable resource) {} + public void attachInternal(String resourceId, AutoCloseable resource) {} /** {@inheritDoc} */ @Override - public void detach(String resourceId) {} + public void detachInternal(String resourceId) {} /** {@inheritDoc} */ @Override