Skip to content

Commit

Permalink
memory copy + fixes and doc
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Mar 11, 2024
1 parent c048e77 commit acfe070
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 17 deletions.
4 changes: 4 additions & 0 deletions src/gflownet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ class Config:
The hostname of the machine on which the experiment is run
pickle_mp_messages : bool
Whether to pickle messages sent between processes (only relevant if num_workers > 0)
mp_buffer_size : Optional[int]
If specified, use a buffer of this size for passing tensors between processes.
Note that this is only relevant if num_workers > 0.
Also note that this will allocate `num_workers + 2 * number of wrapped objects` buffers.
git_hash : Optional[str]
The git hash of the current commit
overwrite_existing_exp : bool
Expand Down
3 changes: 2 additions & 1 deletion src/gflownet/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from gflownet.config import Config
from gflownet.data.replay_buffer import ReplayBuffer
from gflownet.envs.graph_building_env import GraphBuildingEnvContext
from gflownet.envs.seq_building_env import SeqBatch
from gflownet.utils.misc import get_worker_rng
from gflownet.utils.multiprocessing_proxy import BufferPickler, SharedPinnedBuffer

Expand Down Expand Up @@ -332,7 +333,7 @@ def setup_mp_buffers(self):

def _maybe_put_in_mp_buffer(self, batch):
if self.mp_buffer_size:
if not (isinstance(batch, Batch)):
if not (isinstance(batch, (Batch, SeqBatch))):
warnings.warn(f"Expected a Batch object, but got {type(batch)}. Not using mp buffers.")
return batch
return (BufferPickler(self.result_buffer[self._wid]).dumps(batch), self._wid)
Expand Down
9 changes: 5 additions & 4 deletions src/gflownet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,13 @@ def evaluate_batch(self, batch: gd.Batch, epoch_idx: int = 0, batch_idx: int = 0
info["eval_time"] = time.time() - tick
return {k: v.item() if hasattr(v, "item") else v for k, v in info.items()}

def _maybe_resolve_shared_buffer(self, batch: Union[Batch, tuple, list], dl: DataLoader) -> Batch:
if dl.dataset.mp_buffer_size > 0 and isinstance(batch, (tuple, list)):
def _maybe_resolve_shared_buffer(
self, batch: Union[Batch, SeqBatch, tuple, list], dl: DataLoader
) -> Union[Batch, SeqBatch]:
if dl.dataset.mp_buffer_size and isinstance(batch, (tuple, list)):
batch, wid = batch
batch = BufferUnpickler(dl.dataset.result_buffer[wid], batch, self.device).load()
elif isinstance(batch, Batch):
elif isinstance(batch, (Batch, SeqBatch)):
batch = batch.to(self.device)
return batch

Expand Down Expand Up @@ -285,7 +287,6 @@ def run(self, logger=None):
if it % 1024 == 0:
gc.collect()
torch.cuda.empty_cache()
_bd = batch
batch = self._maybe_resolve_shared_buffer(batch, train_dl)
t1 = time.time()
times.append(t1 - t0)
Expand Down
51 changes: 39 additions & 12 deletions src/gflownet/utils/multiprocessing_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ def __init__(self, size):
assert self.buffer.is_pinned()

def __del__(self):
if self.do_unreg and torch.utils.data.get_worker_info() is None:
cudart = torch.cuda.cudart()
r = cudart.cudaHostUnregister(self.buffer.data_ptr())
assert r == 0
if torch.utils.data.get_worker_info() is None:
if self.do_unreg:
cudart = torch.cuda.cudart()
r = cudart.cudaHostUnregister(self.buffer.data_ptr())
assert r == 0


class _BufferPicklerSentinel:
Expand All @@ -43,20 +44,39 @@ def __init__(self, buf: SharedPinnedBuffer):
self._f = io.BytesIO()
super().__init__(self._f)
self.buf = buf
# The lock will be released by the consumer of this buffer once the memory has been transferred to the device
# The lock will be released by the consumer (BufferUnpickler) of this buffer once
# the memory has been transferred to the device and copied
self.buf.lock.acquire()
self.buf_offset = 0

def persistent_id(self, v):
if not isinstance(v, torch.Tensor):
return None
numel = v.numel() * v.element_size()
if self.buf_offset + numel > self.buf.size:
raise RuntimeError(
f"Tried to allocate {self.buf_offset + numel} bytes in a buffer of size {self.buf.size}. "
"Consider increasing cfg.mp_buffer_size"
)
start = self.buf_offset
shape = tuple(v.shape)
if v.ndim > 0 and v.stride(-1) != 1 or not v.is_contiguous():
v = v.contiguous().reshape(-1)
if v.ndim > 0 and v.stride(-1) != 1:
# We're still not contiguous, this unfortunately happens occasionally, e.g.:
# x = torch.arange(10).reshape((10, 1))
# y = x.T[::2].T
# y.stride(), y.is_contiguous(), y.contiguous().stride()
# -> (1, 2), True, (1, 2)
v = v.flatten() + 0
# I don't know if this comes from my misunderstanding of strides or if it's a bug in torch
# but either way torch will refuse to view this tensor as a uint8 tensor, so we have to + 0
# to force torch to materialize it into a new tensor (it may otherwise be lazy and not materialize)
if numel > 0:
self.buf.buffer[start : start + numel] = v.view(-1).view(torch.uint8)
self.buf.buffer[start : start + numel] = v.flatten().view(torch.uint8)
self.buf_offset += numel
self.buf_offset += (8 - self.buf_offset % 8) % 8 # align to 8 bytes
return (_BufferPicklerSentinel, (start, tuple(v.shape), v.dtype))
return (_BufferPicklerSentinel, (start, shape, v.dtype))

def dumps(self, obj):
self.dump(obj)
Expand All @@ -68,11 +88,19 @@ def __init__(self, buf: SharedPinnedBuffer, data, device):
self._f, total_size = io.BytesIO(data[0]), data[1]
super().__init__(self._f)
self.buf = buf
self.target_buf = buf.buffer[:total_size].to(device)
self.target_buf = buf.buffer[:total_size].to(device) + 0
# Why the `+ 0`? Unfortunately, we have no way to know exactly when the consumer of the object we're
# unpickling will be done using the buffer underlying the tensor, so we have to create a copy.
# If we don't and another consumer starts using the buffer, and this consumer transfers this pinned
# buffer to the GPU, the first consumer's tensors will be corrupted, because (depending on the CUDA
# memory manager) the pinned buffer will transfer to the same GPU location.
# Hopefully, especially if the target device is the GPU, the copy will be fast and/or async.
# Note that this could be fixed by using one buffer for each worker, but that would be significantly
# more memory usage.

def load_tensor(self, offset, shape, dtype):
numel = prod(shape) * dtype.itemsize
tensor = self.target_buf[offset : offset + numel].view(dtype).view(shape)
tensor: torch.Tensor = self.target_buf[offset : offset + numel].view(dtype).view(shape)
return tensor

def persistent_load(self, pid):
Expand Down Expand Up @@ -107,7 +135,7 @@ def __init__(self, in_queues, out_queues, pickle_messages=False, shared_buffer_s
self.pickle_messages = pickle_messages
self._is_init = False
self.shared_buffer_size = shared_buffer_size
if shared_buffer_size is not None:
if shared_buffer_size:
self._buffer_to_main = SharedPinnedBuffer(shared_buffer_size)
self._buffer_from_main = SharedPinnedBuffer(shared_buffer_size)

Expand Down Expand Up @@ -194,7 +222,7 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo
self.in_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore
self.out_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore
self.pickle_messages = pickle_messages
self.use_shared_buffer = sb_size is not None
self.use_shared_buffer = bool(sb_size)
self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages, sb_size)
self.obj = obj
if hasattr(obj, "parameters"):
Expand Down Expand Up @@ -226,7 +254,6 @@ def to_cpu(self, i):

def run(self):
timeouts = 0

while not self.stop.is_set() and timeouts < 5 / 1e-5:
for qi, q in enumerate(self.in_queues):
try:
Expand Down

0 comments on commit acfe070

Please sign in to comment.