From c918d3b37de7d09f1627ff8f4c395edf055d4e7e Mon Sep 17 00:00:00 2001 From: Sacha Date: Fri, 9 Aug 2024 00:54:19 +0100 Subject: [PATCH] fix: map_location is needed for cpu --- diffusion_models/utils/schemas.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/diffusion_models/utils/schemas.py b/diffusion_models/utils/schemas.py index 5074212..55974f4 100644 --- a/diffusion_models/utils/schemas.py +++ b/diffusion_models/utils/schemas.py @@ -83,16 +83,17 @@ class Checkpoint: """ @classmethod - def from_file(cls, file_path: str) -> "Checkpoint": + def from_file(cls, file_path: str, map_location: Optional[str] = None) -> "Checkpoint": """Load and instantiate a checkpoint from a file. Args: file_path: The path to the checkpoint file. + map_location: A function, torch. device, string or a dict specifying how to remap storage location. Returns: A checkpoint instance. """ - checkpoint = torch.load(f=file_path, weights_only=True) + checkpoint = torch.load(f=file_path, weights_only=True, map_location=map_location) checkpoint = cls(**checkpoint) beta_scheduler_config = BetaSchedulerConfiguration( **checkpoint.beta_scheduler_config