Skip to content

Commit

Permalink
Remove memory scope and improve memory management
Browse files Browse the repository at this point in the history
The MemoryScope reveals a number of shortcomings within the DJL memory
management. While the MemoryScope is deleted, many of them are fixed as part of
this PR.

First, the NDManager.{attach, detach} were renamed to xxxInternal. This is to
differentiate them from the attach and detach methods that are intended to be used.

There are two new concepts in memory management. An NDResource interface was
created to combine the concepts of managed memory that was used in NDArray and
NDList. It could also be used in more classes in the future. This includes the
getManager, attach, and detach.

Within the NDManager, it gains a second "management convention". The first
convention of normal resources are added to the manager and then closed when the
manager closes. This works for small numbers of things on the NDArray, but not
when operations transitively create. So, the second convention is a
tempResource. Instead of freeing them when the manager is closed, they are
returned to their original manager. This is used to create a temporary scope, do
operations within it, and then the inputs and return value are returned to the
parent while the intermediate work is cleaned. This also matches the concepts of
ownership/borrowing as well.

Using these, a few additional helper methods were created. There is
`NDManager.from(resource)` to ease creation of managers based on a resource.
There is also `scopeManager.ret(returnValue)` to help with returning values
outside of the scopeManager. Lastly, there is a `scopeManager.{temp,}AttachAll`
to attach a number of resources to a manager within a single call.

Using these improvements, the new method were applied to the old locations where
MemoryScope was used as well as an additional case in NDManagerEx.

Also, the old attach methods were altered to be `void`. Because the return
values are no longer used anywhere and are not as necessary in the current
scheme, I figured it would simplify things. It also helps for things like
`NDList.attach` which does not have a single original NDManager when attaching.

Change-Id: I91d109cd14d70fa64fd8fffa0b50d88ab053013e
  • Loading branch information
