Skip to content

Commit

Permalink
chore: added cfg.task.seh_moo.log_topk to de-clutter a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
julienroyd committed Feb 14, 2024
1 parent d64e19f commit d28d860
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/gflownet/tasks/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class SEHMOOTaskConfig:
n_valid: int = 15
n_valid_repeats: int = 128
objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "mw"])
log_topk: bool = False


@dataclass
Expand Down
26 changes: 16 additions & 10 deletions src/gflownet/tasks/seh_frag_moo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
import torch.nn as nn
import torch_geometric.data as gd
import wandb
from rdkit.Chem import QED, Descriptors
from rdkit.Chem.rdchem import Mol as RDMol
from torch import Tensor
Expand Down Expand Up @@ -320,25 +319,32 @@ def setup(self):
else:
valid_cond_vector = valid_preferences

self._top_k_hook = TopKHook(10, self.cfg.task.seh_moo.n_valid_repeats, n_valid)
self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=self.cfg.task.seh_moo.n_valid_repeats)
self.valid_sampling_hooks.append(self._top_k_hook)

self._top_k_hook = TopKHook(10, self.cfg.task.seh_moo.n_valid_repeats, n_valid)
if self.cfg.task.seh_moo.log_topk:
self.valid_sampling_hooks.append(self._top_k_hook)

self.algo.task = self.task

def build_callbacks(self):
# We use this class-based setup to be compatible with the DeterminedAI API, but no direct
# dependency is required.
parent = self
callback_dict = {}

if self.cfg.task.seh_moo.log_topk:

class TopKMetricCB:
def on_validation_end(self, metrics: Dict[str, Any]):
top_k = parent._top_k_hook.finalize()
for i in range(len(top_k)):
metrics[f"topk_rewards_{i}"] = top_k[i]
print("validation end", metrics)

class TopKMetricCB:
def on_validation_end(self, metrics: Dict[str, Any]):
top_k = parent._top_k_hook.finalize()
for i in range(len(top_k)):
metrics[f"topk_rewards_{i}"] = top_k[i]
print("validation end", metrics)
callback_dict["topk"] = TopKMetricCB()

return {"topk": TopKMetricCB()}
return callback_dict

def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]:
if self.task.focus_cond is not None:
Expand Down
4 changes: 3 additions & 1 deletion src/gflownet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ def __init__(self, config: Config):
# config classes < default_hps < constructor (i.e. the constructor overrides the default_hps, and so on)
self.default_cfg: Config = Config()
self.set_default_hps(self.default_cfg)
assert isinstance(self.default_cfg, Config) and isinstance(config, Config) # make sure the config is a Config object, and not the Config class itself
assert isinstance(self.default_cfg, Config) and isinstance(
config, Config
) # make sure the config is a Config object, and not the Config class itself
self.cfg = OmegaConf.merge(self.default_cfg, config)

self.device = torch.device(self.cfg.device)
Expand Down

0 comments on commit d28d860

Please sign in to comment.