diff --git a/pysages/backends/lammps.py b/pysages/backends/lammps.py index aa07a52a..376254e7 100644 --- a/pysages/backends/lammps.py +++ b/pysages/backends/lammps.py @@ -7,6 +7,7 @@ """ import importlib +import weakref from functools import partial import jax @@ -247,5 +248,7 @@ def bind(sampling_context: SamplingContext, callback: Optional[Callable], **kwar # Unfortunately, the default implementation of `lammps.__exit__` closes # the lammps instance, so we need to overwrite it. context.__exit__ = lambda *args: None + # Ensure that the lammps context is properly finalized. + weakref.finalize(context, context.finalize) return sampler