Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support importing (some) virtual sites in Interchange.from_openmm #1081

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions openff/interchange/_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,28 @@


@pytest.fixture
def sage():
def sage() -> ForceField:
return ForceField("openff-2.0.0.offxml")


@pytest.fixture
def sage_unconstrained():
def sage_unconstrained() -> ForceField:
return ForceField("openff_unconstrained-2.0.0.offxml")


@pytest.fixture
def sage_no_switch(sage):
def sage_no_switch(sage) -> ForceField:
sage["vdW"].switch_width = Quantity(0.0, "angstrom")
return sage


@pytest.fixture
def sage_with_tip4p() -> ForceField:
# re-build off of existing fixtures if this gets implemented
# https:/openforcefield/openff-toolkit/issues/1948
return ForceField("openff-2.0.0.offxml", "tip4p.offxml")


@pytest.fixture
def sage_with_bond_charge(sage):
sage["Bonds"].add_parameter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@

class TestUnsupportedCases:
@pytest.mark.filterwarnings("ignore:.*are you sure you don't want to pass positions")
def test_error_topology_mismatch(self, monkeypatch, sage_unconstrained, ethanol):
monkeypatch.setenv("INTERCHANGE_EXPERIMENTAL", "1")

def test_error_topology_mismatch(self, sage_unconstrained, ethanol):
topology = ethanol.to_topology()
topology.box_vectors = Quantity([4, 4, 4], "nanometer")

Expand All @@ -38,23 +36,39 @@ def test_error_topology_mismatch(self, monkeypatch, sage_unconstrained, ethanol)
topology=other_topology.to_openmm(),
)

def test_found_virtual_sites(self, monkeypatch, tip4p, water):
monkeypatch.setenv("INTERCHANGE_EXPERIMENTAL", "1")

def test_found_out_of_plane_virtual_site(self, tip5p, water):
topology = water.to_topology()
topology.box_vectors = Quantity([4, 4, 4], "nanometer")

system = tip4p.create_openmm_system(topology)
system = tip5p.create_openmm_system(topology)

with pytest.raises(
UnsupportedImportError,
match="A particle is a virtual site, which is not yet supported.",
match="A particle is an `outOfPlane` virtual site, which is not yet supported.",
):
from_openmm(
system=system,
topology=topology.to_openmm(),
)

def test_found_two_particle_average_virtual_site(
self,
sage_with_bond_charge,
default_integrator,
):
simulation = sage_with_bond_charge.create_interchange(
Molecule.from_smiles("CCl").to_topology(),
).to_openmm_simulation(integrator=default_integrator)

with pytest.raises(
UnsupportedImportError,
match="A particle is a `TwoParticleAverage` virtual site, which is not yet supported.",
):
from_openmm(
system=simulation.system,
topology=simulation.topology,
)

def test_missing_positions_warning(self, monkeypatch, sage, water):
monkeypatch.setenv("INTERCHANGE_EXPERIMENTAL", "1")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest
from openff.toolkit import Topology

from openff.interchange import Interchange
from openff.interchange.components._packmol import solvate_topology
from openff.interchange.drivers.openmm import _get_openmm_energies, get_openmm_energies


class TestTIP4PVirtualSites:
def test_tip4p_openmm_xml(self, water_dimer):
"""
Prepare a TIP4P water dimer with OpenMM's style of 4-site water.

Below is used as a guide
https://openmm.github.io/openmm-cookbook/latest/notebooks/tutorials/Histone_methyltransferase_simulation_with_a_multisite_water_model_TIP4P-Ew.html
"""
pytest.importorskip("openmm")

import openmm.app

modeller = openmm.app.Modeller(
topology=water_dimer.to_openmm_topology(),
positions=water_dimer.get_positions().to("nanometer").to_openmm(),
)

forcefield = openmm.app.ForceField("tip4pew.xml")

modeller.addExtraParticles(forcefield=forcefield)

