Skip to content

Commit

Permalink
Fixes GPU memory overflow when using replay buffer (#130)
Browse files Browse the repository at this point in the history
* fix: detach tensors entering the buffer and sending to cpu

* minor: added case 'tuple'

* fix: added detach and cpu() at the begining of create_batch()

* minor: removed type cast, now support tuples

* feat: added StrictDataClass to prevent creating new config attributes outside the class definition and avoid silent bugs due to unintended default configs

* minor: better error message

* fix: made focus_dir and preferences accessible at the batch level

* tox
  • Loading branch information
julienroyd authored Apr 4, 2024
1 parent f064713 commit 2a24bb0
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 27 deletions.
14 changes: 8 additions & 6 deletions src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from enum import IntEnum
from typing import Optional

from gflownet.utils.misc import StrictDataClass


class Backward(IntEnum):
"""
Expand Down Expand Up @@ -53,7 +55,7 @@ class LossFN(IntEnum):


@dataclass
class TBConfig:
class TBConfig(StrictDataClass):
"""Trajectory Balance config.
Attributes
Expand Down Expand Up @@ -113,7 +115,7 @@ class TBConfig:


@dataclass
class MOQLConfig:
class MOQLConfig(StrictDataClass):
gamma: float = 1
num_omega_samples: int = 32
num_objectives: int = 2
Expand All @@ -122,29 +124,29 @@ class MOQLConfig:


@dataclass
class A2CConfig:
class A2CConfig(StrictDataClass):
entropy: float = 0.01
gamma: float = 1
penalty: float = -10


@dataclass
class FMConfig:
class FMConfig(StrictDataClass):
epsilon: float = 1e-38
balanced_loss: bool = False
leaf_coef: float = 10
correct_idempotent: bool = False


@dataclass
class SQLConfig:
class SQLConfig(StrictDataClass):
alpha: float = 0.01
gamma: float = 1
penalty: float = -10


@dataclass
class AlgoConfig:
class AlgoConfig(StrictDataClass):
"""Generic configuration for algorithms
Attributes
Expand Down
5 changes: 3 additions & 2 deletions src/gflownet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from gflownet.models.config import ModelConfig
from gflownet.tasks.config import TasksConfig
from gflownet.utils.config import ConditionalsConfig
from gflownet.utils.misc import StrictDataClass


@dataclass
class OptimizerConfig:
class OptimizerConfig(StrictDataClass):
"""Generic configuration for optimizers
Attributes
Expand Down Expand Up @@ -45,7 +46,7 @@ class OptimizerConfig:


@dataclass
class Config:
class Config(StrictDataClass):
"""Base configuration for training
Attributes
Expand Down
4 changes: 3 additions & 1 deletion src/gflownet/data/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from dataclasses import dataclass
from typing import Optional

from gflownet.utils.misc import StrictDataClass


@dataclass
class ReplayConfig:
class ReplayConfig(StrictDataClass):
"""Replay buffer configuration
Attributes
Expand Down
11 changes: 6 additions & 5 deletions src/gflownet/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from gflownet import GFNAlgorithm, GFNTask
from gflownet.config import Config
from gflownet.data.replay_buffer import ReplayBuffer
from gflownet.data.replay_buffer import ReplayBuffer, detach_and_cpu
from gflownet.envs.graph_building_env import GraphBuildingEnvContext
from gflownet.utils.misc import get_worker_rng

Expand Down Expand Up @@ -214,16 +214,17 @@ def call_sampling_hooks(self, trajs):
return batch_info

def create_batch(self, trajs, batch_info):
trajs = detach_and_cpu(trajs)
ci = torch.stack([t["cond_info"]["encoding"] for t in trajs])
log_rewards = torch.stack([t["log_reward"] for t in trajs])
batch = self.algo.construct_batch(trajs, ci, log_rewards)
batch.num_online = sum(t.get("is_online", 0) for t in trajs)
batch.num_offline = len(trajs) - batch.num_online
batch.extra_info = batch_info
if "preferences" in trajs[0]:
batch.preferences = torch.stack([t["preferences"] for t in trajs])
if "focus_dir" in trajs[0]:
batch.focus_dir = torch.stack([t["focus_dir"] for t in trajs])
if "preferences" in trajs[0]["cond_info"].keys():
batch.preferences = torch.stack([t["cond_info"]["preferences"] for t in trajs])
if "focus_dir" in trajs[0]["cond_info"].keys():
batch.focus_dir = torch.stack([t["cond_info"]["focus_dir"] for t in trajs])

if self.ctx.has_n() and self.cfg.algo.tb.do_predict_n:
log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs]
Expand Down
17 changes: 17 additions & 0 deletions src/gflownet/data/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

class ReplayBuffer(object):
def __init__(self, cfg: Config, rng: np.random.Generator = None):
"""
Replay buffer for storing and sampling arbitrary data (e.g. transitions or trajectories)
In self.push(), the buffer detaches any torch tensor and sends it to the CPU.
"""
self.capacity = cfg.replay.capacity
self.warmup = cfg.replay.warmup
assert self.warmup <= self.capacity, "ReplayBuffer warmup must be smaller than capacity"
Expand All @@ -23,6 +27,7 @@ def push(self, *args):
assert self._input_size == len(args), "ReplayBuffer input size must be constant"
if len(self.buffer) < self.capacity:
self.buffer.append(None)
args = detach_and_cpu(args)
self.buffer[self.position] = args
self.position = (self.position + 1) % self.capacity

