Skip to content

Commit

Permalink
Add wandb-sweep example and clean-up use of configs (#118)
Browse files Browse the repository at this point in the history
* feat: added support for wandb logging (single runs, not sweeps yet)

* feat: replaced hps provided as a dict by a Config() object in seh_frag_moo

* feat: added config.desc

* fix: added train/valid for wandb log

* fix: allow JSON serialization of Enum objects

* chore: tox

* chore: replaced hps (dict) by Config() in all tasks. Moved qm9.py out of qm9/

* fix: changed default focus_region for frag_moo

* fix: added assert to prevent inadvertently manipulating a Config rather than Config() object

* removed cfg.use_wandb and simply test if wandb has been initialised in trainer

* chore: adding comment

* fix: typo

* chore: added cfg.task.seh_moo.log_topk to de-clutter a bit

* fix: added wandb to dependencies

* minor: file name change for consistency

* chore: centralised self.cfg.overwrite_existing_exp in GFNTrainer() (removed from all tasks to simplify mains)

* feat: added hyperopt/wandb_demo

* feat: removed wandb_agent_main.py to have the search and entrypoint defined in a single file

* chore: tox

* fix: minor in wandb_demo

* fix: storage path

* chore: tox
  • Loading branch information
julienroyd authored Feb 27, 2024
1 parent 74d6acc commit 96dde6b
Show file tree
Hide file tree
Showing 17 changed files with 265 additions and 183 deletions.
5 changes: 1 addition & 4 deletions src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional


class TBVariant(Enum):
class TBVariant(int, Enum):
"""See algo.trajectory_balance.TrajectoryBalance for details."""

TB = 0
Expand Down Expand Up @@ -116,8 +116,6 @@ class AlgoConfig:
Do not take random actions after this number of steps
valid_random_action_prob : float
The probability of taking a random action during validation
valid_sample_cond_info : bool
Whether to sample conditioning information during validation (if False, expects a validation set of cond_info)
sampling_tau : float
The EMA factor for the sampling model (theta_sampler = tau * theta_sampler + (1-tau) * theta)
"""
Expand All @@ -133,7 +131,6 @@ class AlgoConfig:
train_random_action_prob: float = 0.0
train_det_after: Optional[int] = None
valid_random_action_prob: float = 0.0
valid_sample_cond_info: bool = True
sampling_tau: float = 0.0
tb: TBConfig = TBConfig()
moql: MOQLConfig = MOQLConfig()
Expand Down
25 changes: 23 additions & 2 deletions src/gflownet/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, fields, is_dataclass
from typing import Optional

from omegaconf import MISSING
Expand Down Expand Up @@ -50,6 +50,8 @@ class Config:
Attributes
----------
desc : str
A description of the experiment
log_dir : str
The directory where to store logs, checkpoints, and samples.
device : str
Expand Down Expand Up @@ -82,6 +84,7 @@ class Config:
Whether to overwrite the contents of the log_dir if it already exists
"""

desc: str = "noDesc"
log_dir: str = MISSING
device: str = "cuda"
seed: int = 0
Expand All @@ -96,10 +99,28 @@ class Config:
hostname: Optional[str] = None
pickle_mp_messages: bool = False
git_hash: Optional[str] = None
overwrite_existing_exp: bool = True
overwrite_existing_exp: bool = False
algo: AlgoConfig = AlgoConfig()
model: ModelConfig = ModelConfig()
opt: OptimizerConfig = OptimizerConfig()
replay: ReplayConfig = ReplayConfig()
task: TasksConfig = TasksConfig()
cond: ConditionalsConfig = ConditionalsConfig()


def init_empty(cfg):
"""
Initialize a dataclass instance with all fields set to MISSING,
including nested dataclasses.
This is meant to be used on the user side (tasks) to provide
some configuration using the Config class while overwritting
only the fields that have been set by the user.
"""
for f in fields(cfg):
if is_dataclass(f.type):
setattr(cfg, f.name, init_empty(f.type()))
else:
setattr(cfg, f.name, MISSING)

return cfg
6 changes: 6 additions & 0 deletions src/gflownet/hyperopt/wandb_demo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Everything is contained in one file; `init_wandb_sweep.py` both defines the search space of the sweep and is the entrypoint of wandb agents.

To launch the search:
1. `python init_wandb_sweep.py` to intialize the sweep
2. `sbatch launch_wandb_agents.sh <SWEEP_ID>` to schedule a jobarray in slurm which will launch wandb agents.
The number of jobs in the sbatch file should reflect the size of the hyperparameter space that is being sweeped.
73 changes: 73 additions & 0 deletions src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import sys
import time

import wandb

from gflownet.config import Config, init_empty
from gflownet.tasks.seh_frag_moo import SEHMOOFragTrainer

TIME = time.strftime("%m-%d-%H-%M")
ENTITY = "valencelabs"
PROJECT = "gflownet"
SWEEP_NAME = f"{TIME}-sehFragMoo-Zlr-Zlrdecay"
STORAGE_DIR = f"~/storage/wandb_sweeps/{SWEEP_NAME}"


# Define the search space of the sweep
sweep_config = {
"name": SWEEP_NAME,
"program": "init_wandb_sweep.py",
"controller": {
"type": "cloud",
},
"method": "grid",
"parameters": {
"config.algo.tb.Z_learning_rate": {"values": [1e-4, 1e-3, 1e-2]},
"config.algo.tb.Z_lr_decay": {"values": [2_000, 50_000]},
},
}


def wandb_config_merger():
config = init_empty(Config())
wandb_config = wandb.config

# Set desired config values
config.log_dir = f"{STORAGE_DIR}/{wandb.run.name}-id-{wandb.run.id}"
config.print_every = 100
config.validate_every = 1000
config.num_final_gen_steps = 1000
config.num_training_steps = 40_000
config.pickle_mp_messages = True
config.overwrite_existing_exp = False
config.algo.sampling_tau = 0.95
config.algo.train_random_action_prob = 0.01
config.algo.tb.Z_learning_rate = 1e-3
config.task.seh_moo.objectives = ["seh", "qed"]
config.cond.temperature.sample_dist = "constant"
config.cond.temperature.dist_params = [60.0]
config.cond.weighted_prefs.preference_type = "dirichlet"
config.cond.focus_region.focus_type = None
config.replay.use = False

# Merge the wandb sweep config with the nested config from gflownet
config.algo.tb.Z_learning_rate = wandb_config["config.algo.tb.Z_learning_rate"]
config.algo.tb.Z_lr_decay = wandb_config["config.algo.tb.Z_lr_decay"]

return config


if __name__ == "__main__":
# if there no arguments, initialize the sweep, otherwise this is a wandb agent
if len(sys.argv) == 1:
if os.path.exists(STORAGE_DIR):
raise ValueError(f"Sweep storage directory {STORAGE_DIR} already exists.")

wandb.sweep(sweep_config, entity=ENTITY, project=PROJECT)

else:
wandb.init(entity=ENTITY, project=PROJECT)
config = wandb_config_merger()
trial = SEHMOOFragTrainer(config)
trial.run()
19 changes: 19 additions & 0 deletions src/gflownet/hyperopt/wandb_demo/launch_wandb_agents.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash

# Purpose: Script to allocate a node and run a wandb sweep agent on it
# Usage: sbatch launch_wandb_agent.sh <SWEEP_ID>

#SBATCH --job-name=wandb_sweep_agent
#SBATCH --array=1-6
#SBATCH --time=23:59:00
#SBATCH --output=slurm_output_files/%x_%N_%A_%a.out
#SBATCH --gpus=1
#SBATCH --cpus-per-task=16
#SBATCH --mem=16GB
#SBATCH --partition compute

source activate gfn-py39-torch113
echo "Using environment={$CONDA_DEFAULT_ENV}"

# launch wandb agent
wandb agent --count 1 --entity valencelabs --project gflownet $1
2 changes: 1 addition & 1 deletion src/gflownet/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class GraphTransformerConfig:
num_mlp_layers: int = 0


class SeqPosEnc(Enum):
class SeqPosEnc(int, Enum):
Pos = 0
Rotary = 1

Expand Down
12 changes: 6 additions & 6 deletions src/gflownet/online_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ def setup(self):
git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7]
self.cfg.git_hash = git_hash

