From 2a6b9d74583e7b2837649fb3ad64c6bd8fc47dfe Mon Sep 17 00:00:00 2001 From: julienroyd Date: Mon, 11 Mar 2024 20:31:07 +0000 Subject: [PATCH 01/22] docs: added abstract method construct_batch() in GFNAlgorithm() --- src/gflownet/__init__.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/gflownet/__init__.py b/src/gflownet/__init__.py index 9f445e01..06bb6ba5 100644 --- a/src/gflownet/__init__.py +++ b/src/gflownet/__init__.py @@ -45,6 +45,27 @@ def compute_batch_losses( """ raise NotImplementedError() + def construct_batch(self, trajs, cond_info, log_rewards): + """Construct a batch from a list of trajectories and their information + + Typically calls ctx.graph_to_Data and ctx.collate to convert the trajectories into + a batch of graphs and adds the necessary attributes for training. + + Parameters + ---------- + trajs: List[List[tuple[Graph, GraphAction]]] + A list of N trajectories. + cond_info: Tensor + The conditional info that is considered for each trajectory. Shape (N, n_info) + log_rewards: Tensor + The transformed log-reward (e.g. torch.log(R(x) ** beta) ) for each trajectory. Shape (N,) + Returns + ------- + batch: gd.Batch + A (CPU) Batch object with relevant attributes added + """ + raise NotImplementedError() + def get_random_action_prob(self, it: int): if self.is_eval: return self.global_cfg.algo.valid_random_action_prob From d6277f333504e246efeb8f640f1c6ec22e5397bd Mon Sep 17 00:00:00 2001 From: julienroyd Date: Mon, 11 Mar 2024 21:15:44 +0000 Subject: [PATCH 02/22] chore: cleaning up GraphAction.relabel --- src/gflownet/envs/graph_building_env.py | 7 +------ src/gflownet/envs/test.py | 2 +- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 5812ea49..35fe460b 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -83,7 +83,7 @@ def is_backward(self): class GraphAction: - def __init__(self, action: GraphActionType, source=None, target=None, value=None, attr=None, relabel=None): + def __init__(self, action: GraphActionType, source=None, target=None, value=None, attr=None): """A single graph-building action Parameters @@ -98,15 +98,12 @@ def __init__(self, action: GraphActionType, source=None, target=None, value=None the set attribute of a node/edge value: Any, optional the value (e.g. new node type) applied - relabel: int, optional - for AddNode actions, relabels the new node with that id """ self.action = action self.source = source self.target = target self.attr = attr self.value = value - self.relabel = relabel # TODO: deprecate this? def __repr__(self): attrs = ", ".join(str(i) for i in [self.source, self.target, self.attr, self.value] if i is not None) @@ -185,8 +182,6 @@ def step(self, g: Graph, action: GraphAction) -> Graph: else: assert action.source in g.nodes e = [action.source, max(g.nodes) + 1] - if action.relabel is not None: - raise ValueError("deprecated") # if kw and 'relabel' in kw: # e[1] = kw['relabel'] # for `parent` consistency, allow relabeling assert not g.has_edge(*e) diff --git a/src/gflownet/envs/test.py b/src/gflownet/envs/test.py index 55d3bfbf..b79027a6 100644 --- a/src/gflownet/envs/test.py +++ b/src/gflownet/envs/test.py @@ -95,7 +95,7 @@ def main(smi, n_steps): molg = ctx.mol_to_graph(mol) traj = generate_forward_trajectory(molg) for g, a in traj: - print(a.action, a.source, a.target, a.value, a.relabel) + print(a.action, a.source, a.target, a.value) graphs = [ctx.graph_to_Data(i) for i, _ in traj] traj_batch = ctx.collate(graphs) actions = [ctx.GraphAction_to_aidx(g, a) for g, a in zip(graphs, [i[1] for i in traj])] From 5d1adf916d637f344c4ece9fe5d2c10185faa06d Mon Sep 17 00:00:00 2001 From: julienroyd Date: Mon, 11 Mar 2024 22:01:59 +0000 Subject: [PATCH 03/22] chore: replaced Tuple[int,int,int] by ActionIndex() named-tuple. --- docs/implementation_notes.md | 2 +- src/gflownet/algo/advantage_actor_critic.py | 3 +- src/gflownet/algo/envelope_q_learning.py | 3 +- src/gflownet/algo/flow_matching.py | 11 ++-- src/gflownet/algo/graph_sampling.py | 6 +-- src/gflownet/algo/soft_q_learning.py | 3 +- src/gflownet/algo/trajectory_balance.py | 19 ++++--- src/gflownet/envs/frag_mol_env.py | 35 ++++++------ src/gflownet/envs/graph_building_env.py | 60 ++++++++++++++------- src/gflownet/envs/mol_building_env.py | 39 +++++++------- src/gflownet/envs/seq_building_env.py | 16 +++--- src/gflownet/envs/test.py | 4 +- tests/test_envs.py | 8 +-- 13 files changed, 118 insertions(+), 91 deletions(-) diff --git a/docs/implementation_notes.md b/docs/implementation_notes.md index fa2f9a39..ad146220 100644 --- a/docs/implementation_notes.md +++ b/docs/implementation_notes.md @@ -24,7 +24,7 @@ This library is built around the idea of generating graphs. We use the `networkx Some notes: - graphs are (for now) assumed to be _undirected_. This is encoded for `torch_geometric` by duplicating the edges (contiguously) in both directions. Models still only produce one logit(-row) per edge, so the policy is still assumed to operate on undirected graphs. -- When converting from `GraphAction`s (nx) to so-called `aidx`s, the `aidx`s are encoding-bound, i.e. they point to specific rows and columns in the torch encoding. +- When converting from `GraphAction`s (nx) to `ActionIndex`s (tuple of ints), the action indexes are encoding-bound, i.e. they point to specific rows and columns in the torch encoding. ### Graph policies & graph action categoricals diff --git a/src/gflownet/algo/advantage_actor_critic.py b/src/gflownet/algo/advantage_actor_critic.py index 40c58010..c657e88e 100644 --- a/src/gflownet/algo/advantage_actor_critic.py +++ b/src/gflownet/algo/advantage_actor_critic.py @@ -118,7 +118,8 @@ def construct_batch(self, trajs, cond_info, log_rewards): """ torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]] actions = [ - self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]]) + self.ctx.GraphAction_to_ActionIndex(g, a) + for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]]) ] batch = self.ctx.collate(torch_graphs) batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs]) diff --git a/src/gflownet/algo/envelope_q_learning.py b/src/gflownet/algo/envelope_q_learning.py index 7adfd68c..95cc3d44 100644 --- a/src/gflownet/algo/envelope_q_learning.py +++ b/src/gflownet/algo/envelope_q_learning.py @@ -272,7 +272,8 @@ def construct_batch(self, trajs, cond_info, log_rewards): """ torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]] actions = [ - self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]]) + self.ctx.GraphAction_to_ActionIndex(g, a) + for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]]) ] batch = self.ctx.collate(torch_graphs) batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs]) diff --git a/src/gflownet/algo/flow_matching.py b/src/gflownet/algo/flow_matching.py index a1e9a393..9c291356 100644 --- a/src/gflownet/algo/flow_matching.py +++ b/src/gflownet/algo/flow_matching.py @@ -99,20 +99,23 @@ def construct_batch(self, trajs, cond_info, log_rewards): # there are invalid states that make episodes end prematurely (when those invalid states # have multiple possible parents). - # convert actions to aidx + # convert actions to ActionIndex parent_actions = [pact for parent in parents for pact, pstate in parent] - parent_actionidcs = [self.ctx.GraphAction_to_aidx(gdata, a) for gdata, a in zip(parent_graphs, parent_actions)] + parent_actionidxs = [ + self.ctx.GraphAction_to_ActionIndex(gdata, a) for gdata, a in zip(parent_graphs, parent_actions) + ] # convert state to Data state_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"][1:]] terminal_actions = [ - self.ctx.GraphAction_to_aidx(self.ctx.graph_to_Data(tj["traj"][-1][0]), tj["traj"][-1][1]) for tj in trajs + self.ctx.GraphAction_to_ActionIndex(self.ctx.graph_to_Data(tj["traj"][-1][0]), tj["traj"][-1][1]) + for tj in trajs ] # Create a batch from [*parents, *states]. This order will make it easier when computing the loss batch = self.ctx.collate(parent_graphs + state_graphs) batch.num_parents = torch.tensor([len(i) for i in parents]) batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs]) - batch.parent_acts = torch.tensor(parent_actionidcs) + batch.parent_acts = torch.tensor(parent_actionidxs) batch.terminal_acts = torch.tensor(terminal_actions) batch.log_rewards = log_rewards batch.cond_info = cond_info diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 0db5bcec..d1ec3a5b 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -98,7 +98,7 @@ def sample_from_model( graphs = [self.env.new() for i in range(n)] done = [False] * n # TODO: instead of padding with Stop, we could have a virtual action whose probability - # always evaluates to 1. Presently, Stop should convert to a [0,0,0] aidx, which should + # always evaluates to 1. Presently, Stop should convert to a (0,0,0) ActionIndex, which should # always be at least a valid index, and will be masked out anyways -- but this isn't ideal. # Here we have to pad the backward actions with something, since the backward actions are # evaluated at s_{t+1} not s_t. @@ -136,7 +136,7 @@ def not_done(lst): actions = sample_cat.sample() else: actions = fwd_cat.sample() - graph_actions = [self.ctx.aidx_to_GraphAction(g, a) for g, a in zip(torch_graphs, actions)] + graph_actions = [self.ctx.ActionIndex_to_GraphAction(g, a) for g, a in zip(torch_graphs, actions)] log_probs = fwd_cat.log_prob(actions) # Step each trajectory, and accumulate statistics for i, j in zip(not_done(range(n)), range(n)): @@ -273,7 +273,7 @@ def not_done(lst): ) bck_actions = bck_cat.sample() graph_bck_actions = [ - self.ctx.aidx_to_GraphAction(g, a, fwd=False) for g, a in zip(torch_graphs, bck_actions) + self.ctx.ActionIndex_to_GraphAction(g, a, fwd=False) for g, a in zip(torch_graphs, bck_actions) ] bck_logprobs = bck_cat.log_prob(bck_actions) diff --git a/src/gflownet/algo/soft_q_learning.py b/src/gflownet/algo/soft_q_learning.py index 378d0b7e..dc205981 100644 --- a/src/gflownet/algo/soft_q_learning.py +++ b/src/gflownet/algo/soft_q_learning.py @@ -114,7 +114,8 @@ def construct_batch(self, trajs, cond_info, log_rewards): """ torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]] actions = [ - self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]]) + self.ctx.GraphAction_to_ActionIndex(g, a) + for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]]) ] batch = self.ctx.collate(torch_graphs) batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs]) diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index fcd171b4..1a8a699e 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -13,6 +13,7 @@ from gflownet.algo.graph_sampling import GraphSampler from gflownet.config import Config from gflownet.envs.graph_building_env import ( + ActionIndex, Graph, GraphAction, GraphActionCategorical, @@ -249,15 +250,15 @@ def get_idempotent_actions(self, g: Graph, gd: gd.Data, gp: Graph, action: Graph action: GraphAction Action leading from g to gp return_aidx: bool - If true returns of list of action indices, else a list of GraphAction + If true returns of list of ActionIndex, else a list of GraphAction Returns ------- - actions: Union[List[Tuple[int,int,int]], List[GraphAction]] + actions: Union[List[ActionIndex], List[GraphAction]] The list of idempotent actions that all lead from g to gp. """ - iaction = self.ctx.GraphAction_to_aidx(gd, action) + iaction = self.ctx.GraphAction_to_ActionIndex(gd, action) if action.action == GraphActionType.Stop: return [iaction if return_aidx else action] # Here we're looking for potential idempotent actions by looking at legal actions of the @@ -267,10 +268,10 @@ def get_idempotent_actions(self, g: Graph, gd: gd.Data, gp: Graph, action: Graph nz = lmask.nonzero() # Legal actions are those with a nonzero mask value actions = [iaction if return_aidx else action] for i in nz: - aidx = (iaction[0], i[0].item(), i[1].item()) + aidx = ActionIndex(action_type=iaction[0], row_idx=i[0].item(), col_idx=i[1].item()) if aidx == iaction: continue - ga = self.ctx.aidx_to_GraphAction(gd, aidx, fwd=not action.action.is_backward) + ga = self.ctx.ActionIndex_to_GraphAction(gd, aidx, fwd=not action.action.is_backward) child = self.env.step(g, ga) if nx.algorithms.is_isomorphic(child, gp, lambda a, b: a == b, lambda a, b: a == b): actions.append(aidx if return_aidx else ga) @@ -294,11 +295,13 @@ def construct_batch(self, trajs, cond_info, log_rewards): """ if self.model_is_autoregressive: torch_graphs = [self.ctx.graph_to_Data(tj["traj"][-1][0]) for tj in trajs] - actions = [self.ctx.GraphAction_to_aidx(g, i[1]) for g, tj in zip(torch_graphs, trajs) for i in tj["traj"]] + actions = [ + self.ctx.GraphAction_to_ActionIndex(g, i[1]) for g, tj in zip(torch_graphs, trajs) for i in tj["traj"] + ] else: torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]] actions = [ - self.ctx.GraphAction_to_aidx(g, a) + self.ctx.GraphAction_to_ActionIndex(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]]) ] batch = self.ctx.collate(torch_graphs) @@ -308,7 +311,7 @@ def construct_batch(self, trajs, cond_info, log_rewards): if self.cfg.do_parameterize_p_b: batch.bck_actions = torch.tensor( [ - self.ctx.GraphAction_to_aidx(g, a) + self.ctx.GraphAction_to_ActionIndex(g, a) for g, a in zip(torch_graphs, [i for tj in trajs for i in tj["bck_a"]]) ] ) diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index 71a89f0c..371ec5a7 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -9,7 +9,7 @@ import torch_geometric.data as gd from scipy import special -from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionType, GraphBuildingEnvContext +from gflownet.envs.graph_building_env import ActionIndex, Graph, GraphAction, GraphActionType, GraphBuildingEnvContext from gflownet.models import bengio2021flow @@ -90,14 +90,14 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu self.n_counter = NCounter() self.sorted_frags = sorted(list(enumerate(self.frags_mol)), key=lambda x: -x[1].GetNumAtoms()) - def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True): + def ActionIndex_to_GraphAction(self, g: gd.Data, aidx: ActionIndex, fwd: bool = True): """Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction Parameters ---------- g: gd.Data The graph object on which this action would be applied. - action_idx: Tuple[int, int, int] + aidx: ActionIndex A triple describing the type of action, and the corresponding row and column index for the corresponding Categorical matrix. @@ -105,33 +105,32 @@ def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: action: GraphAction A graph action whose type is one of Stop, AddNode, or SetEdgeAttr. """ - act_type, act_row, act_col = [int(i) for i in action_idx] if fwd: - t = self.action_type_order[act_type] + t = self.action_type_order[aidx.action_type] else: - t = self.bck_action_type_order[act_type] + t = self.bck_action_type_order[aidx.action_type] if t is GraphActionType.Stop: return GraphAction(t) elif t is GraphActionType.AddNode: - return GraphAction(t, source=act_row, value=act_col) + return GraphAction(t, source=aidx.row_idx, value=aidx.col_idx) elif t is GraphActionType.SetEdgeAttr: - a, b = g.edge_index[:, act_row * 2] # Edges are duplicated to get undirected GNN, deduplicated for logits - if act_col < self.num_stem_acts: + a, b = g.edge_index[:, aidx.row_idx * 2] # Edges are duplicated to get undirected GNN, deduplicated for logits + if aidx.col_idx < self.num_stem_acts: attr = "src_attach" - val = act_col + val = aidx.col_idx else: attr = "dst_attach" - val = act_col - self.num_stem_acts + val = aidx.col_idx - self.num_stem_acts return GraphAction(t, source=a.item(), target=b.item(), attr=attr, value=val) elif t is GraphActionType.RemoveNode: - return GraphAction(t, source=act_row) + return GraphAction(t, source=aidx.row_idx) elif t is GraphActionType.RemoveEdgeAttr: - a, b = g.edge_index[:, act_row * 2] - attr = "src_attach" if act_col == 0 else "dst_attach" + a, b = g.edge_index[:, aidx.row_idx * 2] + attr = "src_attach" if aidx.col_idx == 0 else "dst_attach" return GraphAction(t, source=a.item(), target=b.item(), attr=attr) - def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int, int]: - """Translate a GraphAction to an index tuple + def GraphAction_to_ActionIndex(self, g: gd.Data, action: GraphAction) -> ActionIndex: + """Translate a GraphAction to an ActionIndex Parameters ---------- @@ -142,7 +141,7 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int Returns ------- - action_idx: Tuple[int, int, int] + action_idx: ActionIndex A triple describing the type of action, and the corresponding row and column index for the corresponding Categorical matrix. """ @@ -176,7 +175,7 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int col = 0 else: col = 1 - return (type_idx, int(row), int(col)) + return ActionIndex(action_type=type_idx, row_idx=int(row), col_idx=int(col)) def graph_to_Data(self, g: Graph) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 35fe460b..255c91e3 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -4,7 +4,7 @@ import re from collections import defaultdict from functools import cached_property -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union import networkx as nx import numpy as np @@ -82,6 +82,22 @@ def is_backward(self): return self.name.startswith("Remove") +class ActionIndex(NamedTuple): + """ + Represents an index for an action in the GraphBuildingEnv. + + Different types of actions lead to logit matrices of different shapes, + for exemple GraphActionType.Stop has a shape of (1, 1), while + GraphActionType.AddNode has a shape of (n, m) where n is the number of + nodes in the graph and m is the number of possible node types (idem for edge actions). + It is thus convenient to represent the action as a tuple of indices. + """ + + action_type: int # Index of the action type according to GraphActionType + row_idx: int # Index of the element the action applies to (e.g. node or edge) + col_idx: int # Index of the action variant (e.g. which attribute to set) + + class GraphAction: def __init__(self, action: GraphActionType, source=None, target=None, value=None, attr=None): """A single graph-building action @@ -90,7 +106,7 @@ def __init__(self, action: GraphActionType, source=None, target=None, value=None ---------- action: GraphActionType the action type - source: int + source: int, optional (e.g. GraphActionType.Stop has no source) the source node this action is applied on target: int, optional the target node (i.e. if specified this is an edge action) @@ -99,11 +115,11 @@ def __init__(self, action: GraphActionType, source=None, target=None, value=None value: Any, optional the value (e.g. new node type) applied """ - self.action = action - self.source = source - self.target = target - self.attr = attr - self.value = value + self.action: GraphActionType = action + self.source: Optional[int] = source + self.target: Optional[int] = target + self.attr: Optional[str] = attr + self.value: Optional[Any] = value def __repr__(self): attrs = ", ".join(str(i) for i in [self.source, self.target, self.attr, self.value] if i is not None) @@ -675,11 +691,11 @@ def logsumexp(self, x=None): # Add back max return reduction + maxl - def sample(self) -> List[Tuple[int, int, int]]: + def sample(self) -> List[ActionIndex]: """Samples this categorical Returns ------- - actions: List[Tuple[int, int, int]] + actions: List[ActionIndex] A list of indices representing [action type, element index, action index]. See constructor. """ # Use the Gumbel trick to sample categoricals @@ -719,7 +735,7 @@ def argmax( x: List[torch.Tensor], batch: List[torch.Tensor] = None, dim_size: int = None, - ) -> List[Tuple[int, int, int]]: + ) -> List[ActionIndex]: """Takes the argmax, i.e. if x are the logits, returns the most likely action. Parameters @@ -732,7 +748,7 @@ def argmax( The reduction dimension, default `self.num_graphs`. Returns ------- - actions: List[Tuple[int, int, int]] + actions: List[ActionIndex] A list of indices representing [action type, element index, action index]. See constructor. """ # scatter_max and .max create a (values, indices) pair @@ -765,18 +781,22 @@ def argmax( t = type_max_idx[i] # Subtract from the slice of that type and index, since the computed # row position is batch-wise rather graph-wise - argmaxes.append((int(t), int(row_pos[t][i] - self.slice[t][i]), int(col_max[t][1][i]))) + argmaxes.append( + ActionIndex( + action_type=int(t), row_idx=int(row_pos[t][i] - self.slice[t][i]), col_idx=int(col_max[t][1][i]) + ) + ) # It's now up to the Context class to create GraphBuildingAction instances # if it wants to convert these indices to env-compatible actions return argmaxes - def log_prob(self, actions: List[Tuple[int, int, int]], logprobs: torch.Tensor = None, batch: torch.Tensor = None): + def log_prob(self, actions: List[ActionIndex], logprobs: torch.Tensor = None, batch: torch.Tensor = None): """The log-probability of a list of action tuples, effectively indexes `logprobs` using internal slice indices. Parameters ---------- - actions: List[Tuple[int, int, int]] + actions: List[ActionIndex] A list of n action tuples denoting indices logprobs: List[Tensor] [Optional] The log-probablities to be indexed (self.logsoftmax() by default) in order (i.e. this @@ -799,7 +819,7 @@ def log_prob(self, actions: List[Tuple[int, int, int]], logprobs: torch.Tensor = # [logprobs[t][row + self.slice[t][i], col] for i, (t, row, col) in zip(batch, actions)] # but faster. - # each action is a 3-tuple, (type, row, column), where type is the index of the action type group. + # each action is a 3-tuple ActionIndex (type, row, column), where type is the index of the action type group. actions = torch.as_tensor(actions, device=self.dev, dtype=torch.long) assert actions.shape[0] == batch.shape[0] # Check there are as many actions as batch indices # To index the log probabilities efficiently, we will ravel the array, and compute the @@ -857,13 +877,13 @@ class GraphBuildingEnvContext: device: torch.device action_type_order: List[GraphActionType] - def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True) -> GraphAction: + def ActionIndex_to_GraphAction(self, g: gd.Data, aidx: ActionIndex, fwd: bool = True) -> GraphAction: """Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction Parameters ---------- g: gd.Data The graph to which the action is being applied - action_idx: Tuple[int, int, int] + aidx: ActionIndex The tensor indices for the corresponding action fwd: bool If True (default) then this is a forward action @@ -875,8 +895,8 @@ def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: """ raise NotImplementedError() - def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int, int]: - """Translate a GraphAction to an action index (e.g. from a GraphActionCategorical) + def GraphAction_to_ActionIndex(self, g: gd.Data, action: GraphAction) -> ActionIndex: + """Translate a GraphAction to an ActionIndex (e.g. from a GraphActionCategorical) Parameters ---------- g: gd.Data @@ -886,7 +906,7 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int Returns ------- - action_idx: Tuple[int, int, int] + aidx: ActionIndex The tensor indices for the corresponding action """ raise NotImplementedError() diff --git a/src/gflownet/envs/mol_building_env.py b/src/gflownet/envs/mol_building_env.py index 5e43dd0b..6dc5734c 100644 --- a/src/gflownet/envs/mol_building_env.py +++ b/src/gflownet/envs/mol_building_env.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List import networkx as nx import numpy as np @@ -8,7 +8,7 @@ from rdkit.Chem import Mol from rdkit.Chem.rdchem import BondType, ChiralType -from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionType, GraphBuildingEnvContext +from gflownet.envs.graph_building_env import ActionIndex, Graph, GraphAction, GraphActionType, GraphBuildingEnvContext from gflownet.utils.graphs import random_walk_probs DEFAULT_CHIRAL_TYPES = [ChiralType.CHI_UNSPECIFIED, ChiralType.CHI_TETRAHEDRAL_CW, ChiralType.CHI_TETRAHEDRAL_CCW] @@ -158,44 +158,43 @@ def __init__( GraphActionType.RemoveEdgeAttr, ] - def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True): + def ActionIndex_to_GraphAction(self, g: gd.Data, aidx: ActionIndex, fwd: bool = True): """Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction""" - act_type, act_row, act_col = [int(i) for i in action_idx] if fwd: - t = self.action_type_order[act_type] + t = self.action_type_order[aidx.action_type] else: - t = self.bck_action_type_order[act_type] + t = self.bck_action_type_order[aidx.action_type] if t is GraphActionType.Stop: return GraphAction(t) elif t is GraphActionType.AddNode: - return GraphAction(t, source=act_row, value=self.atom_attr_values["v"][act_col]) + return GraphAction(t, source=aidx.row_idx, value=self.atom_attr_values["v"][aidx.col_idx]) elif t is GraphActionType.SetNodeAttr: - attr, val = self.atom_attr_logit_map[act_col] - return GraphAction(t, source=act_row, attr=attr, value=val) + attr, val = self.atom_attr_logit_map[aidx.col_idx] + return GraphAction(t, source=aidx.row_idx, attr=attr, value=val) elif t is GraphActionType.AddEdge: - a, b = g.non_edge_index[:, act_row] + a, b = g.non_edge_index[:, aidx.row_idx] return GraphAction(t, source=a.item(), target=b.item()) elif t is GraphActionType.SetEdgeAttr: # In order to form an undirected graph for torch_geometric, edges are duplicated, in order (i.e. # g.edge_index = [[a,b], [b,a], [c,d], [d,c], ...].T), but edge logits are not. So to go from one # to another we can safely divide or multiply by two. - a, b = g.edge_index[:, act_row * 2] - attr, val = self.bond_attr_logit_map[act_col] + a, b = g.edge_index[:, aidx.row_idx * 2] + attr, val = self.bond_attr_logit_map[aidx.col_idx] return GraphAction(t, source=a.item(), target=b.item(), attr=attr, value=val) elif t is GraphActionType.RemoveNode: - return GraphAction(t, source=act_row) + return GraphAction(t, source=aidx.row_idx) elif t is GraphActionType.RemoveNodeAttr: - attr = self.settable_atom_attrs[act_col] - return GraphAction(t, source=act_row, attr=attr) + attr = self.settable_atom_attrs[aidx.col_idx] + return GraphAction(t, source=aidx.row_idx, attr=attr) elif t is GraphActionType.RemoveEdge: - a, b = g.edge_index[:, act_row * 2] # see note above about edge_index + a, b = g.edge_index[:, aidx.row_idx * 2] # see note above about edge_index return GraphAction(t, source=a.item(), target=b.item()) elif t is GraphActionType.RemoveEdgeAttr: - a, b = g.edge_index[:, act_row * 2] # see note above about edge_index - attr = self.bond_attrs[act_col] + a, b = g.edge_index[:, aidx.row_idx * 2] # see note above about edge_index + attr = self.bond_attrs[aidx.col_idx] return GraphAction(t, source=a.item(), target=b.item(), attr=attr) - def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int, int]: + def GraphAction_to_ActionIndex(self, g: gd.Data, action: GraphAction) -> ActionIndex: """Translate a GraphAction to an index tuple""" for u in [self.action_type_order, self.bck_action_type_order]: if action.action in u: @@ -252,7 +251,7 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int col = self.bond_attrs.index(action.attr) else: raise ValueError(f"Unknown action type {action.action}") - return (type_idx, int(row), int(col)) + return ActionIndex(action_type=type_idx, row_idx=int(row), col_idx=int(col)) def graph_to_Data(self, g: Graph) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance""" diff --git a/src/gflownet/envs/seq_building_env.py b/src/gflownet/envs/seq_building_env.py index b8189690..25c48bad 100644 --- a/src/gflownet/envs/seq_building_env.py +++ b/src/gflownet/envs/seq_building_env.py @@ -1,11 +1,12 @@ from copy import deepcopy -from typing import Any, List, Sequence, Tuple +from typing import Any, List, Sequence import torch from torch.nn.utils.rnn import pad_sequence from torch_geometric.data import Data from gflownet.envs.graph_building_env import ( + ActionIndex, Graph, GraphAction, GraphActionType, @@ -94,17 +95,16 @@ def __init__(self, alphabet: Sequence[str], num_cond_dim=0): self.num_actions = len(alphabet) + 1 # Alphabet + Stop self.num_cond_dim = num_cond_dim - def aidx_to_GraphAction(self, g: Data, action_idx: Tuple[int, int, int], fwd: bool = True) -> GraphAction: + def ActionIndex_to_GraphAction(self, g: Data, aidx: ActionIndex, fwd: bool = True) -> GraphAction: # Since there's only one "object" per timestep to act upon (in graph parlance), the row is always == 0 - act_type, _, act_col = [int(i) for i in action_idx] - t = self.action_type_order[act_type] + t = self.action_type_order[aidx.action_type] if t is GraphActionType.Stop: return GraphAction(t) elif t is GraphActionType.AddNode: - return GraphAction(t, value=act_col) - raise ValueError(action_idx) + return GraphAction(t, value=aidx.col_idx) + raise ValueError(aidx) - def GraphAction_to_aidx(self, g: Data, action: GraphAction) -> Tuple[int, int, int]: + def GraphAction_to_ActionIndex(self, g: Data, action: GraphAction) -> ActionIndex: if action.action is GraphActionType.Stop: col = 0 type_idx = self.action_type_order.index(action.action) @@ -113,7 +113,7 @@ def GraphAction_to_aidx(self, g: Data, action: GraphAction) -> Tuple[int, int, i type_idx = self.action_type_order.index(action.action) else: raise ValueError(action) - return (type_idx, 0, int(col)) + return ActionIndex(action_type=type_idx, row_idx=0, col_idx=int(col)) def graph_to_Data(self, g: Graph): s: Seq = g # type: ignore diff --git a/src/gflownet/envs/test.py b/src/gflownet/envs/test.py index b79027a6..10ced586 100644 --- a/src/gflownet/envs/test.py +++ b/src/gflownet/envs/test.py @@ -98,7 +98,7 @@ def main(smi, n_steps): print(a.action, a.source, a.target, a.value) graphs = [ctx.graph_to_Data(i) for i, _ in traj] traj_batch = ctx.collate(graphs) - actions = [ctx.GraphAction_to_aidx(g, a) for g, a in zip(graphs, [i[1] for i in traj])] + actions = [ctx.GraphAction_to_ActionIndex(g, a) for g, a in zip(graphs, [i[1] for i in traj])] # Train to overfit for i in tqdm(range(n_steps)): @@ -129,7 +129,7 @@ def main(smi, n_steps): # some probability is left on unlikely (wrong) steps print("oops, starting step over") continue - graph_action = ctx.aidx_to_GraphAction(tg, action) + graph_action = ctx.ActionIndex_to_GraphAction(tg, action) print(graph_action.action, graph_action.source, graph_action.target, graph_action.value) if graph_action.action is GraphActionType.Stop: break diff --git a/tests/test_envs.py b/tests/test_envs.py index ff8f43af..7c0677f4 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -8,7 +8,7 @@ from gflownet.algo.trajectory_balance import TrajectoryBalance from gflownet.config import Config from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext -from gflownet.envs.graph_building_env import GraphBuildingEnv +from gflownet.envs.graph_building_env import ActionIndex, GraphBuildingEnv from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.models import bengio2021flow @@ -56,8 +56,8 @@ def expand(s, idx): continue nz = mask.nonzero() for i in nz: # Only expand non-masked legal actions - aidx = (at, i[0].item(), i[1].item()) - ga = ctx.aidx_to_GraphAction(gd, aidx) + aidx = ActionIndex(at, i[0].item(), i[1].item()) + ga = ctx.ActionIndex_to_GraphAction(gd, aidx) sp = env.step(s, ga) h = g2h(sp) if h in graph_cache: @@ -152,7 +152,7 @@ def _test_backwards_mask_equivalence_ipa(two_node_states, ctx): if a in c: break else: - ga = ctx.aidx_to_GraphAction(gd, a, fwd=False) + ga = ctx.ActionIndex_to_GraphAction(gd, a, fwd=False) gp = env.step(g, ga) # TODO: It is a bit weird that get_idempotent_actions is in an algo class, # probably also belongs in a graph utils file. From a6f141a8492b42d144fe603095904551ecfa034b Mon Sep 17 00:00:00 2001 From: julienroyd Date: Tue, 12 Mar 2024 15:42:04 +0000 Subject: [PATCH 04/22] chore: centralises masking in GraphActionCategorical(), specifically: * renamed masks in GraphActionCategorical to action_mask (to avoid confusion with done-masking, molecule validity masking, goal-conditioning masking, etc.) * made logits and masks private attributes of GraphActionCategorical * remove _mask() application from GraphTransformerFragEnvelopeQL and GraphTransformerGFN * removed masking from GraphSampler (masks applied in property setters of GraphActionCategorical._logits) --- docs/implementation_notes.md | 8 +- src/gflownet/algo/envelope_q_learning.py | 48 +++++------ src/gflownet/algo/graph_sampling.py | 35 ++++---- src/gflownet/envs/frag_mol_env.py | 5 ++ src/gflownet/envs/graph_building_env.py | 102 ++++++++++++++++------- src/gflownet/envs/test.py | 2 +- src/gflownet/models/graph_transformer.py | 31 ++++--- src/gflownet/models/seq_transformer.py | 2 +- src/gflownet/tasks/seh_frag_moo.py | 2 + tests/test_envs.py | 26 +++--- tests/test_graph_building_env.py | 2 +- 11 files changed, 147 insertions(+), 116 deletions(-) diff --git a/docs/implementation_notes.md b/docs/implementation_notes.md index ad146220..8e6c2474 100644 --- a/docs/implementation_notes.md +++ b/docs/implementation_notes.md @@ -10,7 +10,6 @@ We separate experiment concerns in four categories: - maps graphs to torch_geometric `Data` instances - maps GraphActions to action indices - - produces action masks - communicates to the model what inputs it should expect - The Task class is responsible for computing the reward of a state, and for sampling conditioning information - The Trainer class is responsible for instanciating everything, and running the training & testing loop @@ -33,12 +32,11 @@ The code contains a specific categorical distribution type for graph actions, `G Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor. -The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution. +The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, action masks and so on; it can also be used to sample from the distribution. To expand, the logits are always 2d tensors, and there’s going to be one such tensor per “action type” that the agent is allowed to take. -Since graphs have variable number of nodes, and since each node has `n` associated possible action/logits, then the `(n_nodes, n)` tensor will vary from minibatch to minibatch. -In addition,the nodes in said logit tensor belong to different graphs in the minibatch; this is indicated by a `batch` tensor of shape `(n_nodes,)` for nodes (for e.g. edges it would be of shape `(n_edges,)`). - +Since graphs have variable number of nodes, and since each node has `n_node_actions` associated possible action/logits, then the `(n_nodes, n_node_actions)` tensor will vary from minibatch to minibatch. +In addition, the nodes in said logit tensor belong to different graphs in the minibatch; this is indicated by a `batch` tensor of shape `(n_nodes,)` for nodes (for e.g. edges it would be of shape `(n_edges,)`). Here’s an example: say we have 2 graphs in a minibatch, the first has 3 nodes, the second 2 nodes. The logits associated with AddNode will be of shape `(5, n)` (assuming there are `n` types of nodes in the problem). Say `n=2`, and `logits[AddNode] = [[1,2],[3,4],[5,6],[7,8],[9,0]]`, and `batch=[0,0,0,1,1]`. Then to compute the policy, we have to compute a softmax appropriately, i.e. the softmax for the first graph would be `softmax([1,2,3,4,5,6])` and for the second `softmax([7,8,9,0])` . This is possible thanks to `batch` and is what `GraphActionCategorical` does behind the scenes. diff --git a/src/gflownet/algo/envelope_q_learning.py b/src/gflownet/algo/envelope_q_learning.py index 95cc3d44..c8317ba2 100644 --- a/src/gflownet/algo/envelope_q_learning.py +++ b/src/gflownet/algo/envelope_q_learning.py @@ -64,43 +64,35 @@ def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): src_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_emb, node_embeddings[e_row]], 1)) dst_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_emb, node_embeddings[e_col]], 1)) - def _mask(x, m): - # mask logit vector x with binary mask m - return x * m + self.mask_value * (1 - m) - - def _mask_obj(x, m): - # mask logit vector x with binary mask m - return ( - x.reshape(x.shape[0], x.shape[1] // self.num_objectives, self.num_objectives) * m[:, :, None] - + self.mask_value * (1 - m[:, :, None]) - ).reshape(x.shape) - cat = GraphActionCategorical( g, - logits=[ + raw_logits=[ F.relu(self.emb2stop(graph_embeddings)), - _mask(F.relu(self.emb2add_node(node_embeddings)), g.add_node_mask), - _mask_obj(F.relu(torch.cat([src_anchor_logits, dst_anchor_logits], 1)), g.set_edge_attr_mask), + F.relu(self.emb2add_node(node_embeddings)), + F.relu(torch.cat([src_anchor_logits, dst_anchor_logits], 1)), ], + action_masks=[1, g.add_node_mask.repeat(1, self.num_objectives), g.set_edge_attr_mask.repeat(1, self.num_objectives)], keys=[None, "x", "edge_index"], types=self.action_type_order, ) r_pred = self.emb2reward(graph_embeddings) if output_Qs: return cat, r_pred - cat.masks = [1, g.add_node_mask.cpu(), g.set_edge_attr_mask.cpu()] - # Compute the greedy policy - # See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations - # TODO: this makes assumptions about how conditional vectors are created! Not robust to upstream changes - w = cond[:, -self.num_objectives :] - w_dot_Q = [ - (qi.reshape((qi.shape[0], qi.shape[1] // w.shape[1], w.shape[1])) * w[b][:, None, :]).sum(2) - for qi, b in zip(cat.logits, cat.batch) - ] - # Set the softmax distribution to a very low temperature to make sure only the max gets - # sampled (and we get random argmax tie breaking for free!): - cat.logits = [i * 100 for i in w_dot_Q] - return cat, r_pred + + else: + # Compute the greedy policy + # See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations + # TODO: this makes assumptions about how conditional vectors are created! Not robust to upstream changes + w = cond[:, -self.num_objectives :] + w_dot_Q = [ + (qi.reshape((qi.shape[0], qi.shape[1] // w.shape[1], w.shape[1])) * w[b][:, None, :]).sum(2) + for qi, b in zip(cat.logits, cat.batch) + ] + cat.action_masks = [1, g.add_node_mask.cpu(), g.set_edge_attr_mask.cpu()] + # Set the softmax distribution to a very low temperature to make sure only the max gets + # sampled (and we get random argmax tie breaking for free!): + cat.logits = [i * 100 for i in w_dot_Q] + return cat, r_pred class GraphTransformerEnvelopeQL(nn.Module): @@ -134,7 +126,7 @@ def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): e_row, e_col = g.edge_index[:, ::2] cat = GraphActionCategorical( g, - logits=[ + raw_logits=[ self.emb2stop(graph_embeddings), self.emb2add_node(node_embeddings), self.emb2set_node_attr(node_embeddings), diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index d1ec3a5b..6b392bb9 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -79,6 +79,8 @@ def sample_from_model( Conditional information of each trajectory, shape (n, n_info) dev: torch.device Device on which data is manipulated + random_action_prob: float + Probability of taking a random action at each step Returns ------- @@ -92,17 +94,17 @@ def sample_from_model( # This will be returned data = [{"traj": [], "reward_pred": None, "is_valid": True, "is_sink": []} for i in range(n)] # Let's also keep track of trajectory statistics according to the model - fwd_logprob: List[List[Tensor]] = [[] for i in range(n)] - bck_logprob: List[List[Tensor]] = [[] for i in range(n)] + fwd_logprob: List[List[Tensor]] = [[] for _ in range(n)] + bck_logprob: List[List[Tensor]] = [[] for _ in range(n)] - graphs = [self.env.new() for i in range(n)] - done = [False] * n + graphs = [self.env.new() for _ in range(n)] + done = [False for _ in range(n)] # TODO: instead of padding with Stop, we could have a virtual action whose probability # always evaluates to 1. Presently, Stop should convert to a (0,0,0) ActionIndex, which should # always be at least a valid index, and will be masked out anyways -- but this isn't ideal. # Here we have to pad the backward actions with something, since the backward actions are # evaluated at s_{t+1} not s_t. - bck_a = [[GraphAction(GraphActionType.Stop)] for i in range(n)] + bck_a = [[GraphAction(GraphActionType.Stop)] for _ in range(n)] def not_done(lst): return [e for i, e in enumerate(lst) if not done[i]] @@ -116,19 +118,14 @@ def not_done(lst): # TODO: compute bck_cat.log_prob(bck_a) when relevant fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), cond_info[not_done_mask]) if random_action_prob > 0: - masks = [1] * len(fwd_cat.logits) if fwd_cat.masks is None else fwd_cat.masks # Device which graphs in the minibatch will get their action randomized is_random_action = torch.tensor( self.rng.uniform(size=len(torch_graphs)) < random_action_prob, device=dev ).float() - # Set the logits to some large value if they're not masked, this way the masked - # actions have no probability of getting sampled, and there is a uniform - # distribution over the rest + # Set the logits to some large value to have a uniform distribution fwd_cat.logits = [ - # We don't multiply m by i on the right because we're assume the model forward() - # method already does that - is_random_action[b][:, None] * torch.ones_like(i) * m * 100 + i * (1 - is_random_action[b][:, None]) - for i, m, b in zip(fwd_cat.logits, masks, fwd_cat.batch) + is_random_action[b][:, None] * torch.ones_like(i) * 100 + i * (1 - is_random_action[b][:, None]) + for i, b in zip(fwd_cat.logits, fwd_cat.batch) ] if self.sample_temp != 1: sample_cat = copy.copy(fwd_cat) @@ -259,16 +256,12 @@ def not_done(lst): else: gbatch = self.ctx.collate(torch_graphs) action_types = self.ctx.bck_action_type_order - masks = [getattr(gbatch, i.mask_name) for i in action_types] + action_masks = [self.ctx.action_type_to_mask(t, g, assert_mask_exists=True) for t in action_types] bck_cat = GraphActionCategorical( gbatch, - logits=[m * 1e6 for m in masks], - keys=[ - # TODO: This is not very clean, could probably abstract this away somehow - GraphTransformerGFN._graph_part_to_key[GraphTransformerGFN._action_type_to_graph_part[t]] - for t in action_types - ], - masks=masks, + raw_logits=[torch.ones_like(m) for m in action_masks], + keys=[GraphTransformerGFN.action_type_to_key(t) for t in action_types], + action_masks=action_masks, types=action_types, ) bck_actions = bck_cat.sample() diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index 371ec5a7..442bc7b6 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -177,6 +177,11 @@ def GraphAction_to_ActionIndex(self, g: gd.Data, action: GraphAction) -> ActionI col = 1 return ActionIndex(action_type=type_idx, row_idx=int(row), col_idx=int(col)) + def action_type_to_mask(self, t: GraphActionType, g: gd.Batch, assert_mask_exists: bool = False): + if assert_mask_exists: + assert hasattr(g, t.mask_name), f"Mask {t.mask_name} not found in graph data" + return getattr(g, t.mask_name) if hasattr(g, t.mask_name) else torch.ones((1, 1), device=g.x.device) + def graph_to_Data(self, g: Graph) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance Parameters diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 255c91e3..98c3f4c3 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -471,11 +471,11 @@ class GraphActionCategorical: def __init__( self, graphs: gd.Batch, - logits: List[torch.Tensor], + raw_logits: List[torch.Tensor], keys: List[Union[str, None]], types: List[GraphActionType], deduplicate_edge_index=True, - masks: List[torch.Tensor] = None, + action_masks: List[torch.Tensor] = None, slice_dict: Optional[dict[str, torch.Tensor]] = None, ): """A multi-type Categorical compatible with generating structured actions. @@ -489,12 +489,24 @@ def __init__( provides this convenient interaction between torch_geometric Batch objects and lists of logit tensors. + Note on action-masking: + Action masks depend on the environment logic (what are allowed v.s. prohibited actions). + Thus, the action_masks should be created by the EnvContext (e.g. FragMolBuildingEnvContext) + and passed to the GraphActionCategorical as a list of tensors. However, action masks + should be applied to the logits within this class only to allow proper masking + when computing log probabilities and sampling and avoid confusion about + the state of the logits (masked or not) for external members. + For this reason, the constructor takes as input the raw (unmasked) logits and the + masks separately. The (masked) logits are cached in the _masked_logits attribute. + Both the (masked) logits and the masks are private properties, and attempts to edit the masks or the logits will + apply the masks to the raw_logits again. + Parameters ---------- graphs: Batch A Batch of graphs to which the logits correspond - logits: List[Tensor] - A list of tensors of shape `(n, m)` representing logits + raw_logits: List[Tensor] + A list of tensors of shape `(n, m)` representing raw (unmasked) logits over a variable number of graph elements (e.g. nodes) for which there are `m` possible actions. `n` should thus be equal to the sum of the number of such elements for each @@ -516,29 +528,30 @@ def __init__( If true, this means that the 'edge_index' keys have been reduced by e_i[::2] (presumably because the graphs are undirected) masks: List[Tensor], default=None - If not None, a list of broadcastable tensors that multiplicatively + If not None, a list of broadcastable tensors that mask out logits of invalid actions slice_dist: Optional[dict[str, Tensor]], default=None If not None, a map of tensors that indicate the start (and end) the graph index of each object keyed. If None, uses the `_slice_dict` attribute of the graphs. """ self.num_graphs = graphs.num_graphs - assert all([i.ndim == 2 for i in logits]) - assert len(logits) == len(types) == len(keys) - if masks is not None: - assert len(logits) == len(masks) - assert all([i.ndim == 2 for i in masks]) + assert all([i.ndim == 2 for i in raw_logits]) + assert len(raw_logits) == len(types) == len(keys) + if action_masks is not None: + assert len(raw_logits) == len(action_masks) + assert all([i.ndim == 2 for i in action_masks]) # The logits - self.logits = logits + self.raw_logits = raw_logits self.types = types self.keys = keys self.dev = dev = graphs.x.device self._epsilon = 1e-38 # TODO: mask is only used by graph_sampler, but maybe we should be more careful with it # (e.g. in a softmax and such) - # Can be set to indicate which logits are masked out (shape must match logits or have + # Can be set to indicate which raw_logits are masked out (shape must match raw_logits or have # broadcast dimensions already set) - self.masks: List[Any] = masks + self._action_masks: List[Any] = action_masks + self._apply_action_masks() # I'm extracting batches and slices in a slightly hackish way, # but I'm not aware of a proper API to torch_geometric that @@ -574,9 +587,37 @@ def __init__( self.batch[idx] = self.batch[idx][::2] self.slice[idx] = self.slice[idx].div(2, rounding_mode="floor") + @property + def logits(self): + return self._masked_logits + + @logits.setter + def logits(self, new_raw_logits): + self.raw_logits = new_raw_logits + self._apply_action_masks() + + @property + def action_masks(self): + return self._action_masks + + @action_masks.setter + def action_masks(self, new_action_masks): + self._action_masks = new_action_masks + self._apply_action_masks() + + def _apply_action_masks(self): + self._masked_logits = [self._mask(logits, mask) for logits, mask in zip(self.raw_logits, self._action_masks)] if self._action_masks is not None else self.raw_logits + + def _mask(self, x, m): + """ + mask logit vector x with binary mask m, -1000 is a tiny log-value + Note to self: we can't use torch.inf here, because inf * 0 is nan + """ + return x * m + -1000 * (1 - m) + def detach(self): new = copy.copy(self) - new.logits = [i.detach() for i in new.logits] + new._masked_logits = [i.detach() for i in new._masked_logits] if new.logprobs is not None: new.logprobs = [i.detach() for i in new.logprobs] if new.log_n is not None: @@ -585,15 +626,15 @@ def detach(self): def to(self, device): self.dev = device - self.logits = [i.to(device) for i in self.logits] + self._masked_logits = [i.to(device) for i in self._masked_logits] self.batch = [i.to(device) for i in self.batch] self.slice = [i.to(device) for i in self.slice] if self.logprobs is not None: self.logprobs = [i.to(device) for i in self.logprobs] if self.log_n is not None: self.log_n = self.log_n.to(device) - if self.masks is not None: - self.masks = [i.to(device) for i in self.masks] + if self._action_masks is not None: + self._action_masks = [i.to(device) for i in self._action_masks] return self def log_n_actions(self): @@ -602,7 +643,7 @@ def log_n_actions(self): sum( [ scatter(m.broadcast_to(i.shape).int().sum(1), b, dim=0, dim_size=self.num_graphs, reduce="sum") - for m, i, b in zip(self.masks, self.logits, self.batch) + for m, i, b in zip(self._action_masks, self._masked_logits, self.batch) ] ) .clamp(1) @@ -624,7 +665,7 @@ def _compute_batchwise_max( Parameters ---------- x: List[torch.Tensor] - A list of tensors of shape `(n, m)` (e.g. representing logits) + A list of tensors of shape `(n, m)` (e.g. representing _masked_logits) detach: bool, default=True If true, detach the tensors before computing the max batch: List[torch.Tensor], default=None @@ -659,10 +700,10 @@ def logsoftmax(self): if self.logprobs is not None: return self.logprobs # Use the `subtract by max` trick to avoid precision errors. - maxl = self._compute_batchwise_max(self.logits).values + maxl = self._compute_batchwise_max(self._masked_logits).values # substract by max then take exp # x[b, None] indexes by the batch to map back to each node/edge and adds a broadcast dim - corr_logits = [(i - maxl[b, None]) for i, b in zip(self.logits, self.batch)] + corr_logits = [(i - maxl[b, None]) for i, b in zip(self._masked_logits, self.batch)] exp_logits = [i.exp().clamp(self._epsilon) for i, b in zip(corr_logits, self.batch)] # sum corrected exponentiated logits, to get log(Z') = log(Z - max) = log(sum(exp(logits - max))) logZ = sum( @@ -676,15 +717,15 @@ def logsoftmax(self): return self.logprobs def logsumexp(self, x=None): - """Reduces `x` (the logits by default) to one scalar per graph""" + """Reduces `x` (the _masked_logits by default) to one scalar per graph""" if x is None: - x = self.logits + x = self._masked_logits # Use the `subtract by max` trick to avoid precision errors. maxl = self._compute_batchwise_max(x).values # substract by max then take exp # x[b, None] indexes by the batch to map back to each node/edge and adds a broadcast dim exp_vals = [(i - maxl[b, None]).exp().clamp(self._epsilon) for i, b in zip(x, self.batch)] - # sum corrected exponentiated logits, to get log(Z - max) = log(sum(exp(logits)) - max) + # sum corrected exponentiated _masked_logits, to get log(Z - max) = log(sum(exp(_masked_logits)) - max) reduction = sum( [scatter(i, b, dim=0, dim_size=self.num_graphs, reduce="sum").sum(1) for i, b in zip(exp_vals, self.batch)] ).log() @@ -707,11 +748,11 @@ def sample(self) -> List[ActionIndex]: # mutually exclusive). # Uniform noise - u = [torch.rand(i.shape, device=self.dev) for i in self.logits] + u = [torch.rand(i.shape, device=self.dev) for i in self._masked_logits] # Gumbel noise - gumbel = [logit - (-noise.log()).log() for logit, noise in zip(self.logits, u)] + gumbel = [logit - (-noise.log()).log() for logit, noise in zip(self._masked_logits, u)] - if self.masks is not None: + if self._action_masks is not None: gumbel_safe = [ torch.where( mask == 1, @@ -723,7 +764,7 @@ def sample(self) -> List[ActionIndex]: ), torch.finfo(x.dtype).min, ) - for x, mask in zip(gumbel, self.masks) + for x, mask in zip(gumbel, self._action_masks) ] else: gumbel_safe = gumbel @@ -872,7 +913,8 @@ def entropy(self, logprobs=None): class GraphBuildingEnvContext: - """A context class defines what the graphs are, how they map to and from data""" + """A context class defines what the graphs are, how they map to and from data + """ device: torch.device action_type_order: List[GraphActionType] @@ -913,6 +955,8 @@ def GraphAction_to_ActionIndex(self, g: gd.Data, action: GraphAction) -> ActionI def graph_to_Data(self, g: Graph) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance + The logic to build masks for prohibited actions can be implemented here, + packed in the data object and used in the GraphActionCategorical. Parameters ---------- g: Graph diff --git a/src/gflownet/envs/test.py b/src/gflownet/envs/test.py index 10ced586..d9c4da4b 100644 --- a/src/gflownet/envs/test.py +++ b/src/gflownet/envs/test.py @@ -63,7 +63,7 @@ def forward(self, g: gd.Batch): e_row, e_col = g.edge_index[:, ::2] cat = GraphActionCategorical( g, - logits=[ + raw_logits=[ self.emb2stop(glob), self.emb2add_node(o), self.emb2add_node_attr(o), diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 8c3993f0..76a627f0 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -1,13 +1,16 @@ from itertools import chain +from typing import Dict + import torch +from torch import Tensor import torch.nn as nn import torch_geometric.data as gd import torch_geometric.nn as gnn from torch_geometric.utils import add_self_loops from gflownet.config import Config -from gflownet.envs.graph_building_env import GraphActionCategorical, GraphActionType +from gflownet.envs.graph_building_env import GraphActionCategorical, GraphActionType, GraphBuildingEnvContext def mlp(n_in, n_hid, n_out, n_layer, act=nn.LeakyReLU): @@ -47,6 +50,9 @@ def __init__(self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, n The number of Transformer layers. num_heads: int The number of Transformer heads per layer. + num_noise: int + The number of noise features to add to the node features. + This can be used as a simple positional encoding mechanism. ln_type: str The location of Layer Norm in the transformer, either 'pre' or 'post', default 'pre'. (apparently, before is better than after, see https://arxiv.org/pdf/2002.04745.pdf) @@ -165,9 +171,11 @@ class GraphTransformerGFN(nn.Module): "edge": "edge_index", } + action_type_to_key = lambda action_type: GraphTransformerGFN._graph_part_to_key.get(GraphTransformerGFN._action_type_to_graph_part.get(action_type)) + def __init__( self, - env_ctx, + env_ctx: GraphBuildingEnvContext, cfg: Config, num_graph_out=1, do_bck=False, @@ -183,6 +191,7 @@ def __init__( num_heads=cfg.model.graph_transformer.num_heads, ln_type=cfg.model.graph_transformer.ln_type, ) + self.env_ctx = env_ctx num_emb = cfg.model.num_emb num_final = num_emb num_glob_final = num_emb * 2 @@ -223,24 +232,12 @@ def __init__( # TODO: flag for this self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2) - def _action_type_to_mask(self, t, g): - return getattr(g, t.mask_name) if hasattr(g, t.mask_name) else torch.ones((1, 1), device=g.x.device) - - def _action_type_to_logit(self, t, emb, g): - logits = self.mlps[t.cname](emb[self._action_type_to_graph_part[t]]) - return self._mask(logits, self._action_type_to_mask(t, g)) - - def _mask(self, x, m): - # mask logit vector x with binary mask m, -1000 is a tiny log-value - # Note to self: we can't use torch.inf here, because inf * 0 is nan (but also see issue #99) - return x * m + -1000 * (1 - m) - - def _make_cat(self, g, emb, action_types): + def _make_cat(self, g: gd.Batch, emb: Dict[str, Tensor], action_types: list[GraphActionType]): return GraphActionCategorical( g, - logits=[self._action_type_to_logit(t, emb, g) for t in action_types], + raw_logits=[self.mlps[t.cname](emb[self._action_type_to_graph_part[t]]) for t in action_types], keys=[self._action_type_to_key[t] for t in action_types], - masks=[self._action_type_to_mask(t, g) for t in action_types], + action_masks=[self.env_ctx.action_type_to_mask(t, g) for t in action_types], types=action_types, ) diff --git a/src/gflownet/models/seq_transformer.py b/src/gflownet/models/seq_transformer.py index 6916366a..b1a4173a 100644 --- a/src/gflownet/models/seq_transformer.py +++ b/src/gflownet/models/seq_transformer.py @@ -108,7 +108,7 @@ def forward(self, xs: SeqBatch, cond, batched=False): return ( GraphActionCategorical( xs, - logits=[stop_logits, add_node_logits], + raw_logits=[stop_logits, add_node_logits], keys=[None, None], types=self.ctx.action_type_order, slice_dict={}, diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index ef8def85..d3880b5c 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -393,7 +393,9 @@ def main(): config.desc = "debug_seh_frag_moo" config.log_dir = "./logs/debug_run_sfm" config.device = "cuda" if torch.cuda.is_available() else "cpu" + config.num_workers = 0 config.print_every = 1 + config.algo.num_from_policy = 2 config.validate_every = 1 config.num_final_gen_steps = 5 config.num_training_steps = 3 diff --git a/tests/test_envs.py b/tests/test_envs.py index 7c0677f4..bd3959f3 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -50,8 +50,8 @@ def g2h(g): def expand(s, idx): # Recursively expand all children of s gd = ctx.graph_to_Data(s) - masks = [getattr(gd, gat.mask_name) for gat in ctx.action_type_order] - for at, mask in enumerate(masks): + action_masks = [getattr(gd, gat.mask_name) for gat in ctx.action_type_order] + for at, mask in enumerate(action_masks): if at == 0: # Ignore Stop action continue nz = mask.nonzero() @@ -106,8 +106,8 @@ def two_node_states_atoms(request): return data -def _test_backwards_mask_equivalence(two_node_states, ctx): - """This tests that FragMolBuildingEnvContext implements backwards masks correctly. It treats +def _test_backwards_action_mask_equivalence(two_node_states, ctx): + """This tests that FragMolBuildingEnvContext implements backwards action masks correctly. It treats GraphBuildingEnv.count_backward_transitions as the ground truth and raises an error if there is a different number of actions leading to the parents of any state. """ @@ -124,7 +124,7 @@ def _test_backwards_mask_equivalence(two_node_states, ctx): raise ValueError() -def _test_backwards_mask_equivalence_ipa(two_node_states, ctx): +def _test_backwards_action_mask_equivalence_ipa(two_node_states, ctx): """This tests that FragMolBuildingEnvContext implements backwards masks correctly. It treats GraphBuildingEnv.count_backward_transitions as the ground truth and raises an error if there is a different number of actions leading to the parents of any state. @@ -162,17 +162,17 @@ def _test_backwards_mask_equivalence_ipa(two_node_states, ctx): raise ValueError() -def test_backwards_mask_equivalence_frag(two_node_states_frags): - _test_backwards_mask_equivalence(two_node_states_frags, get_frag_env_ctx()) +def test_backwards_action_mask_equivalence_frag(two_node_states_frags): + _test_backwards_action_mask_equivalence(two_node_states_frags, get_frag_env_ctx()) -def test_backwards_mask_equivalence_ipa_frag(two_node_states_frags): - _test_backwards_mask_equivalence_ipa(two_node_states_frags, get_frag_env_ctx()) +def test_backwards_action_mask_equivalence_ipa_frag(two_node_states_frags): + _test_backwards_action_mask_equivalence_ipa(two_node_states_frags, get_frag_env_ctx()) -def test_backwards_mask_equivalence_atom(two_node_states_atoms): - _test_backwards_mask_equivalence(two_node_states_atoms, get_atom_env_ctx()) +def test_backwards_action_mask_equivalence_atom(two_node_states_atoms): + _test_backwards_action_mask_equivalence(two_node_states_atoms, get_atom_env_ctx()) -def test_backwards_mask_equivalence_ipa_atom(two_node_states_atoms): - _test_backwards_mask_equivalence_ipa(two_node_states_atoms, get_atom_env_ctx()) +def test_backwards_action_mask_equivalence_ipa_atom(two_node_states_atoms): + _test_backwards_action_mask_equivalence_ipa(two_node_states_atoms, get_atom_env_ctx()) diff --git a/tests/test_graph_building_env.py b/tests/test_graph_building_env.py index 6839f291..e9184cbd 100644 --- a/tests/test_graph_building_env.py +++ b/tests/test_graph_building_env.py @@ -17,7 +17,7 @@ def make_test_cat(): cat = GraphActionCategorical( # Let's use arange to have different logit values batch, - logits=[ + raw_logits=[ torch.arange(3).reshape((3, 1)).float(), torch.arange(6 * 4).reshape((6, 4)).float(), torch.arange(2 * 3).reshape((2, 3)).float(), From 4f775075c47b12d6b483d1b57251da2b866c2ad3 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Thu, 28 Mar 2024 16:16:50 -0600 Subject: [PATCH 05/22] feat: now mask logits by setting them to -inf (should avoid silent un-masking i.e. masked_logit * 0. -> 0. (un-masked)) --- src/gflownet/envs/graph_building_env.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 98c3f4c3..e33a40fb 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -609,11 +609,7 @@ def _apply_action_masks(self): self._masked_logits = [self._mask(logits, mask) for logits, mask in zip(self.raw_logits, self._action_masks)] if self._action_masks is not None else self.raw_logits def _mask(self, x, m): - """ - mask logit vector x with binary mask m, -1000 is a tiny log-value - Note to self: we can't use torch.inf here, because inf * 0 is nan - """ - return x * m + -1000 * (1 - m) + return x.masked_fill(m == 0, -torch.inf) def detach(self): new = copy.copy(self) @@ -751,25 +747,8 @@ def sample(self) -> List[ActionIndex]: u = [torch.rand(i.shape, device=self.dev) for i in self._masked_logits] # Gumbel noise gumbel = [logit - (-noise.log()).log() for logit, noise in zip(self._masked_logits, u)] - - if self._action_masks is not None: - gumbel_safe = [ - torch.where( - mask == 1, - torch.maximum( - x, - torch.nextafter( - torch.tensor(torch.finfo(x.dtype).min, dtype=x.dtype), torch.tensor(0.0, dtype=x.dtype) - ).to(x.device), - ), - torch.finfo(x.dtype).min, - ) - for x, mask in zip(gumbel, self._action_masks) - ] - else: - gumbel_safe = gumbel # Take the argmax - return self.argmax(x=gumbel_safe) + return self.argmax(x=gumbel) def argmax( self, From 10174fcb639c2ed90ba0196ff7616e57d423591c Mon Sep 17 00:00:00 2001 From: julienroyd Date: Thu, 28 Mar 2024 17:02:32 -0600 Subject: [PATCH 06/22] fix: wrong variable --- src/gflownet/algo/graph_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 6b392bb9..296c212d 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -256,7 +256,7 @@ def not_done(lst): else: gbatch = self.ctx.collate(torch_graphs) action_types = self.ctx.bck_action_type_order - action_masks = [self.ctx.action_type_to_mask(t, g, assert_mask_exists=True) for t in action_types] + action_masks = [self.ctx.action_type_to_mask(t, gbatch, assert_mask_exists=True) for t in action_types] bck_cat = GraphActionCategorical( gbatch, raw_logits=[torch.ones_like(m) for m in action_masks], From c6afaa101cdd7c8c83032e6d1b4db74aef50e3b6 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Thu, 28 Mar 2024 17:02:40 -0600 Subject: [PATCH 07/22] chore: tox --- src/gflownet/algo/envelope_q_learning.py | 8 ++++++-- src/gflownet/envs/frag_mol_env.py | 12 ++++++++---- src/gflownet/envs/graph_building_env.py | 23 +++++++++++++---------- src/gflownet/models/graph_transformer.py | 11 ++++++----- 4 files changed, 33 insertions(+), 21 deletions(-) diff --git a/src/gflownet/algo/envelope_q_learning.py b/src/gflownet/algo/envelope_q_learning.py index c8317ba2..8f836657 100644 --- a/src/gflownet/algo/envelope_q_learning.py +++ b/src/gflownet/algo/envelope_q_learning.py @@ -71,14 +71,18 @@ def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): F.relu(self.emb2add_node(node_embeddings)), F.relu(torch.cat([src_anchor_logits, dst_anchor_logits], 1)), ], - action_masks=[1, g.add_node_mask.repeat(1, self.num_objectives), g.set_edge_attr_mask.repeat(1, self.num_objectives)], + action_masks=[ + 1, + g.add_node_mask.repeat(1, self.num_objectives), + g.set_edge_attr_mask.repeat(1, self.num_objectives), + ], keys=[None, "x", "edge_index"], types=self.action_type_order, ) r_pred = self.emb2reward(graph_embeddings) if output_Qs: return cat, r_pred - + else: # Compute the greedy policy # See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index 442bc7b6..c99196b9 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -114,7 +114,9 @@ def ActionIndex_to_GraphAction(self, g: gd.Data, aidx: ActionIndex, fwd: bool = elif t is GraphActionType.AddNode: return GraphAction(t, source=aidx.row_idx, value=aidx.col_idx) elif t is GraphActionType.SetEdgeAttr: - a, b = g.edge_index[:, aidx.row_idx * 2] # Edges are duplicated to get undirected GNN, deduplicated for logits + a, b = g.edge_index[ + :, aidx.row_idx * 2 + ] # Edges are duplicated to get undirected GNN, deduplicated for logits if aidx.col_idx < self.num_stem_acts: attr = "src_attach" val = aidx.col_idx @@ -177,10 +179,12 @@ def GraphAction_to_ActionIndex(self, g: gd.Data, action: GraphAction) -> ActionI col = 1 return ActionIndex(action_type=type_idx, row_idx=int(row), col_idx=int(col)) - def action_type_to_mask(self, t: GraphActionType, g: gd.Batch, assert_mask_exists: bool = False): + def action_type_to_mask(self, t: GraphActionType, gbatch: gd.Batch, assert_mask_exists: bool = False): if assert_mask_exists: - assert hasattr(g, t.mask_name), f"Mask {t.mask_name} not found in graph data" - return getattr(g, t.mask_name) if hasattr(g, t.mask_name) else torch.ones((1, 1), device=g.x.device) + assert hasattr(gbatch, t.mask_name), f"Mask {t.mask_name} not found in graph data" + return ( + getattr(gbatch, t.mask_name) if hasattr(gbatch, t.mask_name) else torch.ones((1, 1), device=gbatch.x.device) + ) def graph_to_Data(self, g: Graph) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index e33a40fb..6a301fc7 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -493,11 +493,11 @@ def __init__( Action masks depend on the environment logic (what are allowed v.s. prohibited actions). Thus, the action_masks should be created by the EnvContext (e.g. FragMolBuildingEnvContext) and passed to the GraphActionCategorical as a list of tensors. However, action masks - should be applied to the logits within this class only to allow proper masking + should be applied to the logits within this class only to allow proper masking when computing log probabilities and sampling and avoid confusion about - the state of the logits (masked or not) for external members. - For this reason, the constructor takes as input the raw (unmasked) logits and the - masks separately. The (masked) logits are cached in the _masked_logits attribute. + the state of the logits (masked or not) for external members. + For this reason, the constructor takes as input the raw (unmasked) logits and the + masks separately. The (masked) logits are cached in the _masked_logits attribute. Both the (masked) logits and the masks are private properties, and attempts to edit the masks or the logits will apply the masks to the raw_logits again. @@ -590,23 +590,27 @@ def __init__( @property def logits(self): return self._masked_logits - + @logits.setter def logits(self, new_raw_logits): self.raw_logits = new_raw_logits self._apply_action_masks() - + @property def action_masks(self): return self._action_masks - + @action_masks.setter def action_masks(self, new_action_masks): self._action_masks = new_action_masks self._apply_action_masks() def _apply_action_masks(self): - self._masked_logits = [self._mask(logits, mask) for logits, mask in zip(self.raw_logits, self._action_masks)] if self._action_masks is not None else self.raw_logits + self._masked_logits = ( + [self._mask(logits, mask) for logits, mask in zip(self.raw_logits, self._action_masks)] + if self._action_masks is not None + else self.raw_logits + ) def _mask(self, x, m): return x.masked_fill(m == 0, -torch.inf) @@ -892,8 +896,7 @@ def entropy(self, logprobs=None): class GraphBuildingEnvContext: - """A context class defines what the graphs are, how they map to and from data - """ + """A context class defines what the graphs are, how they map to and from data""" device: torch.device action_type_order: List[GraphActionType] diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 76a627f0..f0b3981b 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -1,16 +1,15 @@ from itertools import chain - from typing import Dict import torch -from torch import Tensor import torch.nn as nn import torch_geometric.data as gd import torch_geometric.nn as gnn +from torch import Tensor from torch_geometric.utils import add_self_loops from gflownet.config import Config -from gflownet.envs.graph_building_env import GraphActionCategorical, GraphActionType, GraphBuildingEnvContext +from gflownet.envs.graph_building_env import GraphActionCategorical, GraphActionType def mlp(n_in, n_hid, n_out, n_layer, act=nn.LeakyReLU): @@ -171,11 +170,13 @@ class GraphTransformerGFN(nn.Module): "edge": "edge_index", } - action_type_to_key = lambda action_type: GraphTransformerGFN._graph_part_to_key.get(GraphTransformerGFN._action_type_to_graph_part.get(action_type)) + action_type_to_key = lambda action_type: GraphTransformerGFN._graph_part_to_key.get( # noqa: E731 + GraphTransformerGFN._action_type_to_graph_part.get(action_type) + ) def __init__( self, - env_ctx: GraphBuildingEnvContext, + env_ctx, cfg: Config, num_graph_out=1, do_bck=False, From b29fab49c88159d4ffd605ad357afdcccfcf17d7 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Thu, 28 Mar 2024 17:12:55 -0600 Subject: [PATCH 08/22] fix: in test --- tests/test_envs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_envs.py b/tests/test_envs.py index bd3959f3..7986abd9 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -145,14 +145,14 @@ def _test_backwards_action_mask_equivalence_ipa(two_node_states, ctx): for u, k in enumerate(ctx.bck_action_type_order): m = getattr(gd, k.mask_name) for a in m.nonzero(): - a = (u, a[0].item(), a[1].item()) + aidx = ActionIndex(u, a[0].item(), a[1].item()) for c in equivalence_classes: # Here `a` could have been added in another equivalence class by # get_idempotent_actions. If so, no need to check it. if a in c: break else: - ga = ctx.ActionIndex_to_GraphAction(gd, a, fwd=False) + ga = ctx.ActionIndex_to_GraphAction(gd, aidx, fwd=False) gp = env.step(g, ga) # TODO: It is a bit weird that get_idempotent_actions is in an algo class, # probably also belongs in a graph utils file. From c0d67e1849dc6bc90d0e37df2755bb78c4e3dec4 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Thu, 28 Mar 2024 18:01:15 -0600 Subject: [PATCH 09/22] minor: added assert for safety --- src/gflownet/envs/graph_building_env.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 6a301fc7..1b70c6e3 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -613,7 +613,8 @@ def _apply_action_masks(self): ) def _mask(self, x, m): - return x.masked_fill(m == 0, -torch.inf) + assert m.dtype == torch.float + return x.masked_fill(m == 0., -torch.inf) def detach(self): new = copy.copy(self) From b3ee035f36b8be5ad53cd71403d083426dfe4b07 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Mon, 1 Apr 2024 08:38:49 -0600 Subject: [PATCH 10/22] debug: reverted to multiplicative masking --- src/gflownet/envs/graph_building_env.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 1b70c6e3..fba86d9d 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -613,8 +613,12 @@ def _apply_action_masks(self): ) def _mask(self, x, m): + """ + mask logit vector x with binary mask m, -1000 is a tiny log-value + Note to self: we can't use torch.inf here, because inf * 0 is nan + """ assert m.dtype == torch.float - return x.masked_fill(m == 0., -torch.inf) + return x * m + -1000 * (1 - m) def detach(self): new = copy.copy(self) @@ -752,8 +756,25 @@ def sample(self) -> List[ActionIndex]: u = [torch.rand(i.shape, device=self.dev) for i in self._masked_logits] # Gumbel noise gumbel = [logit - (-noise.log()).log() for logit, noise in zip(self._masked_logits, u)] + + if self._action_masks is not None: + gumbel_safe = [ + torch.where( + mask == 1, + torch.maximum( + x, + torch.nextafter( + torch.tensor(torch.finfo(x.dtype).min, dtype=x.dtype), torch.tensor(0.0, dtype=x.dtype) + ).to(x.device), + ), + torch.finfo(x.dtype).min, + ) + for x, mask in zip(gumbel, self._action_masks) + ] + else: + gumbel_safe = gumbel # Take the argmax - return self.argmax(x=gumbel) + return self.argmax(x=gumbel_safe) def argmax( self, From b37969406acae6ccaa286eca395239ad1d58060a Mon Sep 17 00:00:00 2001 From: julienroyd Date: Mon, 1 Apr 2024 10:33:35 -0600 Subject: [PATCH 11/22] fix: corrected variable name --- tests/test_envs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_envs.py b/tests/test_envs.py index 7986abd9..b371e03f 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -144,12 +144,12 @@ def _test_backwards_action_mask_equivalence_ipa(two_node_states, ctx): equivalence_classes = [] for u, k in enumerate(ctx.bck_action_type_order): m = getattr(gd, k.mask_name) - for a in m.nonzero(): - aidx = ActionIndex(u, a[0].item(), a[1].item()) + for aidx in m.nonzero(): + aidx = ActionIndex(u, aidx[0].item(), aidx[1].item()) for c in equivalence_classes: # Here `a` could have been added in another equivalence class by # get_idempotent_actions. If so, no need to check it. - if a in c: + if aidx in c: break else: ga = ctx.ActionIndex_to_GraphAction(gd, aidx, fwd=False) From 70a4ec309ba53e119225332dd74055b4becb9fe2 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Mon, 1 Apr 2024 13:49:11 -0600 Subject: [PATCH 12/22] fix: detach tensors entering the buffer and sending to cpu --- src/gflownet/data/replay_buffer.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/gflownet/data/replay_buffer.py b/src/gflownet/data/replay_buffer.py index f0568dd2..0ff3223f 100644 --- a/src/gflownet/data/replay_buffer.py +++ b/src/gflownet/data/replay_buffer.py @@ -8,6 +8,10 @@ class ReplayBuffer(object): def __init__(self, cfg: Config, rng: np.random.Generator = None): + """ + Replay buffer for storing and sampling arbitrary data (e.g. transitions or trajectories) + In self.push(), the buffer detaches any torch tensor and sends it to the CPU. + """ self.capacity = cfg.replay.capacity self.warmup = cfg.replay.warmup assert self.warmup <= self.capacity, "ReplayBuffer warmup must be smaller than capacity" @@ -23,6 +27,7 @@ def push(self, *args): assert self._input_size == len(args), "ReplayBuffer input size must be constant" if len(self.buffer) < self.capacity: self.buffer.append(None) + args = detach_and_cpu(list(args)) self.buffer[self.position] = args self.position = (self.position + 1) % self.capacity @@ -42,3 +47,15 @@ def sample(self, batch_size): def __len__(self): return len(self.buffer) + + +def detach_and_cpu(x): + if isinstance(x, torch.Tensor): + x = x.detach().cpu() + elif isinstance(x, dict): + for k in x.keys(): + x[k] = detach_and_cpu(x[k]) + elif isinstance(x, list): + for i in range(len(x)): + x[i] = detach_and_cpu(x[i]) + return x From 5443f6f2dcc42b1c5c3f23241a21bee641366125 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Mon, 1 Apr 2024 15:16:04 -0600 Subject: [PATCH 13/22] minor: added case 'tuple' --- src/gflownet/data/replay_buffer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gflownet/data/replay_buffer.py b/src/gflownet/data/replay_buffer.py index 0ff3223f..1869d8ed 100644 --- a/src/gflownet/data/replay_buffer.py +++ b/src/gflownet/data/replay_buffer.py @@ -53,9 +53,9 @@ def detach_and_cpu(x): if isinstance(x, torch.Tensor): x = x.detach().cpu() elif isinstance(x, dict): - for k in x.keys(): - x[k] = detach_and_cpu(x[k]) + x = {k: detach_and_cpu(v) for k, v in x.items()} elif isinstance(x, list): - for i in range(len(x)): - x[i] = detach_and_cpu(x[i]) + x = [detach_and_cpu(v) for v in x] + elif isinstance(x, tuple): + x = tuple(detach_and_cpu(v) for v in x) return x From f92690a258c3315909088bf17bc07f169aa97816 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Mon, 1 Apr 2024 16:00:10 -0600 Subject: [PATCH 14/22] fix: added detach and cpu() at the begining of create_batch() --- src/gflownet/data/data_source.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index b8a2c7e1..90a2848f 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -7,7 +7,7 @@ from gflownet import GFNAlgorithm, GFNTask from gflownet.config import Config -from gflownet.data.replay_buffer import ReplayBuffer +from gflownet.data.replay_buffer import ReplayBuffer, detach_and_cpu from gflownet.envs.graph_building_env import GraphBuildingEnvContext from gflownet.utils.misc import get_worker_rng @@ -214,6 +214,7 @@ def call_sampling_hooks(self, trajs): return batch_info def create_batch(self, trajs, batch_info): + trajs = detach_and_cpu(trajs) ci = torch.stack([t["cond_info"]["encoding"] for t in trajs]) log_rewards = torch.stack([t["log_reward"] for t in trajs]) batch = self.algo.construct_batch(trajs, ci, log_rewards) From 15031a319a02cc0bc322fb3e00b42eb970a6f661 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Mon, 1 Apr 2024 16:11:42 -0600 Subject: [PATCH 15/22] minor: removed type cast, now support tuples --- src/gflownet/data/replay_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gflownet/data/replay_buffer.py b/src/gflownet/data/replay_buffer.py index 1869d8ed..df14cc68 100644 --- a/src/gflownet/data/replay_buffer.py +++ b/src/gflownet/data/replay_buffer.py @@ -27,7 +27,7 @@ def push(self, *args): assert self._input_size == len(args), "ReplayBuffer input size must be constant" if len(self.buffer) < self.capacity: self.buffer.append(None) - args = detach_and_cpu(list(args)) + args = detach_and_cpu(args) self.buffer[self.position] = args self.position = (self.position + 1) % self.capacity From f5abfc18027fb0e84b0ca908e2323659129690a3 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Mon, 1 Apr 2024 20:57:48 -0600 Subject: [PATCH 16/22] feat: added StrictDataClass to prevent creating new config attributes outside the class definition and avoid silent bugs due to unintended default configs --- src/gflownet/algo/config.py | 14 ++++++++------ src/gflownet/config.py | 5 +++-- src/gflownet/data/config.py | 4 +++- src/gflownet/models/config.py | 8 +++++--- src/gflownet/tasks/config.py | 12 +++++++----- src/gflownet/utils/config.py | 12 +++++++----- src/gflownet/utils/misc.py | 12 ++++++++++++ 7 files changed, 45 insertions(+), 22 deletions(-) diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index f6818e15..4dd9cbfe 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -2,6 +2,8 @@ from enum import IntEnum from typing import Optional +from gflownet.utils.misc import StrictDataClass + class Backward(IntEnum): """ @@ -53,7 +55,7 @@ class LossFN(IntEnum): @dataclass -class TBConfig: +class TBConfig(StrictDataClass): """Trajectory Balance config. Attributes @@ -113,7 +115,7 @@ class TBConfig: @dataclass -class MOQLConfig: +class MOQLConfig(StrictDataClass): gamma: float = 1 num_omega_samples: int = 32 num_objectives: int = 2 @@ -122,14 +124,14 @@ class MOQLConfig: @dataclass -class A2CConfig: +class A2CConfig(StrictDataClass): entropy: float = 0.01 gamma: float = 1 penalty: float = -10 @dataclass -class FMConfig: +class FMConfig(StrictDataClass): epsilon: float = 1e-38 balanced_loss: bool = False leaf_coef: float = 10 @@ -137,14 +139,14 @@ class FMConfig: @dataclass -class SQLConfig: +class SQLConfig(StrictDataClass): alpha: float = 0.01 gamma: float = 1 penalty: float = -10 @dataclass -class AlgoConfig: +class AlgoConfig(StrictDataClass): """Generic configuration for algorithms Attributes diff --git a/src/gflownet/config.py b/src/gflownet/config.py index e8238e97..b66f238d 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -8,10 +8,11 @@ from gflownet.models.config import ModelConfig from gflownet.tasks.config import TasksConfig from gflownet.utils.config import ConditionalsConfig +from gflownet.utils.misc import StrictDataClass @dataclass -class OptimizerConfig: +class OptimizerConfig(StrictDataClass): """Generic configuration for optimizers Attributes @@ -45,7 +46,7 @@ class OptimizerConfig: @dataclass -class Config: +class Config(StrictDataClass): """Base configuration for training Attributes diff --git a/src/gflownet/data/config.py b/src/gflownet/data/config.py index ce1bac7e..1e1b7f98 100644 --- a/src/gflownet/data/config.py +++ b/src/gflownet/data/config.py @@ -1,9 +1,11 @@ from dataclasses import dataclass from typing import Optional +from gflownet.utils.misc import StrictDataClass + @dataclass -class ReplayConfig: +class ReplayConfig(StrictDataClass): """Replay buffer configuration Attributes diff --git a/src/gflownet/models/config.py b/src/gflownet/models/config.py index acce656a..329c74c5 100644 --- a/src/gflownet/models/config.py +++ b/src/gflownet/models/config.py @@ -1,9 +1,11 @@ from dataclasses import dataclass, field from enum import Enum +from gflownet.utils.misc import StrictDataClass + @dataclass -class GraphTransformerConfig: +class GraphTransformerConfig(StrictDataClass): num_heads: int = 2 ln_type: str = "pre" num_mlp_layers: int = 0 @@ -15,13 +17,13 @@ class SeqPosEnc(int, Enum): @dataclass -class SeqTransformerConfig: +class SeqTransformerConfig(StrictDataClass): num_heads: int = 2 posenc: SeqPosEnc = SeqPosEnc.Rotary @dataclass -class ModelConfig: +class ModelConfig(StrictDataClass): """Generic configuration for models Attributes diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index 3c8a0fab..8f2b30c9 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -1,14 +1,16 @@ from dataclasses import dataclass, field from typing import List +from gflownet.utils.misc import StrictDataClass + @dataclass -class SEHTaskConfig: +class SEHTaskConfig(StrictDataClass): reduced_frag: bool = False @dataclass -class SEHMOOTaskConfig: +class SEHMOOTaskConfig(StrictDataClass): """Config for the SEHMOOTask Attributes @@ -31,13 +33,13 @@ class SEHMOOTaskConfig: @dataclass -class QM9TaskConfig: +class QM9TaskConfig(StrictDataClass): h5_path: str = "./data/qm9/qm9.h5" # see src/gflownet/data/qm9.py model_path: str = "./data/qm9/qm9_model.pt" @dataclass -class QM9MOOTaskConfig: +class QM9MOOTaskConfig(StrictDataClass): """ Config for the QM9MooTask @@ -61,7 +63,7 @@ class QM9MOOTaskConfig: @dataclass -class TasksConfig: +class TasksConfig(StrictDataClass): qm9: QM9TaskConfig = field(default_factory=QM9TaskConfig) qm9_moo: QM9MOOTaskConfig = field(default_factory=QM9MOOTaskConfig) seh: SEHTaskConfig = field(default_factory=SEHTaskConfig) diff --git a/src/gflownet/utils/config.py b/src/gflownet/utils/config.py index 5ee5369a..ca7122ee 100644 --- a/src/gflownet/utils/config.py +++ b/src/gflownet/utils/config.py @@ -1,9 +1,11 @@ from dataclasses import dataclass, field from typing import Any, List, Optional +from gflownet.utils.misc import StrictDataClass + @dataclass -class TempCondConfig: +class TempCondConfig(StrictDataClass): """Config for the temperature conditional. Attributes @@ -28,13 +30,13 @@ class TempCondConfig: @dataclass -class MultiObjectiveConfig: +class MultiObjectiveConfig(StrictDataClass): num_objectives: int = 2 # TODO: Change that as it can conflict with cfg.task.seh_moo.num_objectives num_thermometer_dim: int = 16 @dataclass -class WeightedPreferencesConfig: +class WeightedPreferencesConfig(StrictDataClass): """Config for the weighted preferences conditional. Attributes @@ -51,7 +53,7 @@ class WeightedPreferencesConfig: @dataclass -class FocusRegionConfig: +class FocusRegionConfig(StrictDataClass): """Config for the focus region conditional. Attributes @@ -71,7 +73,7 @@ class FocusRegionConfig: @dataclass -class ConditionalsConfig: +class ConditionalsConfig(StrictDataClass): valid_sample_cond_info: bool = True temperature: TempCondConfig = field(default_factory=TempCondConfig) moo: MultiObjectiveConfig = field(default_factory=MultiObjectiveConfig) diff --git a/src/gflownet/utils/misc.py b/src/gflownet/utils/misc.py index 7ec5bdba..8497a86b 100644 --- a/src/gflownet/utils/misc.py +++ b/src/gflownet/utils/misc.py @@ -54,3 +54,15 @@ def set_main_process_device(device): def get_worker_device(): worker_info = torch.utils.data.get_worker_info() return _main_process_device[0] if worker_info is None else torch.device("cpu") + + +class StrictDataClass: + """ + A dataclass that raises an error if any field is created outside of the __init__ method. + """ + + def __setattr__(self, name, value): + if hasattr(self, name) or name in self.__annotations__: + super().__setattr__(name, value) + else: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") From d7d6997e16bdb36015cfb3ec08069d6915e3c3a2 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Tue, 2 Apr 2024 13:25:57 -0600 Subject: [PATCH 17/22] minor: better error message --- src/gflownet/utils/misc.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/gflownet/utils/misc.py b/src/gflownet/utils/misc.py index 8497a86b..f9d32df3 100644 --- a/src/gflownet/utils/misc.py +++ b/src/gflownet/utils/misc.py @@ -65,4 +65,8 @@ def __setattr__(self, name, value): if hasattr(self, name) or name in self.__annotations__: super().__setattr__(name, value) else: - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'." + f" '{type(self).__name__}' is a StrictDataClass object." + f" Attributes can only be defined in the class definition." + ) From ad0db6ca0fd8c80ca4005bdca92020eef439a700 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Thu, 4 Apr 2024 07:36:11 -0600 Subject: [PATCH 18/22] fix: made focus_dir and preferences accessible at the batch level --- src/gflownet/data/data_source.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 90a2848f..cc868758 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -221,10 +221,10 @@ def create_batch(self, trajs, batch_info): batch.num_online = sum(t.get("is_online", 0) for t in trajs) batch.num_offline = len(trajs) - batch.num_online batch.extra_info = batch_info - if "preferences" in trajs[0]: - batch.preferences = torch.stack([t["preferences"] for t in trajs]) - if "focus_dir" in trajs[0]: - batch.focus_dir = torch.stack([t["focus_dir"] for t in trajs]) + if "preferences" in trajs[0]['cond_info'].keys(): + batch.preferences = torch.stack([t['cond_info']["preferences"] for t in trajs]) + if "focus_dir" in trajs[0]['cond_info'].keys(): + batch.focus_dir = torch.stack([t['cond_info']["focus_dir"] for t in trajs]) if self.ctx.has_n() and self.cfg.algo.tb.do_predict_n: log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] From 5ab80df44d18c05c3747123467cd0e364e9a4e52 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Thu, 4 Apr 2024 07:41:29 -0600 Subject: [PATCH 19/22] tox --- src/gflownet/data/data_source.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index cc868758..782f0750 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -221,10 +221,10 @@ def create_batch(self, trajs, batch_info): batch.num_online = sum(t.get("is_online", 0) for t in trajs) batch.num_offline = len(trajs) - batch.num_online batch.extra_info = batch_info - if "preferences" in trajs[0]['cond_info'].keys(): - batch.preferences = torch.stack([t['cond_info']["preferences"] for t in trajs]) - if "focus_dir" in trajs[0]['cond_info'].keys(): - batch.focus_dir = torch.stack([t['cond_info']["focus_dir"] for t in trajs]) + if "preferences" in trajs[0]["cond_info"].keys(): + batch.preferences = torch.stack([t["cond_info"]["preferences"] for t in trajs]) + if "focus_dir" in trajs[0]["cond_info"].keys(): + batch.focus_dir = torch.stack([t["cond_info"]["focus_dir"] for t in trajs]) if self.ctx.has_n() and self.cfg.algo.tb.do_predict_n: log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] From 215e85723234ddabe965852b05cd8182bb4b6ce1 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Thu, 4 Apr 2024 08:10:00 -0600 Subject: [PATCH 20/22] revert: back from multiplicative masking to -inf This reverts commit b3ee035f36b8be5ad53cd71403d083426dfe4b07. --- src/gflownet/envs/graph_building_env.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index fba86d9d..1b70c6e3 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -613,12 +613,8 @@ def _apply_action_masks(self): ) def _mask(self, x, m): - """ - mask logit vector x with binary mask m, -1000 is a tiny log-value - Note to self: we can't use torch.inf here, because inf * 0 is nan - """ assert m.dtype == torch.float - return x * m + -1000 * (1 - m) + return x.masked_fill(m == 0., -torch.inf) def detach(self): new = copy.copy(self) @@ -756,25 +752,8 @@ def sample(self) -> List[ActionIndex]: u = [torch.rand(i.shape, device=self.dev) for i in self._masked_logits] # Gumbel noise gumbel = [logit - (-noise.log()).log() for logit, noise in zip(self._masked_logits, u)] - - if self._action_masks is not None: - gumbel_safe = [ - torch.where( - mask == 1, - torch.maximum( - x, - torch.nextafter( - torch.tensor(torch.finfo(x.dtype).min, dtype=x.dtype), torch.tensor(0.0, dtype=x.dtype) - ).to(x.device), - ), - torch.finfo(x.dtype).min, - ) - for x, mask in zip(gumbel, self._action_masks) - ] - else: - gumbel_safe = gumbel # Take the argmax - return self.argmax(x=gumbel_safe) + return self.argmax(x=gumbel) def argmax( self, From 57da20ef838f9fc531b92eba0a47741054b5ac91 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Thu, 4 Apr 2024 11:07:40 -0600 Subject: [PATCH 21/22] tox --- src/gflownet/envs/graph_building_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 1b70c6e3..444f7015 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -614,7 +614,7 @@ def _apply_action_masks(self): def _mask(self, x, m): assert m.dtype == torch.float - return x.masked_fill(m == 0., -torch.inf) + return x.masked_fill(m == 0.0, -torch.inf) def detach(self): new = copy.copy(self) From d27a265cda4d807c9f189420caed037d0765ea6c Mon Sep 17 00:00:00 2001 From: julienroyd Date: Thu, 11 Apr 2024 13:46:24 -0600 Subject: [PATCH 22/22] refactor: moving action_type_to_mask in graph_building_env.py --- src/gflownet/algo/graph_sampling.py | 10 ++++++++-- src/gflownet/envs/frag_mol_env.py | 7 ------- src/gflownet/envs/graph_building_env.py | 6 ++++++ src/gflownet/models/graph_transformer.py | 4 ++-- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 296c212d..2776a22e 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -6,7 +6,13 @@ import torch.nn as nn from torch import Tensor -from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionCategorical, GraphActionType +from gflownet.envs.graph_building_env import ( + Graph, + GraphAction, + GraphActionCategorical, + GraphActionType, + action_type_to_mask, +) from gflownet.models.graph_transformer import GraphTransformerGFN @@ -256,7 +262,7 @@ def not_done(lst): else: gbatch = self.ctx.collate(torch_graphs) action_types = self.ctx.bck_action_type_order - action_masks = [self.ctx.action_type_to_mask(t, gbatch, assert_mask_exists=True) for t in action_types] + action_masks = [action_type_to_mask(t, gbatch, assert_mask_exists=True) for t in action_types] bck_cat = GraphActionCategorical( gbatch, raw_logits=[torch.ones_like(m) for m in action_masks], diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index c99196b9..8317e7dc 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -179,13 +179,6 @@ def GraphAction_to_ActionIndex(self, g: gd.Data, action: GraphAction) -> ActionI col = 1 return ActionIndex(action_type=type_idx, row_idx=int(row), col_idx=int(col)) - def action_type_to_mask(self, t: GraphActionType, gbatch: gd.Batch, assert_mask_exists: bool = False): - if assert_mask_exists: - assert hasattr(gbatch, t.mask_name), f"Mask {t.mask_name} not found in graph data" - return ( - getattr(gbatch, t.mask_name) if hasattr(gbatch, t.mask_name) else torch.ones((1, 1), device=gbatch.x.device) - ) - def graph_to_Data(self, g: Graph) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance Parameters diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 444f7015..20b109ef 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -1015,3 +1015,9 @@ def log_n(self, g) -> float: def traj_log_n(self, traj): return [self.log_n(g) for g, _ in traj] + + +def action_type_to_mask(t: GraphActionType, gbatch: gd.Batch, assert_mask_exists: bool = False): + if assert_mask_exists: + assert hasattr(gbatch, t.mask_name), f"Mask {t.mask_name} not found in graph data" + return getattr(gbatch, t.mask_name) if hasattr(gbatch, t.mask_name) else torch.ones((1, 1), device=gbatch.x.device) diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index f0b3981b..d0e29e72 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -9,7 +9,7 @@ from torch_geometric.utils import add_self_loops from gflownet.config import Config -from gflownet.envs.graph_building_env import GraphActionCategorical, GraphActionType +from gflownet.envs.graph_building_env import GraphActionCategorical, GraphActionType, action_type_to_mask def mlp(n_in, n_hid, n_out, n_layer, act=nn.LeakyReLU): @@ -238,7 +238,7 @@ def _make_cat(self, g: gd.Batch, emb: Dict[str, Tensor], action_types: list[Grap g, raw_logits=[self.mlps[t.cname](emb[self._action_type_to_graph_part[t]]) for t in action_types], keys=[self._action_type_to_key[t] for t in action_types], - action_masks=[self.env_ctx.action_type_to_mask(t, g) for t in action_types], + action_masks=[action_type_to_mask(t, g) for t in action_types], types=action_types, )