zachgk committed Feb 25, 2021
1 parent b4a93e4 commit 0b48feb
Show file tree
Hide file tree
Showing 30 changed files with 410 additions and 416 deletions.
26 changes: 23 additions & 3 deletions api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ai.djl.Device;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.nio.Buffer;
import java.nio.file.Path;
Expand All @@ -34,12 +35,14 @@ public abstract class BaseNDManager implements NDManager {
protected String name;
protected Device device;
protected ConcurrentHashMap<String, AutoCloseable> resources;
protected ConcurrentHashMap<String, Pair<NDResource, NDManager>> tempResources;
protected AtomicBoolean closed = new AtomicBoolean(false);

protected BaseNDManager(NDManager parent, Device device) {
this.parent = parent;
this.device = Device.defaultIfNull(device, getEngine());
resources = new ConcurrentHashMap<>();
tempResources = new ConcurrentHashMap<>();
uid = UUID.randomUUID().toString();
}

Expand Down Expand Up @@ -197,7 +200,7 @@ public String toString() {

/** {@inheritDoc} */
@Override
public synchronized void attach(String resourceId, AutoCloseable resource) {
public synchronized void attachInternal(String resourceId, AutoCloseable resource) {
if (closed.get()) {
throw new IllegalStateException("NDManager has been closed already.");
}
Expand All @@ -206,7 +209,17 @@ public synchronized void attach(String resourceId, AutoCloseable resource) {

/** {@inheritDoc} */
@Override
public synchronized void detach(String resourceId) {
public void tempAttachInternal(
NDManager originalManager, String resourceId, NDResource resource) {
if (closed.get()) {
throw new IllegalStateException("NDManager has been closed already.");
}
tempResources.put(resourceId, new Pair<>(resource, originalManager));
}

/** {@inheritDoc} */
@Override
public synchronized void detachInternal(String resourceId) {
if (closed.get()) {
// This may happen in the middle of BaseNDManager.close()
return;
Expand Down Expand Up @@ -238,7 +251,14 @@ public synchronized void close() {
logger.error("Resource close failed.", e);
}
}
parent.detach(uid);
for (Pair<NDResource, NDManager> resource : tempResources.values()) {
try {
resource.getKey().attach(resource.getValue());
} catch (Exception e) {
logger.error("Temporary resource return failed.", e);
}
}
parent.detachInternal(uid);
resources.clear();
}
}
Expand Down
30 changes: 1 addition & 29 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
* href="https:/awslabs/djl/blob/master/docs/development/memory_management.md">NDArray
* Memory Management Guide</a>
*/
public interface NDArray extends AutoCloseable {
public interface NDArray extends NDResource {

/**
* Decodes {@code NDArray} from bytes.
Expand All @@ -53,13 +53,6 @@ static NDArray decode(NDManager manager, byte[] byteArray) {
return manager.decode(byteArray);
}

/**
* Returns the {@link NDManager} used to create this {@code NDArray}.
*
* @return the {@link NDManager} used to create this {@code NDArray}
*/
NDManager getManager();

/**
* Returns the name of this {@code NDArray}.
*
Expand Down Expand Up @@ -146,27 +139,6 @@ default byte[] encode() {
return NDSerializer.encode(this);
}

/**
* Attaches this {@code NDArray} to the specified {@link NDManager}.
*
* <p>Attached resource will be closed when the {@link NDManager} is closed.
*
* @param manager the {@link NDManager} to be attached
* @return the original {@link NDManager}
*/
NDManager attach(NDManager manager);

/**
* Detaches the {@code NDArray} from current {@link NDManager}'s lifecycle.
*
* <p>The {@code NDArray} becomes un-managed, it is the user's responsibility to close the
* {@code NDArray}. Failure to close the resource might cause your machine to run out of native
* memory.
*
* @see NDManager
*/
void detach();

/**
* Moves this {@code NDArray} to a different {@link Device}.
*
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ default SparseFormat getSparseFormat() {

/** {@inheritDoc} */
@Override
default NDManager attach(NDManager manager) {
default void attach(NDManager manager) {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

Expand Down
47 changes: 18 additions & 29 deletions api/src/main/java/ai/djl/ndarray/NDList.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
* An {@code NDList} represents a sequence of {@link NDArray}s with names.
Expand All @@ -34,7 +31,7 @@
*
* @see NDArray
*/
public class NDList extends ArrayList<NDArray> implements AutoCloseable {
public class NDList extends ArrayList<NDArray> implements NDResource {

private static final long serialVersionUID = 1L;

Expand Down Expand Up @@ -200,36 +197,28 @@ public NDList toDevice(Device device, boolean copy) {
return newNDList;
}

/**
* Attaches each ndarray in this list to the specified manager.
*
* @param manager the manager to attach the lists to
* @return a list of {@code NDManager} with which original NDArray are attached
* @see NDArray#attach(NDManager)
*/
public List<NDManager> attach(NDManager manager) {
return stream().map(array -> array.attach(manager)).collect(Collectors.toList());
/** {@inheritDoc} */
@Override
public NDManager getManager() {
return head().getManager();
}

/**
* Attaches each ndarray in this list to the specified manager.
*
* @param managers the list of managers to attach
* @return a list of {@code NDManager} with which original NDArray are attached
*/
public List<NDManager> attach(List<NDManager> managers) {
return IntStream.range(0, size())
.mapToObj(i -> get(i).attach(managers.get(i)))
.collect(Collectors.toList());
/** {@inheritDoc} */
@Override
public void attach(NDManager manager) {
stream().forEach(array -> array.attach(manager));
}

/**
* Detaches each ndarray in this list from their current managers.
*
* @see NDArray#detach()
*/
/** {@inheritDoc} */
@Override
public void tempAttach(NDManager manager) {
stream().forEach(array -> array.tempAttach(manager));
}

/** {@inheritDoc} */
@Override
public void detach() {
forEach(NDArray::detach);
stream().forEach(NDResource::detach);
}

/**
Expand Down
78 changes: 74 additions & 4 deletions api/src/main/java/ai/djl/ndarray/NDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ static NDManager newBaseManager(Device device, String engineName) {
return Engine.getEngine(engineName).newBaseManager(device);
}

/**
* Creates a new manager based on the given resource.
*
* @param resource the resource to use
* @return a new memory scrope containing the array
*/
static NDManager from(NDResource resource) {
return resource.getManager().newSubManager();
}

/**
* Allocates a new engine specific direct byte buffer.
*
Expand Down Expand Up @@ -1274,14 +1284,34 @@ default NDArray randomNormal(
Device getDevice();

/**
* Attaches a {@link NDArray} or {@code NDManager} to this {@code NDManager}.
* Attaches a resource to this {@code NDManager}.
*
* <p>The attached resource will be closed when this {@code NDManager} is closed.
*
* <p>This attachment is internal. Many resources will internally track which manager they are
* attached to. In that case, you should call {@link NDResource#attach(NDManager)} instead and
* that should then call attachInternal.
*
* @param resourceId the unique resourceId
* @param resource the {@link AutoCloseable} resource to be attached
*/
void attachInternal(String resourceId, AutoCloseable resource);

/**
* Temporarily attaches a resource to this {@code NDManager} to be returned when this is closed.
*
* <p>The attached resource will be returned to it's original manager when this {@code
* NDManager} is closed.
*
* <p>Attached resource will be closed when this {@code NDManager} is closed.
* <p>This attachment is internal. Many resources will internally track which manager they are
* attached to. In that case, you should call {@link NDResource#attach(NDManager)} instead and
* that should then call tempAttachInternal.
*
* @param originalManager the original manager to return the resource to
* @param resourceId the unique resourceId
* @param resource the {@link AutoCloseable} resource to be attached
*/
void attach(String resourceId, AutoCloseable resource);
void tempAttachInternal(NDManager originalManager, String resourceId, NDResource resource);

/**
* Detaches a {@link NDArray} from this {@code NDManager}'s lifecycle.
Expand All @@ -1290,9 +1320,49 @@ default NDArray randomNormal(
* resource. Failed to close the resource has to wait on GC to be freed, and might cause out of
* native memory.
*
* <p>This detach is internal. Many resources will internally track which manager they are
* attached to. In that case, you should call {@link NDResource#detach()} instead and that
* should then call detachInternal.
*
* @param resourceId the resourceId to be removed from this {@code NDManager}'s lifecycle
*/
void detach(String resourceId);
void detachInternal(String resourceId);

/**
* Returns a value outside of this manager by attaching to this manager's parent.
*
* @param resource the resource to return
* @param <T> the type of the resource
* @return the passed in resource, after attaching to a new manager
*/
default <T extends NDResource> T ret(T resource) {
resource.attach(getParentManager());
return resource;
}

/**
* Attaches all resources to this manager.
*
* @param resources the resources to attach
* @see NDResource#attach(NDManager)
*/
default void attachAll(NDResource... resources) {
for (NDResource resource : resources) {
resource.attach(this);
}
}

/**
* Temporarily attaches all resources to this manager.
*
* @param resources the resources to attach
* @see NDResource#tempAttach(NDManager)
*/
default void tempAttachAll(NDResource... resources) {
for (NDResource resource : resources) {
resource.tempAttach(this);
}
}

/**
* An engine specific generic invocation to native operation.
Expand Down
57 changes: 57 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDResource.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ndarray;

/** An object which is managed by an {@link NDManager} and tracks the manager it is attached to. */
public interface NDResource extends AutoCloseable {

/**
* Returns the {@link NDManager} that manages this.
*
* @return the {@link NDManager} that manages this.
*/
NDManager getManager();

/**
* Attaches this {@link NDResource} to the specified {@link NDManager}.
*
* <p>Attached resource will be closed when the {@link NDManager} is closed.
*
* @param manager the {@link NDManager} to be attached to
*/
void attach(NDManager manager);

/**
* Temporarily attaches this {@link NDResource} to the specified {@link NDManager}.
*
* <p>Attached resource will be returned to the original manager when the {@link NDManager} is
* closed.
*
* @param manager the {@link NDManager} to be attached to
*/
void tempAttach(NDManager manager);

/**
* Detaches the {@link NDResource} from current {@link NDManager}'s lifecycle.
*
* <p>This becomes un-managed and it is the user's responsibility to close this. Failure to
* close the resource might cause your machine to run out of native memory.
*
* @see NDManager
*/
void detach();

/** {@inheritDoc} */
@Override
void close();
}
20 changes: 10 additions & 10 deletions api/src/main/java/ai/djl/nn/transformer/BertBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ public NDList forward(final ParameterStore ps, final NDList inputs, boolean trai
*/
public NDList forward(
ParameterStore ps, NDArray tokenIds, NDArray typeIds, NDArray masks, boolean training) {
MemoryScope initScope = MemoryScope.from(tokenIds).add(typeIds, masks);
NDManager initScope = NDManager.from(tokenIds);
initScope.tempAttachAll(tokenIds, typeIds, masks);

// Create embeddings for inputs
NDArray embeddedTokens = tokenEmbedding.forward(ps, tokenIds, training);
NDArray embeddedTypes = typeEmbedding.forward(ps, typeIds, training);
Expand All @@ -257,16 +259,14 @@ public NDList forward(
.mul(-100000f); // turn 1s (original 0s) into -100000
// Run through all transformer blocks
NDList lastOutput = dropoutEmbedding;
initScope
.remove(tokenIds, typeIds, masks)
.waitToRead(dropoutEmbedding)
.waitToRead(offsetMask)
.close();
for (final TransformerEncoderBlock block : transformerEncoderBlocks) {
initScope.close();

for (TransformerEncoderBlock block : transformerEncoderBlocks) {
NDList input = new NDList(lastOutput.head(), offsetMask);
MemoryScope innerScope = MemoryScope.from(input);
lastOutput = block.forward(ps, input, training);
innerScope.remove(offsetMask).waitToRead(lastOutput).close();
try (NDManager innerManager = NDManager.from(input)) {
innerManager.tempAttachAll(input);
lastOutput = block.forward(ps, input, training);
}
}
// We also return the pooled output - this is an additional fully connected layer
// only applied to the first token, assumed to be the CLS token to be used for training
Expand Down
Loading

0 comments on commit 0b48feb

Please sign in to comment.