diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e62900b7..36b45ac5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,7 +32,7 @@ jobs: - name: Install python dependecies run: | python -m pip install --upgrade pip - pip install dill "e3nn-jax<0.19.2" jaxlib jax-md jaxopt pytest matplotlib + pip install dill jaxlib "jax-md>=0.2.7" jaxopt pytest matplotlib - name: Install pysages run: pip install . @@ -60,7 +60,7 @@ jobs: - name: Install python dependecies run: | python -m pip install --upgrade pip - pip install dill "e3nn-jax<0.19.2" jaxlib jax-md jaxopt pytest pylint flake8 + pip install dill jaxlib "jax-md>=0.2.7" jaxopt pytest pylint flake8 pip install -r docs/requirements.txt - name: Install pysages run: pip install . diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 72e568aa..da6558c6 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -9,6 +9,7 @@ - Pablo Zubieta - Ludwig Schneider +- [Trung Nguyen](https://github.com/ndtrung81) ## Collective Variables @@ -38,4 +39,4 @@ details. Specific contributions to this repository are listed below ## Other For other contributions such as bugfixes or performance improvements, take a look at -https://github.com/SSAGESLabs/PySAGES/graphs/contributors +https://github.com/SSAGESLabs/PySAGES/graphs/contributors. diff --git a/Dockerfile b/Dockerfile index ae3f3d4e..4b4cd3d5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ RUN python -m pip install ase gsd matplotlib "pyparsing<3" # Install JAX and JAX-MD RUN python -m pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -RUN python -m pip install --upgrade "e3nn-jax<0.19.2" jax-md jaxopt +RUN python -m pip install --upgrade "jax-md>=0.2.7" jaxopt COPY . /PySAGES RUN pip install /PySAGES/ diff --git a/examples/lammps/unbiased/lj.lmp b/examples/lammps/unbiased/lj.lmp new file mode 100644 index 00000000..00cf15d8 --- /dev/null +++ b/examples/lammps/unbiased/lj.lmp @@ -0,0 +1,21 @@ +# 3d Lennard-Jones melt + +units lj +atom_style atomic +atom_modify map yes + +lattice fcc 0.8442 +region box block 0 20 0 20 0 20 +create_box 1 box +create_atoms 1 box +mass 1 1.0 + +velocity all create 1.44 87287 loop geom + +pair_style lj/cut 2.5 +pair_coeff 1 1 1.0 1.0 2.5 + +neighbor 0.3 bin +neigh_modify delay 5 every 1 + +fix 1 all nve diff --git a/examples/lammps/unbiased/unbiased.py b/examples/lammps/unbiased/unbiased.py new file mode 100644 index 00000000..5c37075a --- /dev/null +++ b/examples/lammps/unbiased/unbiased.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 + +""" +Example unbiased simulation with pysages and lammps. + +For a list of possible options for running the script pass `-h` as argument from the +command line, or call `get_args(["-h"])` if the module was loaded interactively. +""" + +# %% +import argparse +import sys + +from lammps import lammps + +import pysages +from pysages.backends import SamplingContext +from pysages.colvars import Component +from pysages.methods import Unbiased + + +# %% +def generate_context(args="", script="lj.lmp", store_freq=1): + """ + Returns a lammps simulation defined by the contents of `script` using `args` as + initialization arguments. + """ + context = lammps(cmdargs=args.split()) + context.file(script) + # Allow for the retrieval of the unwrapped positions + context.command(f"fix unwrap all store/state {store_freq} xu yu zu") + return context + + +def get_args(argv): + """Process the command-line arguments to this script.""" + + available_args = [ + ("time-steps", "t", int, 1e2, "Number of simulation steps"), + ("kokkos", "k", bool, True, "Whether to use Kokkos acceleration"), + ] + parser = argparse.ArgumentParser(description="Example script to run pysages with lammps") + + for name, short, T, val, doc in available_args: + if T is bool: + action = "store_" + str(val).lower() + parser.add_argument("--" + name, "-" + short, action=action, help=doc) + else: + convert = (lambda x: int(float(x))) if T is int else T + parser.add_argument("--" + name, "-" + short, type=convert, default=T(val), help=doc) + + return parser.parse_args(argv) + + +def main(argv): + """Example simulation with pysages and lammps.""" + args = get_args(argv) + + context_args = {"store_freq": args.time_steps} + if args.kokkos: + # Passed to the lammps constructor as `cmdargs` when running the script + # with the --kokkos (or -k) option + context_args["args"] = "-k on g 1 -sf kk -pk kokkos newton on neigh half" + + # Setting the collective variable, method, and running the simulation + cvs = [Component([0], i) for i in range(3)] + method = Unbiased(cvs) + sampling_context = SamplingContext(method, generate_context, context_args=context_args) + result = pysages.run(sampling_context, args.time_steps) + + # Post-run analysis + # ----------------- + context = sampling_context.context + nlocal = sampling_context.sampler.view.local_particle_number() + snapshot = result.snapshots[0] + state = result.states[0] + + # Retrieve the pointer to the unwrapped positions, + ptr = context.extract_fix("unwrap", 1, 2) + # and make them available as a numpy ndarray + positions = context.numpy.darray(ptr, nlocal, dim=3) + # Get the map to sort the atoms since they can be reordered during the simulation + ids = context.numpy.extract_atom("id").argsort() + + # The ids of the final snapshot in pysages and lammps should be the same + assert (snapshot.ids == ids).all() + # For our example, the last value of the CV should match + # the unwrapped position of the zeroth atom + assert (state.xi.flatten() == positions[ids[0]]).all() + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/pysages/backends/core.py b/pysages/backends/core.py index 4ed1d695..8d02029b 100644 --- a/pysages/backends/core.py +++ b/pysages/backends/core.py @@ -82,10 +82,12 @@ def __init__( self._backend_name = "ase" elif module_name.startswith("hoomd"): self._backend_name = "hoomd" - elif module_name.startswith("simtk.openmm") or module_name.startswith("openmm"): - self._backend_name = "openmm" elif isinstance(context, JaxMDContext): self._backend_name = "jax-md" + elif module_name.startswith("lammps"): + self._backend_name = "lammps" + elif module_name.startswith("simtk.openmm") or module_name.startswith("openmm"): + self._backend_name = "openmm" if self._backend_name is None: backends = ", ".join(supported_backends()) @@ -113,14 +115,15 @@ def __enter__(self): """ if hasattr(self.context, "__enter__"): return self.context.__enter__() + return self.context def __exit__(self, exc_type, exc_value, exc_traceback): """ Trampoline 'with statements' to the wrapped context when the backend supports it. """ if hasattr(self.context, "__exit__"): - return self.context.__exit__(exc_type, exc_value, exc_traceback) + self.context.__exit__(exc_type, exc_value, exc_traceback) def supported_backends(): - return ("ase", "hoomd", "jax-md", "openmm") + return ("ase", "hoomd", "jax-md", "lammps", "openmm") diff --git a/pysages/backends/lammps.py b/pysages/backends/lammps.py new file mode 100644 index 00000000..376254e7 --- /dev/null +++ b/pysages/backends/lammps.py @@ -0,0 +1,254 @@ +# SPDX-License-Identifier: MIT +# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES + +""" +This module defines the Sampler class, which is a LAMMPS fix that enables any PySAGES +SamplingMethod to be hooked to a LAMMPS simulation instance. +""" + +import importlib +import weakref +from functools import partial + +import jax +from jax import jit +from jax import numpy as np +from jax import vmap +from jax.dlpack import from_dlpack as asarray +from lammps import dlext +from lammps.dlext import ExecutionSpace, FixDLExt, LAMMPSView, has_kokkos_cuda_enabled + +from pysages.backends import snapshot as pbs +from pysages.backends.core import SamplingContext +from pysages.backends.snapshot import ( + Box, + HelperMethods, + Snapshot, + SnapshotMethods, + build_data_querier, +) +from pysages.typing import Callable, Optional +from pysages.utils import copy, identity + +kDefaultLocation = dlext.kOnHost if not hasattr(ExecutionSpace, "kOnDevice") else dlext.kOnDevice + + +class Sampler(FixDLExt): # pylint: disable=R0902 + """ + LAMMPS fix that connects PySAGES sampling methods to LAMMPS simulations. + + Parameters + ---------- + context: ``lammps.core.lammps`` + The LAMMPS simulation instance to which the PySAGES sampling + machinery will be hooked. + sampling_method: ``SamplingMethod`` + The sampling method used. + callbacks: ``Optional[Callback]`` + An optional callback. Some methods define one for logging, + but it can also be user-defined. + location: ``lammps.dlext.ExecutionSpace`` + Device where the simulation data will be retrieved. + """ + + def __init__( + self, context, sampling_method, callback: Optional[Callable], location=kDefaultLocation + ): + super().__init__(context) + + on_gpu = (location != dlext.kOnHost) & has_kokkos_cuda_enabled(context) + location = location if on_gpu else dlext.kOnHost + + self.context = context + self.location = location + self.view = LAMMPSView(context) + + helpers, restore, bias = build_helpers(context, sampling_method, on_gpu, pbs.restore) + initial_snapshot = self.take_snapshot() + _, initialize, method_update = sampling_method.build(initial_snapshot, helpers) + + self.callback = callback + self.snapshot = initial_snapshot + self.state = initialize() + self._restore = restore + self._update_box = lambda: self.snapshot.box + + def update(timestep): + self.view.synchronize() + self.snapshot = self._update_snapshot() + self.state = method_update(self.snapshot, self.state) + bias(self.snapshot, self.state) + if self.callback: + self.callback(self.snapshot, self.state, timestep) + + self.set_callback(update) + + def _partial_snapshot(self, include_masses: bool = False): + positions = asarray(dlext.positions(self.view, self.location)) + types = asarray(dlext.types(self.view, self.location)) + velocities = asarray(dlext.velocities(self.view, self.location)) + forces = asarray(dlext.forces(self.view, self.location)) + tags_map = asarray(dlext.tags_map(self.view, self.location)) + imgs = asarray(dlext.images(self.view, self.location)) + + masses = None + if include_masses: + masses = asarray(dlext.masses(self.view, self.location)) + vel_mass = (velocities, (masses, types)) + + return Snapshot(positions, vel_mass, forces, tags_map, imgs, None, None) + + def _update_snapshot(self): + s = self._partial_snapshot() + velocities, (_, types) = s.vel_mass + _, (masses, _) = self.snapshot.vel_mass + vel_mass = (velocities, (masses, types)) + box = self._update_box() + dt = self.snapshot.dt + + return Snapshot(s.positions, vel_mass, s.forces, s.ids[1:], s.images, box, dt) + + def restore(self, prev_snapshot): + """Replaces this sampler's snapshot with `prev_snapshot`.""" + self._restore(self.snapshot, prev_snapshot) + + def take_snapshot(self): + """Returns a copy of the current snapshot of the system.""" + s = self._partial_snapshot(include_masses=True) + box = Box(*get_global_box(self.context)) + dt = get_timestep(self.context) + + return Snapshot( + copy(s.positions), copy(s.vel_mass), copy(s.forces), s.ids[1:], copy(s.images), box, dt + ) + + +def build_helpers(context, sampling_method, on_gpu, restore_fn): + """ + Builds helper methods used for restoring snapshots and biasing a simulation. + """ + # Depending on the device being used we need to use either cupy or numpy + # (or numba) to generate a view of jax's DeviceArrays + if on_gpu: + cupy = importlib.import_module("cupy") + view = cupy.asarray + + def sync_forces(): + cupy.cuda.get_current_stream().synchronize() + + else: + utils = importlib.import_module(".utils", package="pysages.backends") + view = utils.view + + def sync_forces(): + pass + + # TODO: check if this can be sped up. # pylint: disable=W0511 + def bias(snapshot, state): + """Adds the computed bias to the forces.""" + if state.bias is None: + return + forces = view(snapshot.forces) + biases = view(state.bias.block_until_ready()) + forces[:, :3] += biases + sync_forces() + + snapshot_methods = build_snapshot_methods(sampling_method, on_gpu) + flags = sampling_method.snapshot_flags + restore = partial(restore_fn, view) + helpers = HelperMethods(build_data_querier(snapshot_methods, flags), get_dimension(context)) + + return helpers, restore, bias + + +def build_snapshot_methods(sampling_method, on_gpu): + """ + Builds methods for retrieving snapshot properties in a format useful for collective + variable calculations. + """ + + if sampling_method.requires_box_unwrapping: + device = jax.devices("gpu" if on_gpu else "cpu")[0] + dtype = np.int64 if dlext.kImgBitSize == 64 else np.int32 + offset = dlext.kImgMax + + with jax.default_device(device): + bits = np.asarray((0, dlext.kImgBits, dlext.kImg2Bits), dtype=dtype) + mask = np.asarray((dlext.kImgMask, dlext.kImgMask, -1), dtype=dtype) + + def unpack(image): + return (image >> bits & mask) - offset + + def positions(snapshot): + L = np.diag(snapshot.box.H) + return snapshot.positions[:, :3] + L * vmap(unpack)(snapshot.images) + + else: + + def positions(snapshot): + return snapshot.positions + + @jit + def indices(snapshot): + return snapshot.ids + + @jit + def momenta(snapshot): + V, (masses, types) = snapshot.vel_mass + M = masses[types] + return (M * V).flatten() + + @jit + def masses(snapshot): + return snapshot.vel_mass[:, 3:] + + return SnapshotMethods(jit(positions), indices, momenta, masses) + + +def get_dimension(context): + """Get the dimensionality of a LAMMPS simulation.""" + return context.extract_setting("dimension") + + +def get_global_box(context): + """Get the box and origin of a LAMMPS simulation.""" + boxlo, boxhi, xy, yz, xz, *_ = context.extract_box() + Lx = boxhi[0] - boxlo[0] + Ly = boxhi[1] - boxlo[1] + Lz = boxhi[2] - boxlo[2] + origin = boxlo + H = ((Lx, xy * Ly, xz * Lz), (0.0, Ly, yz * Lz), (0.0, 0.0, Lz)) + return H, origin + + +def get_timestep(context): + """Get the timestep of a LAMMPS simulation.""" + return context.extract_global("dt") + + +def bind(sampling_context: SamplingContext, callback: Optional[Callable], **kwargs): + """ + Defines and sets up a Sampler to perform an enhanced sampling simulation. + + This function takes a ``sampling_context`` that has its context attribute as an instance + of a LAMMPS simulation, and creates a ``Sampler`` object that connects the PySAGES + sampling method to the LAMMPS simulation. It also modifies the sampling_context's view + and run attributes to use the sampler's view and the LAMMPS run command. + """ + identity(kwargs) # we ignore the kwargs for now + + context = sampling_context.context + sampling_method = sampling_context.method + sampler = Sampler(context, sampling_method, callback) + sampling_context.view = sampler.view + sampling_context.run = lambda n, **kwargs: context.command(f"run {n}") + + # We want to support backends that also are context managers as long + # as the simulation is kept alive after exiting the context. + # Unfortunately, the default implementation of `lammps.__exit__` closes + # the lammps instance, so we need to overwrite it. + context.__exit__ = lambda *args: None + # Ensure that the lammps context is properly finalized. + weakref.finalize(context, context.finalize) + + return sampler diff --git a/pysages/methods/core.py b/pysages/methods/core.py index b76cb483..6c10fd68 100644 --- a/pysages/methods/core.py +++ b/pysages/methods/core.py @@ -16,7 +16,7 @@ from pysages.methods.restraints import canonicalize from pysages.methods.utils import ReplicasConfiguration from pysages.typing import Callable, Optional, Union -from pysages.utils import ToCPU, copy, dispatch, dispatch_table, identity +from pysages.utils import ToCPU, copy, dispatch, dispatch_table, has_method, identity # Base Classes # ============ @@ -449,10 +449,7 @@ def has_custom_run(method: type): """ Determine if ``method`` has a specialized ``run`` implementation. """ - custom_run_methods = set() - for sig in dispatch_table(dispatch)["run"].methods.keys(): - custom_run_methods.update(sig.types[0].get_types()) - return method in custom_run_methods + return has_method(dispatch_table(dispatch)["run"], method, 0) def generalize(concrete_update, helpers, jit_compile=True): diff --git a/pysages/ml/utils.py b/pysages/ml/utils.py index eb90e2a1..753ffe9d 100644 --- a/pysages/ml/utils.py +++ b/pysages/ml/utils.py @@ -63,7 +63,7 @@ def pack(params, layout): previously flatten with `unpack`. """ structure, shapes, separators = layout - partition = params.split(separators) + partition = np.split(params, separators) ps = [p.reshape(s) for (p, s) in zip(partition, shapes)] return structure.unflatten(ps) diff --git a/pysages/utils/__init__.py b/pysages/utils/__init__.py index 5bfcf39b..02e3d1be 100644 --- a/pysages/utils/__init__.py +++ b/pysages/utils/__init__.py @@ -11,6 +11,7 @@ from .compat import ( check_device_array, dispatch_table, + has_method, is_generic_subclass, solve_pos_def, try_import, diff --git a/pysages/utils/compat.py b/pysages/utils/compat.py index 95f1480c..99084900 100644 --- a/pysages/utils/compat.py +++ b/pysages/utils/compat.py @@ -56,13 +56,30 @@ def solve_pos_def(a, b): def dispatch_table(dispatch): return dispatch._functions + def has_method(fn, T, index): + types_at_index = set() + for sig in fn.methods.keys(): + types_at_index.update(sig.types[index].get_types()) + return T in types_at_index + is_generic_subclass = issubclass else: _bt = import_module("beartype.door") + _pm = import_module("plum") def dispatch_table(dispatch): return dispatch.functions + def has_method(fn, T, index): + types_at_index = set() + for sig in fn.methods: + typ = sig.types[index] + if _pm.get_origin(typ) is _pm.Union: + types_at_index.update(_pm.get_args(typ)) + else: + types_at_index.add(typ) + return T in types_at_index + def is_generic_subclass(A, B): return _bt.TypeHint(A) <= _bt.TypeHint(B) diff --git a/tests/test_pickle.py b/tests/test_pickle.py index 541e9de5..1bcf073f 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -1,8 +1,8 @@ +import importlib import inspect import tempfile import dill as pickle -import jax_md as jmd import numpy as np import pysages @@ -14,7 +14,9 @@ def build_neighbor_list(box_size, positions, r_cutoff, capacity_multiplier): """Helper function to generate a jax-md neighbor list""" - displacement_fn, shift_fn = jmd.space.periodic(box_size) + jmd = importlib.import_module("jax_md") + + displacement_fn, _ = jmd.space.periodic(box_size) neighbor_list_fn = jmd.partition.neighbor_list( displacement_fn, box_size, @@ -23,6 +25,7 @@ def build_neighbor_list(box_size, positions, r_cutoff, capacity_multiplier): format=jmd.partition.NeighborListFormat.Dense, ) neighbors = neighbor_list_fn.allocate(positions) + return neighbors @@ -136,9 +139,11 @@ def test_pickle_methods(): "number_of_opt_it": 10, "standard_deviation": 0.125, "mesh_size": 30, - "nbrs": build_neighbor_list( - 2.0, positions=np.random.randn(20, 3), r_cutoff=1.5, capacity_multiplier=1.0 - ), + "nbrs": None, + # Disable build_neighbor_list until jax_md stabilizes + # "nbrs": build_neighbor_list( + # 2.0, positions=np.random.randn(20, 3), r_cutoff=1.5, capacity_multiplier=1.0 + # ), "fractional_coords": True, }, }