system = forcefield.createSystem(
modeller.topology,
nonbondedMethod=openmm.app.PME,
nonbondedCutoff=1.0 * openmm.unit.nanometers,
constraints=openmm.app.HBonds,
rigidWater=True,
ewaldErrorTolerance=0.0005,
)

imported = Interchange.from_openmm(
topology=modeller.topology,
system=system,
)

get_openmm_energies(imported)

def test_dimer_energy_equals(self, tip4p, water_dimer):
out: Interchange = tip4p.create_interchange(water_dimer)

roundtripped = Interchange.from_openmm(
system=out.to_openmm_system(),
topology=out.to_openmm_topology(collate=False),
positions=out.get_positions(include_virtual_sites=True).to_openmm(),
box_vectors=out.box.to_openmm(),
)

assert get_openmm_energies(out) == _get_openmm_energies(roundtripped)

def test_minimize_solvated_ligand(self, sage_with_tip4p, ethanol, default_integrator):
topology = solvate_topology(ethanol.to_topology())

simulation = sage_with_tip4p.create_simulation(
topology,
).to_openmm_simulation(
integrator=default_integrator,
)

roundtripped = Interchange.from_openmm(
system=simulation.system,
topology=simulation.topology,
positions=simulation.context.getState(getPositions=True).getPositions(),
box_vectors=simulation.system.getDefaultPeriodicBoxVectors(),
)

original_energy = get_openmm_energies(simulation)

# TODO: Much more validation could be done here, but if a simulation
# can start and minimize at all, that should catch most problems
roundtripped.minimize()

assert get_openmm_energies(roundtripped) < original_energy

def test_error_index_mismatch(self, tip4p, water):
out: Interchange = tip4p.create_interchange(Topology.from_molecules([water, water]))