yaml = OmegaConf.to_yaml(self.cfg)
os.makedirs(self.cfg.log_dir, exist_ok=True)
if self.print_hps:
yaml_cfg = OmegaConf.to_yaml(self.cfg)
if self.print_config:
print("\n\nHyperparameters:\n")
print(yaml)
with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w", encoding="utf8") as f:
f.write(yaml)
print(yaml_cfg)
os.makedirs(self.cfg.log_dir, exist_ok=True)
with open(pathlib.Path(self.cfg.log_dir) / "config.yaml", "w", encoding="utf8") as f:
f.write(yaml_cfg)

def step(self, loss: Tensor):
loss.backward()
Expand Down
1 change: 1 addition & 0 deletions src/gflownet/tasks/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class SEHMOOTaskConfig:
n_valid: int = 15
n_valid_repeats: int = 128
objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "mw"])
log_topk: bool = False
online_pareto_front: bool = True


Expand Down
24 changes: 11 additions & 13 deletions src/gflownet/tasks/make_rings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import socket
from typing import Dict, List, Tuple, Union

Expand All @@ -8,7 +7,7 @@
from rdkit.Chem.rdchem import Mol as RDMol
from torch import Tensor

from gflownet.config import Config
from gflownet.config import Config, init_empty
from gflownet.envs.mol_building_env import MolBuildingEnvContext
from gflownet.online_trainer import StandardOnlineTrainer
from gflownet.trainer import FlatRewards, GFNTask, RewardScalar
Expand Down Expand Up @@ -72,17 +71,16 @@ def setup_env_context(self):


