Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure util.serde #334

Merged
merged 2 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion curated_transformers/models/hf_hub/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ...repository.fsspec import FsspecArgs, FsspecRepository
from ...repository.hf_hub import HfHubRepository
from ...repository.repository import ModelRepository, Repository
from ...util.serde import load_model_from_checkpoints
from ...util.serde.load import load_model_from_checkpoints
from ..module import TransformerModule

# Only provided as typing.Self in Python 3.11+.
Expand Down
2 changes: 1 addition & 1 deletion curated_transformers/quantization/bnb/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..._compat import has_bitsandbytes
from ...util.pytorch import ModuleIterator, apply_to_module
from ...util.serde import TensorToParameterConverterT
from ...util.serde.load import TensorToParameterConverterT
from .config import BitsAndBytesConfig, _4BitConfig, _8BitConfig

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion curated_transformers/quantization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch.nn import Module

from ..util.serde import TensorToParameterConverterT
from ..util.serde.load import TensorToParameterConverterT
from .bnb import prepare_for_quantization as bnb_prepare_for_quantization
from .bnb.config import BitsAndBytesConfig
from .quantizable import Quantizable
Expand Down
88 changes: 2 additions & 86 deletions curated_transformers/repository/_hf.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
from contextvars import ContextVar
from enum import Enum
from typing import TYPE_CHECKING, Callable, Iterable, Mapping, Optional

import torch

from .._compat import has_safetensors
from ..repository.file import RepositoryFile

if TYPE_CHECKING:
import safetensors
from typing import TYPE_CHECKING

from ..util.serde.checkpoint import ModelCheckpointType

HF_MODEL_CONFIG = "config.json"
HF_MODEL_CHECKPOINT = "pytorch_model.bin"
Expand All @@ -22,37 +13,6 @@
TOKENIZER_JSON = "tokenizer.json"


class ModelCheckpointType(Enum):
"""
Types of model checkpoints supported by Curated Transformers.
"""

#: PyTorch `checkpoint<https://pytorch.org/docs/stable/generated/torch.save.html>`_.
PYTORCH_STATE_DICT = 0

#: Hugging Face `Safetensors <https:/huggingface/safetensors>`_ checkpoint.
SAFE_TENSORS = 1

@property
def loader(
self,
) -> Callable[[Iterable[RepositoryFile]], Iterable[Mapping[str, torch.Tensor]]]:
checkpoint_type_to_loader = {
ModelCheckpointType.PYTORCH_STATE_DICT: _load_pytorch_state_dicts_from_checkpoints,
ModelCheckpointType.SAFE_TENSORS: _load_safetensor_state_dicts_from_checkpoints,
}
return checkpoint_type_to_loader[self]

@property
def pretty_name(self) -> str:
if self == ModelCheckpointType.PYTORCH_STATE_DICT:
return "PyTorch StateDict"
elif self == ModelCheckpointType.SAFE_TENSORS:
return "SafeTensors"
else:
return ""


PRIMARY_CHECKPOINT_FILENAMES = {
ModelCheckpointType.PYTORCH_STATE_DICT: HF_MODEL_CHECKPOINT,
ModelCheckpointType.SAFE_TENSORS: HF_MODEL_CHECKPOINT_SAFETENSORS,
Expand All @@ -63,47 +23,3 @@ def pretty_name(self) -> str:
}
# Same for both checkpoint types.
SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY = HF_MODEL_SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY


# When `None`, behaviour is implementation-specific.
_MODEL_CHECKPOINT_TYPE: ContextVar[Optional[ModelCheckpointType]] = ContextVar(
"model_checkpoint_type", default=None
)


def _load_safetensor_state_dicts_from_checkpoints(
checkpoints: Iterable[RepositoryFile],
) -> Iterable[Mapping[str, torch.Tensor]]:
if not has_safetensors:
raise ValueError(
"The `safetensors` library is required to load models from Safetensors checkpoints"
)

import safetensors.torch

for checkpoint in checkpoints:
# Prefer to load from a path when possible. Since loading from a file
# temporarily puts the checkpoint in memory twice.
if checkpoint.path is not None:
# Map to CPU first to support all devices.
state_dict = safetensors.torch.load_file(checkpoint.path, device="cpu")
else:
with checkpoint.open() as f:
# This has memory overhead, since Safetensors does not have
# support for loading from a file object and cannot use
# the bytes in-place.
checkpoint_bytes = f.read()
state_dict = safetensors.torch.load(checkpoint_bytes)
yield state_dict


def _load_pytorch_state_dicts_from_checkpoints(
checkpoints: Iterable[RepositoryFile],
) -> Iterable[Mapping[str, torch.Tensor]]:
for checkpoint in checkpoints:
with checkpoint.open() as f:
# Map to CPU first to support all devices.
state_dict = torch.load(
f, map_location=torch.device("cpu"), weights_only=True
)
yield state_dict
3 changes: 1 addition & 2 deletions curated_transformers/repository/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
from typing import Any, Dict, List, Optional, Tuple

