Skip to content

Commit

Permalink
Add unprocessed methods (#111)
Browse files Browse the repository at this point in the history
* add unprocessed methods

* update changelog
  • Loading branch information
lilyminium authored Apr 8, 2024
1 parent 1df57aa commit 7a51b4f
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ The rules for this file:
### Authors
- @lilyminium

### Added
- Helper methods for debugging (PR #111)

### Changed
- Moved from importlib_resources to plain importlib with 3.9+

Expand Down
15 changes: 13 additions & 2 deletions openff/nagl/nn/_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,24 @@ def __init__(
self.postprocess_layer = postprocess_layer

def forward(self, molecule: Union[DGLMolecule, DGLMoleculeBatch]) -> torch.Tensor:
x = self.pooling_layer.forward(molecule)
x = self.readout_layers.forward(x)
x = self._forward_unpostprocessed(molecule)
if self.postprocess_layer is not None:
x = self.postprocess_layer.forward(molecule, x)

return x

def _forward_unpostprocessed(
self, molecule: Union[DGLMolecule, DGLMoleculeBatch]
) -> torch.Tensor:
"""
Forward pass without postprocessing the readout modules.
This is quality-of-life method for debugging and testing.
It is *not* intended for public use.
"""
x = self.pooling_layer.forward(molecule)
x = self.readout_layers.forward(x)
return x

def copy(self, copy_weights: bool = False):
pooling = type(self.pooling_layer)()
readout = self.readout_layers.copy(copy_weights=copy_weights)
Expand Down
39 changes: 36 additions & 3 deletions openff/nagl/nn/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ def forward(
}
return readouts

def _forward_unpostprocessed(self, molecule: "DGLMoleculeOrBatch"):
"""
Forward pass without postprocessing the readout modules.
This is quality-of-life method for debugging and testing.
It is *not* intended for public use.
"""
self.convolution_module(molecule)
readouts: Dict[str, torch.Tensor] = {
readout_type: readout_module._forward_unpostprocessed(molecule)
for readout_type, readout_module in self.readout_modules.items()
}
return readouts


class GNNModel(BaseGNNModel):
def __init__(
self,
Expand Down Expand Up @@ -140,7 +154,7 @@ def compute_properties(
if as_numpy:
values = {k: v.detach().numpy().flatten() for k, v in values.items()}
return values

def compute_property(
self,
molecule: "Molecule",
Expand Down Expand Up @@ -219,8 +233,25 @@ def _compute_properties_dgl(self, molecule: "Molecule") -> "torch.Tensor":
)
return self.forward(dglmol)

def _convert_to_nagl_molecule(self, molecule: "Molecule"):
from openff.nagl.molecule._graph.molecule import GraphMolecule
if self._is_dgl:
from openff.nagl.molecule._dgl.molecule import DGLMolecule

return DGLMolecule.from_openff(
molecule,
atom_features=self.config.atom_features,
bond_features=self.config.bond_features,
)

return GraphMolecule.from_openff(
molecule,
atom_features=self.config.atom_features,
bond_features=self.config.bond_features,
)

@classmethod
def load(cls, model: str, eval_mode: bool = True):
def load(cls, model: str, eval_mode: bool = True, **kwargs):
"""
Load a model from a file.
Expand All @@ -234,6 +265,8 @@ def load(cls, model: str, eval_mode: bool = True):
This can be created using the `save` method.
eval_mode: bool
Whether to set the model to evaluation mode.
**kwargs
Additional keyword arguments to pass to `torch.load`.
Returns
-------
Expand All @@ -251,7 +284,7 @@ def load(cls, model: str, eval_mode: bool = True):
models saved with ``torch.save``, as it expects
a dictionary of hyperparameters and a state dictionary.
"""
model_kwargs = torch.load(str(model))
model_kwargs = torch.load(str(model), **kwargs)
if isinstance(model_kwargs, dict):
model = cls(**model_kwargs["hyperparameters"])
model.load_state_dict(model_kwargs["state_dict"])
Expand Down
26 changes: 26 additions & 0 deletions openff/nagl/tests/nn/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,32 @@ def test_load_and_compute(self, smiles):

assert_allclose(computed, desired, atol=1e-5)

def test_forward_unpostprocessed(self):
dgl = pytest.importorskip("dgl")
from openff.toolkit import Molecule

model = GNNModel.load(EXAMPLE_AM1BCC_MODEL, eval_mode=True)
molecule = Molecule.from_smiles("C")
nagl_mol = model._convert_to_nagl_molecule(molecule)
unpostprocessed = model._forward_unpostprocessed(nagl_mol)
computed = unpostprocessed["am1bcc_charges"].detach().cpu().numpy()
assert computed.shape == (5, 2)
expected = np.array([
[ 0.166862, 5.489722],
[-0.431665, 5.454424],
[-0.431665, 5.454424],
[-0.431665, 5.454424],
[-0.431665, 5.454424],
])
assert_allclose(computed, expected, atol=1e-5)

def test_load_model_with_kwargs(self):
GNNModel.load(
EXAMPLE_AM1BCC_MODEL,
eval_mode=True,
map_location=torch.device('cpu')
)

def test_protein_computable(self):
"""
Test that working with moderately sized protein
Expand Down

0 comments on commit 7a51b4f

Please sign in to comment.