def main():
hps = {
"log_dir": "./logs/debug_run_mr4",
"device": "cuda",
"num_training_steps": 10_000,
"num_workers": 8,
"algo": {"tb": {"do_parameterize_p_b": True}},
}
os.makedirs(hps["log_dir"], exist_ok=True)

trial = MakeRingsTrainer(hps)
trial.print_every = 1
"""Example of how this model can be run."""
config = init_empty(Config())
config.print_every = 1
config.log_dir = "./logs/debug_run_mr4"
config.device = "cuda"
config.num_training_steps = 10_000
config.num_workers = 8
config.algo.tb.do_parameterize_p_b = True

trial = MakeRingsTrainer(config)
trial.run()


Expand Down
21 changes: 20 additions & 1 deletion src/gflownet/tasks/qm9/qm9.py → src/gflownet/tasks/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data import Dataset

import gflownet.models.mxmnet as mxmnet
from gflownet.config import Config
from gflownet.config import Config, init_empty
from gflownet.data.qm9 import QM9Dataset
from gflownet.envs.mol_building_env import MolBuildingEnvContext
from gflownet.online_trainer import StandardOnlineTrainer
Expand Down Expand Up @@ -162,3 +162,22 @@ def setup(self):
super().setup()
self.training_data.setup(self.task, self.ctx)
self.test_data.setup(self.task, self.ctx)


def main():
"""Example of how this model can be run."""
config = init_empty(Config())
config.num_workers = 0
config.num_training_steps = 100000
config.validate_every = 100
config.log_dir = "./logs/debug_qm9"
config.opt.lr_decay = 10000
config.task.qm9.h5_path = "/rxrx/data/chem/qm9/qm9.h5"
config.task.qm9.model_path = "/rxrx/data/chem/qm9/mxmnet_gap_model.pt"

trial = QM9GapTrainer(config)
trial.run()


if __name__ == "__main__":
main()
Empty file removed src/gflownet/tasks/qm9/__init__.py
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from gflownet.config import Config
from gflownet.data.qm9 import QM9Dataset
from gflownet.envs.mol_building_env import MolBuildingEnvContext
from gflownet.tasks.qm9.qm9 import QM9GapTask, QM9GapTrainer
from gflownet.tasks.qm9 import QM9GapTask, QM9GapTrainer
from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, aux_tasks
from gflownet.trainer import FlatRewards, RewardScalar
from gflownet.utils import metrics
Expand Down Expand Up @@ -197,7 +197,7 @@ def set_default_hps(self, cfg: Config):
cfg.algo.sampling_tau = 0.95
# We use a fixed set of preferences as our "validation set", so we must disable the preference (cond_info)
# sampling and set the offline ratio to 1
cfg.algo.valid_sample_cond_info = False
cfg.cond.valid_sample_cond_info = False
cfg.algo.valid_offline_ratio = 1

def setup_algo(self):
Expand Down
46 changes: 16 additions & 30 deletions src/gflownet/tasks/seh_frag.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os
import shutil
import socket
from typing import Callable, Dict, List, Tuple, Union

Expand All @@ -13,7 +11,7 @@
from torch.utils.data import Dataset
from torch_geometric.data import Data

from gflownet.config import Config
from gflownet.config import Config, init_empty
from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext, Graph
from gflownet.models import bengio2021flow
from gflownet.online_trainer import StandardOnlineTrainer
Expand Down Expand Up @@ -200,33 +198,21 @@ def setup(self):


def main():
"""Example of how this trainer can be run"""
hps = {
"log_dir": "./logs/debug_run_seh_frag_pb",
"device": "cuda" if torch.cuda.is_available() else "cpu",
"overwrite_existing_exp": True,
"num_training_steps": 10_000,
"num_workers": 8,
"opt": {
"lr_decay": 20000,
},
"algo": {"sampling_tau": 0.99, "offline_ratio": 0.0},
"cond": {
"temperature": {
"sample_dist": "uniform",
"dist_params": [0, 64.0],
}
},
}
if os.path.exists(hps["log_dir"]):
if hps["overwrite_existing_exp"]:
shutil.rmtree(hps["log_dir"])
else:
raise ValueError(f"Log dir {hps['log_dir']} already exists. Set overwrite_existing_exp=True to delete it.")
os.makedirs(hps["log_dir"])

trial = SEHFragTrainer(hps)
trial.print_every = 1
"""Example of how this model can be run."""
config = init_empty(Config())
config.print_every = 1
config.log_dir = "./logs/debug_run_seh_frag_pb"
config.device = "cuda" if torch.cuda.is_available() else "cpu"
config.overwrite_existing_exp = True
config.num_training_steps = 10_000
config.num_workers = 8
config.opt.lr_decay = 20_000
config.algo.sampling_tau = 0.99
config.algo.offline_ratio = 0.0
config.cond.temperature.sample_dist = "uniform"
config.cond.temperature.dist_params = [0, 64.0]

trial = SEHFragTrainer(config)
trial.run()


Expand Down
Loading

0 comments on commit 96dde6b

Please sign in to comment.