Skip to content

Commit

Permalink
fix: map_location is needed for cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
sachahu1 committed Aug 8, 2024
1 parent 27cb480 commit c918d3b
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions diffusion_models/utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,17 @@ class Checkpoint:
"""

@classmethod
def from_file(cls, file_path: str) -> "Checkpoint":
def from_file(cls, file_path: str, map_location: Optional[str] = None) -> "Checkpoint":
"""Load and instantiate a checkpoint from a file.
Args:
file_path: The path to the checkpoint file.
map_location: A function, torch. device, string or a dict specifying how to remap storage location.
Returns:
A checkpoint instance.
"""
checkpoint = torch.load(f=file_path, weights_only=True)
checkpoint = torch.load(f=file_path, weights_only=True, map_location=map_location)
checkpoint = cls(**checkpoint)
beta_scheduler_config = BetaSchedulerConfiguration(
**checkpoint.beta_scheduler_config
Expand Down

0 comments on commit c918d3b

Please sign in to comment.