Expand All @@ -42,3 +47,15 @@ def sample(self, batch_size):

def __len__(self):
return len(self.buffer)


def detach_and_cpu(x):
if isinstance(x, torch.Tensor):
x = x.detach().cpu()
elif isinstance(x, dict):
x = {k: detach_and_cpu(v) for k, v in x.items()}
elif isinstance(x, list):
x = [detach_and_cpu(v) for v in x]
elif isinstance(x, tuple):
x = tuple(detach_and_cpu(v) for v in x)
return x
8 changes: 5 additions & 3 deletions src/gflownet/models/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from dataclasses import dataclass, field
from enum import Enum

from gflownet.utils.misc import StrictDataClass


@dataclass
class GraphTransformerConfig:
class GraphTransformerConfig(StrictDataClass):
num_heads: int = 2
ln_type: str = "pre"
num_mlp_layers: int = 0
Expand All @@ -15,13 +17,13 @@ class SeqPosEnc(int, Enum):


@dataclass
class SeqTransformerConfig:
class SeqTransformerConfig(StrictDataClass):
num_heads: int = 2
posenc: SeqPosEnc = SeqPosEnc.Rotary


@dataclass
class ModelConfig:
class ModelConfig(StrictDataClass):
"""Generic configuration for models
Attributes
Expand Down
12 changes: 7 additions & 5 deletions src/gflownet/tasks/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from dataclasses import dataclass, field
from typing import List

from gflownet.utils.misc import StrictDataClass


@dataclass
class SEHTaskConfig:
class SEHTaskConfig(StrictDataClass):
reduced_frag: bool = False


@dataclass
class SEHMOOTaskConfig:
class SEHMOOTaskConfig(StrictDataClass):
"""Config for the SEHMOOTask
Attributes
Expand All @@ -31,13 +33,13 @@ class SEHMOOTaskConfig:


@dataclass
class QM9TaskConfig:
class QM9TaskConfig(StrictDataClass):
h5_path: str = "./data/qm9/qm9.h5" # see src/gflownet/data/qm9.py
model_path: str = "./data/qm9/qm9_model.pt"


@dataclass
class QM9MOOTaskConfig:
class QM9MOOTaskConfig(StrictDataClass):
"""
Config for the QM9MooTask
Expand All @@ -61,7 +63,7 @@ class QM9MOOTaskConfig:


@dataclass
class TasksConfig:
class TasksConfig(StrictDataClass):
qm9: QM9TaskConfig = field(default_factory=QM9TaskConfig)
qm9_moo: QM9MOOTaskConfig = field(default_factory=QM9MOOTaskConfig)
seh: SEHTaskConfig = field(default_factory=SEHTaskConfig)
Expand Down
12 changes: 7 additions & 5 deletions src/gflownet/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from dataclasses import dataclass, field
from typing import Any, List, Optional

from gflownet.utils.misc import StrictDataClass


@dataclass
class TempCondConfig:
class TempCondConfig(StrictDataClass):
"""Config for the temperature conditional.
Attributes
Expand All @@ -28,13 +30,13 @@ class TempCondConfig:


@dataclass
class MultiObjectiveConfig:
class MultiObjectiveConfig(StrictDataClass):
num_objectives: int = 2 # TODO: Change that as it can conflict with cfg.task.seh_moo.num_objectives
num_thermometer_dim: int = 16


@dataclass
class WeightedPreferencesConfig:
class WeightedPreferencesConfig(StrictDataClass):
"""Config for the weighted preferences conditional.
Attributes
Expand All @@ -51,7 +53,7 @@ class WeightedPreferencesConfig:


@dataclass
class FocusRegionConfig:
class FocusRegionConfig(StrictDataClass):
"""Config for the focus region conditional.
Attributes
Expand All @@ -71,7 +73,7 @@ class FocusRegionConfig:


@dataclass
class ConditionalsConfig:
class ConditionalsConfig(StrictDataClass):
valid_sample_cond_info: bool = True
temperature: TempCondConfig = field(default_factory=TempCondConfig)
moo: MultiObjectiveConfig = field(default_factory=MultiObjectiveConfig)
Expand Down
16 changes: 16 additions & 0 deletions src/gflownet/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,19 @@ def set_main_process_device(device):
def get_worker_device():
worker_info = torch.utils.data.get_worker_info()
return _main_process_device[0] if worker_info is None else torch.device("cpu")


class StrictDataClass:
"""
A dataclass that raises an error if any field is created outside of the __init__ method.
"""

def __setattr__(self, name, value):
if hasattr(self, name) or name in self.__annotations__:
super().__setattr__(name, value)
else:
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'."
f" '{type(self).__name__}' is a StrictDataClass object."
f" Attributes can only be defined in the class definition."
)

0 comments on commit 2a24bb0

Please sign in to comment.