from .._compat import has_safetensors
from ..util.serde.checkpoint import _MODEL_CHECKPOINT_TYPE, ModelCheckpointType
from ._hf import (
_MODEL_CHECKPOINT_TYPE,
HF_MODEL_CONFIG,
HF_TOKENIZER_CONFIG,
PRIMARY_CHECKPOINT_FILENAMES,
SHARDED_CHECKPOINT_INDEX_FILENAMES,
SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY,
SPECIAL_TOKENS_MAP,
TOKENIZER_JSON,
ModelCheckpointType,
)
from .file import RepositoryFile

Expand Down
6 changes: 2 additions & 4 deletions curated_transformers/tests/models/test_hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from curated_transformers.models.bert.encoder import BERTEncoder
from curated_transformers.repository.hf_hub import HfHubRepository
from curated_transformers.repository.repository import ModelRepository
from curated_transformers.util.serde import (
ModelCheckpointType,
_use_model_checkpoint_type,
)
from curated_transformers.util.serde.checkpoint import ModelCheckpointType
from curated_transformers.util.serde.load import _use_model_checkpoint_type

from ..compat import has_hf_transformers, has_safetensors
from ..conftest import TORCH_DEVICES
Expand Down
Empty file.
86 changes: 86 additions & 0 deletions curated_transformers/util/serde/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from contextvars import ContextVar
from enum import Enum
from typing import TYPE_CHECKING, Callable, Iterable, Mapping, Optional

import torch

from ..._compat import has_safetensors
from ...repository.file import RepositoryFile

if TYPE_CHECKING:
import safetensors


class ModelCheckpointType(Enum):
"""
Types of model checkpoints supported by Curated Transformers.
"""

#: PyTorch `checkpoint<https://pytorch.org/docs/stable/generated/torch.save.html>`_.
PYTORCH_STATE_DICT = 0

#: Hugging Face `Safetensors <https:/huggingface/safetensors>`_ checkpoint.
SAFE_TENSORS = 1

@property
def loader(
self,
) -> Callable[[Iterable[RepositoryFile]], Iterable[Mapping[str, torch.Tensor]]]:
checkpoint_type_to_loader = {
ModelCheckpointType.PYTORCH_STATE_DICT: _load_pytorch_state_dicts_from_checkpoints,
ModelCheckpointType.SAFE_TENSORS: _load_safetensor_state_dicts_from_checkpoints,
}
return checkpoint_type_to_loader[self]

@property
def pretty_name(self) -> str:
if self == ModelCheckpointType.PYTORCH_STATE_DICT:
return "PyTorch StateDict"
elif self == ModelCheckpointType.SAFE_TENSORS:
return "SafeTensors"
else:
return ""


# When `None`, behaviour is implementation-specific.
_MODEL_CHECKPOINT_TYPE: ContextVar[Optional[ModelCheckpointType]] = ContextVar(
"model_checkpoint_type", default=None
)


def _load_safetensor_state_dicts_from_checkpoints(
checkpoints: Iterable[RepositoryFile],
) -> Iterable[Mapping[str, torch.Tensor]]:
if not has_safetensors:
raise ValueError(
"The `safetensors` library is required to load models from Safetensors checkpoints"
)

import safetensors.torch

for checkpoint in checkpoints:
# Prefer to load from a path when possible. Since loading from a file
# temporarily puts the checkpoint in memory twice.
if checkpoint.path is not None:
# Map to CPU first to support all devices.
state_dict = safetensors.torch.load_file(checkpoint.path, device="cpu")
else:
with checkpoint.open() as f:
# This has memory overhead, since Safetensors does not have
# support for loading from a file object and cannot use
# the bytes in-place.
checkpoint_bytes = f.read()
state_dict = safetensors.torch.load(checkpoint_bytes)
yield state_dict


def _load_pytorch_state_dicts_from_checkpoints(
checkpoints: Iterable[RepositoryFile],
) -> Iterable[Mapping[str, torch.Tensor]]:
for checkpoint in checkpoints:
with checkpoint.open() as f:
# Map to CPU first to support all devices.
state_dict = torch.load(
f, map_location=torch.device("cpu"), weights_only=True
)
yield state_dict
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import torch
from torch.nn import Module, Parameter

from ..repository._hf import _MODEL_CHECKPOINT_TYPE, ModelCheckpointType
from ..repository.file import RepositoryFile
from .pytorch import ModuleIterator, apply_to_module
from ...repository.file import RepositoryFile
from ..pytorch import ModuleIterator, apply_to_module
from .checkpoint import _MODEL_CHECKPOINT_TYPE, ModelCheckpointType

# Args: Parent module, module prefix, parameter name, tensor to convert, device.
# Returns the new paramater.
Expand Down