Skip to content

Commit

Permalink
Add docs for the lammps backend
Browse files Browse the repository at this point in the history
Co-authored-by: Pablo Zubieta <[email protected]>
  • Loading branch information
ndtrung81 and pabloferz committed Aug 3, 2023
1 parent ab3cd79 commit 57485a9
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 5 deletions.
3 changes: 2 additions & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

- Pablo Zubieta
- Ludwig Schneider
- [Trung Nguyen](https:/ndtrung81)

## Collective Variables

Expand Down Expand Up @@ -38,4 +39,4 @@ details. Specific contributions to this repository are listed below
## Other

For other contributions such as bugfixes or performance improvements, take a look at
https:/SSAGESLabs/PySAGES/graphs/contributors
https:/SSAGESLabs/PySAGES/graphs/contributors.
54 changes: 50 additions & 4 deletions pysages/backends/lammps.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# SPDX-License-Identifier: MIT
# See LICENSE.md and CONTRIBUTORS.md at https:/SSAGESLabs/PySAGES

# Maintainer: ndtrung
"""
This module defines the Sampler class, which is a LAMMPS fix that enables any PySAGES
SamplingMethod to be hooked to a LAMMPS simulation instance.
"""

import importlib
from functools import partial
Expand All @@ -24,12 +27,29 @@
build_data_querier,
)
from pysages.typing import Callable, Optional
from pysages.utils import copy
from pysages.utils import copy, identity

kDefaultLocation = dlext.kOnHost if not hasattr(ExecutionSpace, "kOnDevice") else dlext.kOnDevice


class Sampler(FixDLExt):
class Sampler(FixDLExt): # pylint: disable=R0902
"""
LAMMPS fix that connects PySAGES sampling methods to LAMMPS simulations.
Parameters
----------
context: ``lammps.core.lammps``
The LAMMPS simulation instance to which the PySAGES sampling
machinery will be hooked.
sampling_method: ``SamplingMethod``
The sampling method used.
callbacks: ``Optional[Callback]``
An optional callback. Some methods define one for logging,
but it can also be user-defined.
location: ``lammps.dlext.ExecutionSpace``
Device where the simulation data will be retrieved.
"""

def __init__(
self, context, sampling_method, callback: Optional[Callable], location=kDefaultLocation
):
Expand All @@ -47,9 +67,9 @@ def __init__(
_, initialize, method_update = sampling_method.build(initial_snapshot, helpers)

self.callback = callback
self.restore = lambda prev_snapshot: restore(self.snapshot, prev_snapshot)
self.snapshot = initial_snapshot
self.state = initialize()
self._restore = restore
self._update_box = lambda: self.snapshot.box

def update(timestep):
Expand Down Expand Up @@ -87,7 +107,12 @@ def _update_snapshot(self):

return Snapshot(s.positions, vel_mass, s.forces, s.ids[1:], s.images, box, dt)

def restore(self, prev_snapshot):
"""Replaces this sampler's snapshot with `prev_snapshot`."""
self._restore(self.snapshot, prev_snapshot)

def take_snapshot(self):
"""Returns a copy of the current snapshot of the system."""
s = self._partial_snapshot(include_masses=True)
box = Box(*get_global_box(self.context))
dt = get_timestep(self.context)
Expand All @@ -98,6 +123,9 @@ def take_snapshot(self):


def build_helpers(context, sampling_method, on_gpu, restore_fn):
"""
Builds helper methods used for restoring snapshots and biasing a simulation.
"""
# Depending on the device being used we need to use either cupy or numpy
# (or numba) to generate a view of jax's DeviceArrays
if on_gpu:
Expand Down Expand Up @@ -133,6 +161,11 @@ def bias(snapshot, state):


def build_snapshot_methods(sampling_method, on_gpu):
"""
Builds methods for retrieving snapshot properties in a format useful for collective
variable calculations.
"""

if sampling_method.requires_box_unwrapping:
device = jax.devices("gpu" if on_gpu else "cpu")[0]
dtype = np.int64 if dlext.kImgBitSize == 64 else np.int32
Expand Down Expand Up @@ -172,10 +205,12 @@ def masses(snapshot):


def get_dimension(context):
"""Get the dimensionality of a LAMMPS simulation."""
return context.extract_setting("dimension")


def get_global_box(context):
"""Get the box and origin of a LAMMPS simulation."""
boxlo, boxhi, xy, yz, xz, *_ = context.extract_box()
Lx = boxhi[0] - boxlo[0]
Ly = boxhi[1] - boxlo[1]
Expand All @@ -186,10 +221,21 @@ def get_global_box(context):


def get_timestep(context):
"""Get the timestep of a LAMMPS simulation."""
return context.extract_global("dt")


def bind(sampling_context: SamplingContext, callback: Optional[Callable], **kwargs):
"""
Defines and sets up a Sampler to perform an enhanced sampling simulation.
This function takes a ``sampling_context`` that has its context attribute as an instance
of a LAMMPS simulation, and creates a ``Sampler`` object that connects the PySAGES
sampling method to the LAMMPS simulation. It also modifies the sampling_context's view
and run attributes to use the sampler's view and the LAMMPS run command.
"""
identity(kwargs) # we ignore the kwargs for now

context = sampling_context.context
sampling_method = sampling_context.method
sampler = Sampler(context, sampling_method, callback)
Expand Down

0 comments on commit 57485a9

Please sign in to comment.