with pytest.raises(
ValueError, # TODO: Make a different error
match="The number of particles in the system and the number of atoms in the topology do not match.",
):
Interchange.from_openmm(
system=out.to_openmm_system(),
topology=out.to_openmm_topology(collate=True),
)
76 changes: 59 additions & 17 deletions openff/interchange/components/toolkit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utilities for processing and interfacing with the OpenFF Toolkit."""

from typing import TYPE_CHECKING, Union
from collections import defaultdict
from typing import TYPE_CHECKING

import networkx
import numpy
Expand All @@ -10,6 +11,8 @@
from openff.toolkit.utils.collections import ValidatedList
from openff.utilities.utilities import has_package

from openff.interchange.models import ImportedVirtualSiteKey

if has_package("openmm") or TYPE_CHECKING:
import openmm.app

Expand All @@ -25,7 +28,7 @@ def _get_num_h_bonds(topology: "Topology") -> int:
return n_bonds_containing_hydrogen


def _get_14_pairs(topology_or_molecule: Union["Topology", "Molecule"]):
def _get_14_pairs(topology_or_molecule: Topology | Molecule):
"""Generate tuples of atom pairs, including symmetric duplicates."""
# TODO: A replacement of Topology.nth_degree_neighbors in the toolkit
# may implement this in the future.
Expand Down Expand Up @@ -105,35 +108,71 @@ def _check_electrostatics_handlers(force_field: "ForceField") -> bool:
return False


def _simple_topology_from_openmm(openmm_topology: "openmm.app.Topology") -> Topology:
def _simple_topology_from_openmm(openmm_topology: "openmm.app.Topology", system: openmm.System) -> Topology:
"""Convert an OpenMM Topology into an OpenFF Topology consisting **only** of so-called `_SimpleMolecule`s."""
# TODO: Splice in fully-defined OpenFF `Molecule`s?

graph = networkx.Graph()

virtual_sites: list[ImportedVirtualSiteKey] = list()

# map indices of OpenMM system with virtual site to topology indices, or virtual site keys,
# in associated OpenFF topology, since stripping out virtual sites offsets particle indices
openmm_openff_particle_map: dict[int, int | ImportedVirtualSiteKey] = dict()

# TODO: This is nearly identical to Topology._openmm_topology_to_networkx.
# Should this method be replaced with a direct call to that?
for atom in openmm_topology.atoms():
graph.add_node(
atom.index,
atomic_number=atom.element.atomic_number,
name=atom.name,
residue_name=atom.residue.name,
# Note that residue number is mapped to residue.id here. The use of id vs. number varies in other packages
# and the convention for the OpenFF-OpenMM interconversion is recorded at
# https://docs.openforcefield.org/projects/toolkit/en/0.15.1/users/molecule_conversion.html
residue_number=atom.residue.id,
insertion_code=atom.residue.insertionCode,
chain_id=atom.residue.chain.id,
)
if atom.element is None:
virtual_sites.append(
# assume ThreeParticleAverageSite for now
ImportedVirtualSiteKey(
orientation_atom_indices=[system.getVirtualSite(atom.index).getParticle(i) for i in range(3)],
name=atom.name,
type="ThreeParticleAverageSite",
),
)

openmm_openff_particle_map[atom.index] = virtual_sites[-1]

else:
graph.add_node(
atom.index,
atomic_number=getattr(atom.element, "atomic_number", 0),
name=atom.name,
residue_name=atom.residue.name,
# Note that residue number is mapped to residue.id here. The use of id vs. number
# varies in other packages and the convention for the OpenFF-OpenMM interconversion
# is recorded at
# https://docs.openforcefield.org/projects/toolkit/en/0.15.1/users/molecule_conversion.html
residue_number=atom.residue.id,
insertion_code=atom.residue.insertionCode,
chain_id=atom.residue.chain.id,
)

openmm_openff_particle_map[atom.index] = atom.index - len(virtual_sites)

for bond in openmm_topology.bonds():
graph.add_edge(
bond.atom1.index,
bond.atom2.index,
)

return _simple_topology_from_graph(graph)
topology = _simple_topology_from_graph(graph)

topology._molecule_virtual_site_map = defaultdict(list)

# TODO: This iteration strategy scales horribly with system size - need to refactor -
# since looking up topology atom indices is slow. It's probably repetitive to
# look up the molecule index over and over again
for particle in virtual_sites:
molecule_index = topology.molecule_index(topology.atom(particle.orientation_atom_indices[0]).molecule)

topology._molecule_virtual_site_map[molecule_index].append(particle)

topology._particle_map = openmm_openff_particle_map

return topology


def _simple_topology_from_graph(graph: networkx.Graph) -> Topology:
Expand All @@ -147,7 +186,10 @@ def _simple_topology_from_graph(graph: networkx.Graph) -> Topology:
# the subgraphs are returned out of "atom order", like
# if atoms in an later molecule have lesser atom indices
# than this molecule
assert topology.n_atoms == next(iter(subgraph.nodes))
#
# Oct 2024 - need to add a test case for above?
# these values are not necessarily equal because of virtual sites
assert topology.n_atoms <= next(iter(subgraph.nodes))

topology.add_molecule(_SimpleMolecule._from_subgraph(subgraph))

Expand Down
11 changes: 11 additions & 0 deletions openff/interchange/interop/_virtual_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ def _virtual_site_parent_molecule_mapping(
return mapping


def _get_molecule_virtual_site_map(interchange: Interchange) -> defaultdict[int, list[VirtualSiteKey]]:
virtual_site_molecule_map = _virtual_site_parent_molecule_mapping(interchange)

molecule_virtual_site_map = defaultdict(list)

for virtual_site, molecule_index in virtual_site_molecule_map.items():
molecule_virtual_site_map[molecule_index].append(virtual_site)

return molecule_virtual_site_map


def get_positions_with_virtual_sites(
interchange: Interchange,
collate: bool = False,
Expand Down
Loading
Loading