Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LAMMPS backend #274

Merged
merged 7 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- name: Install python dependecies
run: |
python -m pip install --upgrade pip
pip install dill "e3nn-jax<0.19.2" jaxlib jax-md jaxopt pytest matplotlib
pip install dill jaxlib "jax-md>=0.2.7" jaxopt pytest matplotlib

- name: Install pysages
run: pip install .
Expand Down Expand Up @@ -60,7 +60,7 @@ jobs:
- name: Install python dependecies
run: |
python -m pip install --upgrade pip
pip install dill "e3nn-jax<0.19.2" jaxlib jax-md jaxopt pytest pylint flake8
pip install dill jaxlib "jax-md>=0.2.7" jaxopt pytest pylint flake8
pip install -r docs/requirements.txt
- name: Install pysages
run: pip install .
Expand Down
3 changes: 2 additions & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

- Pablo Zubieta
- Ludwig Schneider
- [Trung Nguyen](https:/ndtrung81)

## Collective Variables

Expand Down Expand Up @@ -38,4 +39,4 @@ details. Specific contributions to this repository are listed below
## Other

For other contributions such as bugfixes or performance improvements, take a look at
https:/SSAGESLabs/PySAGES/graphs/contributors
https:/SSAGESLabs/PySAGES/graphs/contributors.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ RUN python -m pip install ase gsd matplotlib "pyparsing<3"

# Install JAX and JAX-MD
RUN python -m pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
RUN python -m pip install --upgrade "e3nn-jax<0.19.2" jax-md jaxopt
RUN python -m pip install --upgrade "jax-md>=0.2.7" jaxopt

COPY . /PySAGES
RUN pip install /PySAGES/
21 changes: 21 additions & 0 deletions examples/lammps/unbiased/lj.lmp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# 3d Lennard-Jones melt

units lj
atom_style atomic
atom_modify map yes

lattice fcc 0.8442
region box block 0 20 0 20 0 20
create_box 1 box
create_atoms 1 box
mass 1 1.0

velocity all create 1.44 87287 loop geom

pair_style lj/cut 2.5
pair_coeff 1 1 1.0 1.0 2.5

neighbor 0.3 bin
neigh_modify delay 5 every 1

fix 1 all nve
93 changes: 93 additions & 0 deletions examples/lammps/unbiased/unbiased.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#!/usr/bin/env python3

"""
Example unbiased simulation with pysages and lammps.

For a list of possible options for running the script pass `-h` as argument from the
command line, or call `get_args(["-h"])` if the module was loaded interactively.
"""

# %%
import argparse
import sys

from lammps import lammps

import pysages
from pysages.backends import SamplingContext
from pysages.colvars import Component
from pysages.methods import Unbiased


# %%
def generate_context(args="", script="lj.lmp", store_freq=1):
"""
Returns a lammps simulation defined by the contents of `script` using `args` as
initialization arguments.
"""
context = lammps(cmdargs=args.split())
context.file(script)
# Allow for the retrieval of the unwrapped positions
context.command(f"fix unwrap all store/state {store_freq} xu yu zu")
return context


def get_args(argv):
"""Process the command-line arguments to this script."""

available_args = [
("time-steps", "t", int, 1e2, "Number of simulation steps"),
("kokkos", "k", bool, True, "Whether to use Kokkos acceleration"),
]
parser = argparse.ArgumentParser(description="Example script to run pysages with lammps")

for name, short, T, val, doc in available_args:
if T is bool:
action = "store_" + str(val).lower()
parser.add_argument("--" + name, "-" + short, action=action, help=doc)
else:
convert = (lambda x: int(float(x))) if T is int else T
parser.add_argument("--" + name, "-" + short, type=convert, default=T(val), help=doc)

return parser.parse_args(argv)


def main(argv):
"""Example simulation with pysages and lammps."""
args = get_args(argv)

context_args = {"store_freq": args.time_steps}
if args.kokkos:
# Passed to the lammps constructor as `cmdargs` when running the script
# with the --kokkos (or -k) option
context_args["args"] = "-k on g 1 -sf kk -pk kokkos newton on neigh half"

# Setting the collective variable, method, and running the simulation
cvs = [Component([0], i) for i in range(3)]
method = Unbiased(cvs)
sampling_context = SamplingContext(method, generate_context, context_args=context_args)
result = pysages.run(sampling_context, args.time_steps)

# Post-run analysis
# -----------------
context = sampling_context.context
nlocal = sampling_context.sampler.view.local_particle_number()
snapshot = result.snapshots[0]
state = result.states[0]

# Retrieve the pointer to the unwrapped positions,
ptr = context.extract_fix("unwrap", 1, 2)
# and make them available as a numpy ndarray
positions = context.numpy.darray(ptr, nlocal, dim=3)
# Get the map to sort the atoms since they can be reordered during the simulation
ids = context.numpy.extract_atom("id").argsort()

# The ids of the final snapshot in pysages and lammps should be the same
assert (snapshot.ids == ids).all()
# For our example, the last value of the CV should match
# the unwrapped position of the zeroth atom
assert (state.xi.flatten() == positions[ids[0]]).all()


if __name__ == "__main__":
main(sys.argv[1:])
11 changes: 7 additions & 4 deletions pysages/backends/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,12 @@ def __init__(
self._backend_name = "ase"
elif module_name.startswith("hoomd"):
self._backend_name = "hoomd"
elif module_name.startswith("simtk.openmm") or module_name.startswith("openmm"):
self._backend_name = "openmm"
elif isinstance(context, JaxMDContext):
self._backend_name = "jax-md"
elif module_name.startswith("lammps"):
self._backend_name = "lammps"
elif module_name.startswith("simtk.openmm") or module_name.startswith("openmm"):
self._backend_name = "openmm"

if self._backend_name is None:
backends = ", ".join(supported_backends())
Expand Down Expand Up @@ -113,14 +115,15 @@ def __enter__(self):
"""
if hasattr(self.context, "__enter__"):
return self.context.__enter__()
return self.context

def __exit__(self, exc_type, exc_value, exc_traceback):
"""
Trampoline 'with statements' to the wrapped context when the backend supports it.
"""
if hasattr(self.context, "__exit__"):
return self.context.__exit__(exc_type, exc_value, exc_traceback)
self.context.__exit__(exc_type, exc_value, exc_traceback)


def supported_backends():
return ("ase", "hoomd", "jax-md", "openmm")
return ("ase", "hoomd", "jax-md", "lammps", "openmm")
Loading