diff --git a/curated_transformers/models/auto_model.py b/curated_transformers/models/auto_model.py index 7f37d74f..94de95a0 100644 --- a/curated_transformers/models/auto_model.py +++ b/curated_transformers/models/auto_model.py @@ -1,13 +1,14 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, Optional, Type, TypeVar +from typing import Dict, Generic, Optional, Type, TypeVar import torch from fsspec import AbstractFileSystem from ..layers.cache import KeyValueCache from ..quantization.bnb.config import BitsAndBytesConfig -from ..util.fsspec import get_config_model_type as get_config_model_type_fsspec -from ..util.hf import get_config_model_type +from ..repository.fsspec import FsspecArgs, FsspecRepository +from ..repository.hf_hub import HfHubRepository +from ..repository.repository import ModelRepository, Repository from .albert import ALBERTEncoder from .bert import BERTEncoder from .camembert import CamemBERTEncoder @@ -33,36 +34,12 @@ class AutoModel(ABC, Generic[ModelT]): _hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = {} - @classmethod - def _resolve_model_cls_fsspec( - cls, - fs: AbstractFileSystem, - model_path: str, - fsspec_args: Optional[Dict[str, Any]] = None, - ) -> Type[FromHFHub]: - model_type = get_config_model_type_fsspec( - fs, model_path, fsspec_args=fsspec_args - ) - if model_type is None: - raise ValueError( - "The model type is not defined in the model configuration." - ) - module_cls = cls._hf_model_type_to_curated.get(model_type) - if module_cls is None: - raise ValueError( - f"Unsupported model type `{model_type}` for {cls.__name__}. " - f"Supported model types: {tuple(cls._hf_model_type_to_curated.keys())}" - ) - assert issubclass(module_cls, FromHFHub) - return module_cls - @classmethod def _resolve_model_cls( cls, - name: str, - revision: str, + repo: ModelRepository, ) -> Type[FromHFHub]: - model_type = get_config_model_type(name, revision) + model_type = repo.model_type() module_cls = cls._hf_model_type_to_curated.get(model_type) if module_cls is None: raise ValueError( @@ -73,36 +50,15 @@ def _resolve_model_cls( return module_cls @classmethod - def _instantiate_model_from_fsspec( - cls, - fs: AbstractFileSystem, - model_path: str, - fsspec_args: Optional[Dict[str, Any]], - device: Optional[torch.device], - quantization_config: Optional[BitsAndBytesConfig], - ) -> FromHFHub: - module_cls = cls._resolve_model_cls_fsspec(fs, model_path) - module = module_cls.from_fsspec( - fs=fs, - model_path=model_path, - fsspec_args=fsspec_args, - device=device, - quantization_config=quantization_config, - ) - return module - - @classmethod - def _instantiate_model_from_hf_hub( + def _instantiate_model( cls, - name: str, - revision: str, + repo: Repository, device: Optional[torch.device], quantization_config: Optional[BitsAndBytesConfig], ) -> FromHFHub: - module_cls = cls._resolve_model_cls(name, revision) - module = module_cls.from_hf_hub( - name=name, - revision=revision, + module_cls = cls._resolve_model_cls(ModelRepository(repo)) + module = module_cls.from_repo( + repo=repo, device=device, quantization_config=quantization_config, ) @@ -114,7 +70,7 @@ def from_fsspec( *, fs: AbstractFileSystem, model_path: str, - fsspec_args: Optional[Dict[str, Any]] = None, + fsspec_args: Optional[FsspecArgs] = None, device: Optional[torch.device] = None, quantization_config: Optional[BitsAndBytesConfig] = None, ) -> ModelT: @@ -135,10 +91,17 @@ def from_fsspec( :returns: Module with the parameters loaded. """ - raise NotImplementedError + return cls.from_repo( + repo=FsspecRepository( + fs, + path=model_path, + fsspec_args=fsspec_args, + ), + device=device, + quantization_config=quantization_config, + ) @classmethod - @abstractmethod def from_hf_hub( cls, *, @@ -161,6 +124,34 @@ def from_hf_hub( :returns: Loaded model or generator. """ + return cls.from_repo( + repo=HfHubRepository(name=name, revision=revision), + device=device, + quantization_config=quantization_config, + ) + + @classmethod + @abstractmethod + def from_repo( + cls, + *, + repo: Repository, + device: Optional[torch.device] = None, + quantization_config: Optional[BitsAndBytesConfig] = None, + ) -> ModelT: + """ + Construct and load a model or a generator from a repository. + + :param repository: + The repository to load from. + :param device: + Device on which to initialize the model. + :param quantization_config: + Configuration for loading quantized weights. + :returns: + Loaded model or generator. + """ + raise NotImplementedError @classmethod @@ -181,8 +172,9 @@ def from_hf_hub_to_cache( :param revision: Model revision. """ - module_cls = cls._resolve_model_cls(name, revision) - module_cls.from_hf_hub_to_cache(name=name, revision=revision) + repo = ModelRepository(HfHubRepository(name=name, revision=revision)) + repo.model_config() + repo.model_checkpoints() class AutoEncoder(AutoModel[EncoderModule[TransformerConfig]]): @@ -199,33 +191,14 @@ class AutoEncoder(AutoModel[EncoderModule[TransformerConfig]]): } @classmethod - def from_fsspec( + def from_repo( cls, *, - fs: AbstractFileSystem, - model_path: str, - fsspec_args: Optional[Dict[str, Any]] = None, + repo: Repository, device: Optional[torch.device] = None, quantization_config: Optional[BitsAndBytesConfig] = None, ) -> EncoderModule[TransformerConfig]: - encoder = cls._instantiate_model_from_fsspec( - fs, model_path, fsspec_args, device, quantization_config - ) - assert isinstance(encoder, EncoderModule) - return encoder - - @classmethod - def from_hf_hub( - cls, - *, - name: str, - revision: str = "main", - device: Optional[torch.device] = None, - quantization_config: Optional[BitsAndBytesConfig] = None, - ) -> EncoderModule[TransformerConfig]: - encoder = cls._instantiate_model_from_hf_hub( - name, revision, device, quantization_config - ) + encoder = cls._instantiate_model(repo, device, quantization_config) assert isinstance(encoder, EncoderModule) return encoder @@ -245,33 +218,14 @@ class AutoDecoder(AutoModel[DecoderModule[TransformerConfig, KeyValueCache]]): } @classmethod - def from_fsspec( - cls, - *, - fs: AbstractFileSystem, - model_path: str, - fsspec_args: Optional[Dict[str, Any]] = None, - device: Optional[torch.device] = None, - quantization_config: Optional[BitsAndBytesConfig] = None, - ) -> DecoderModule[TransformerConfig, KeyValueCache]: - decoder = cls._instantiate_model_from_fsspec( - fs, model_path, fsspec_args, device, quantization_config - ) - assert isinstance(decoder, DecoderModule) - return decoder - - @classmethod - def from_hf_hub( + def from_repo( cls, *, - name: str, - revision: str = "main", + repo: Repository, device: Optional[torch.device] = None, quantization_config: Optional[BitsAndBytesConfig] = None, ) -> DecoderModule[TransformerConfig, KeyValueCache]: - decoder = cls._instantiate_model_from_hf_hub( - name, revision, device, quantization_config - ) + decoder = cls._instantiate_model(repo, device, quantization_config) assert isinstance(decoder, DecoderModule) return decoder @@ -291,32 +245,13 @@ class AutoCausalLM(AutoModel[CausalLMModule[TransformerConfig, KeyValueCache]]): } @classmethod - def from_fsspec( + def from_repo( cls, *, - fs: AbstractFileSystem, - model_path: str, - fsspec_args: Optional[Dict[str, Any]] = None, + repo: Repository, device: Optional[torch.device] = None, quantization_config: Optional[BitsAndBytesConfig] = None, ) -> CausalLMModule[TransformerConfig, KeyValueCache]: - causal_lm = cls._instantiate_model_from_fsspec( - fs, model_path, fsspec_args, device, quantization_config - ) - assert isinstance(causal_lm, CausalLMModule) - return causal_lm - - @classmethod - def from_hf_hub( - cls, - *, - name: str, - revision: str = "main", - device: Optional[torch.device] = None, - quantization_config: Optional[BitsAndBytesConfig] = None, - ) -> CausalLMModule[TransformerConfig, KeyValueCache]: - causal_lm = cls._instantiate_model_from_hf_hub( - name, revision, device, quantization_config - ) + causal_lm = cls._instantiate_model(repo, device, quantization_config) assert isinstance(causal_lm, CausalLMModule) return causal_lm diff --git a/curated_transformers/models/hf_hub.py b/curated_transformers/models/hf_hub.py index d3e5f823..986c4c4c 100644 --- a/curated_transformers/models/hf_hub.py +++ b/curated_transformers/models/hf_hub.py @@ -18,12 +18,10 @@ from ..quantization import prepare_module_for_quantization from ..quantization.bnb.config import BitsAndBytesConfig -from ..util.fsspec import ( - get_model_checkpoint_files as get_model_checkpoint_files_fsspec, -) -from ..util.fsspec import get_model_config as get_model_config_fsspec -from ..util.hf import get_model_checkpoint_files, get_model_config -from ..util.serde import ModelCheckpointType, ModelFile, load_model_from_checkpoints +from ..repository.fsspec import FsspecArgs, FsspecRepository +from ..repository.hf_hub import HfHubRepository +from ..repository.repository import ModelRepository, Repository +from ..util.serde import load_model_from_checkpoints from .module import TransformerModule # Only provided as typing.Self in Python 3.11+. @@ -94,8 +92,9 @@ def from_hf_hub_to_cache( :param revision: Model revision. """ - _ = get_model_config(name, revision) - _ = get_model_checkpoint_files(name, revision) + repo = ModelRepository(HfHubRepository(name=name, revision=revision)) + repo.model_config() + repo.model_checkpoints() @classmethod def from_fsspec( @@ -103,7 +102,7 @@ def from_fsspec( *, fs: AbstractFileSystem, model_path: str, - fsspec_args: Optional[Dict[str, Any]] = None, + fsspec_args: Optional[FsspecArgs] = None, device: Optional[torch.device] = None, quantization_config: Optional[BitsAndBytesConfig] = None, ) -> Self: @@ -124,13 +123,8 @@ def from_fsspec( :returns: Module with the parameters loaded. """ - return cls._create_and_load_model( - get_config=lambda: get_model_config_fsspec( - fs, model_path, fsspec_args=fsspec_args - ), - get_checkpoint_files=lambda: get_model_checkpoint_files_fsspec( - fs, model_path, fsspec_args=fsspec_args - ), + return cls.from_repo( + repo=FsspecRepository(fs, model_path, fsspec_args), device=device, quantization_config=quantization_config, ) @@ -158,9 +152,8 @@ def from_hf_hub( :returns: Module with the parameters loaded. """ - return cls._create_and_load_model( - get_config=lambda: get_model_config(name, revision), - get_checkpoint_files=lambda: get_model_checkpoint_files(name, revision), + return cls.from_repo( + repo=HfHubRepository(name=name, revision=revision), device=device, quantization_config=quantization_config, ) @@ -182,15 +175,27 @@ def to( ... @classmethod - def _create_and_load_model( + def from_repo( cls: Type[Self], *, - get_config: Callable[[], Dict[Any, str]], - get_checkpoint_files: Callable[[], Tuple[List[ModelFile], ModelCheckpointType]], + repo: Repository, device: Optional[torch.device] = None, quantization_config: Optional[BitsAndBytesConfig] = None, ) -> Self: - config = get_config() + """ + Construct and load a model from a repository. + + :param repository: + The repository to load from. + :param device: + Device on which to initialize the model. + :param quantization_config: + Configuration for loading quantized weights. + :returns: + Loaded model. + """ + model_repo = ModelRepository(repo) + config = model_repo.model_config() model = cls.from_hf_config(hf_config=config, device=torch.device("meta")) # Convert the model to the expected dtype. @@ -211,7 +216,7 @@ def _create_and_load_model( tensor2param = None # Download model and convert HF parameter names to ours. - checkpoint_filenames, checkpoint_type = get_checkpoint_files() + checkpoint_filenames, checkpoint_type = model_repo.model_checkpoints() load_model_from_checkpoints( model, # type:ignore filepaths=checkpoint_filenames, diff --git a/curated_transformers/repository/__init__.py b/curated_transformers/repository/__init__.py new file mode 100644 index 00000000..d5ca674e --- /dev/null +++ b/curated_transformers/repository/__init__.py @@ -0,0 +1,16 @@ +from .file import LocalFile, RepositoryFile +from .fsspec import FsspecArgs, FsspecFile, FsspecRepository +from .hf_hub import HfHubRepository +from .repository import ModelRepository, Repository, TokenizerRepository + +__all__ = [ + "FsspecArgs", + "FsspecFile", + "FsspecRepository", + "HfHubRepository", + "LocalFile", + "ModelRepository", + "Repository", + "RepositoryFile", + "TokenizerRepository", +] diff --git a/curated_transformers/repository/_hf.py b/curated_transformers/repository/_hf.py new file mode 100644 index 00000000..0b8003d7 --- /dev/null +++ b/curated_transformers/repository/_hf.py @@ -0,0 +1,109 @@ +from contextvars import ContextVar +from enum import Enum +from typing import TYPE_CHECKING, Callable, Iterable, Mapping, Optional + +import torch + +from .._compat import has_safetensors +from ..repository.file import RepositoryFile + +if TYPE_CHECKING: + import safetensors + + +HF_MODEL_CONFIG = "config.json" +HF_MODEL_CHECKPOINT = "pytorch_model.bin" +HF_MODEL_CHECKPOINT_SAFETENSORS = "model.safetensors" +HF_MODEL_SHARDED_CHECKPOINT_INDEX = "pytorch_model.bin.index.json" +HF_MODEL_SHARDED_CHECKPOINT_INDEX_SAFETENSORS = "model.safetensors.index.json" +HF_MODEL_SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY = "weight_map" +HF_TOKENIZER_CONFIG = "tokenizer_config.json" +SPECIAL_TOKENS_MAP = "special_tokens_map.json" +TOKENIZER_JSON = "tokenizer.json" + + +class ModelCheckpointType(Enum): + """ + Types of model checkpoints supported by Curated Transformers. + """ + + #: PyTorch `checkpoint`_. + PYTORCH_STATE_DICT = 0 + + #: Hugging Face `Safetensors `_ checkpoint. + SAFE_TENSORS = 1 + + @property + def loader( + self, + ) -> Callable[[Iterable[RepositoryFile]], Iterable[Mapping[str, torch.Tensor]]]: + checkpoint_type_to_loader = { + ModelCheckpointType.PYTORCH_STATE_DICT: _load_pytorch_state_dicts_from_checkpoints, + ModelCheckpointType.SAFE_TENSORS: _load_safetensor_state_dicts_from_checkpoints, + } + return checkpoint_type_to_loader[self] + + @property + def pretty_name(self) -> str: + if self == ModelCheckpointType.PYTORCH_STATE_DICT: + return "PyTorch StateDict" + elif self == ModelCheckpointType.SAFE_TENSORS: + return "SafeTensors" + else: + return "" + + +PRIMARY_CHECKPOINT_FILENAMES = { + ModelCheckpointType.PYTORCH_STATE_DICT: HF_MODEL_CHECKPOINT, + ModelCheckpointType.SAFE_TENSORS: HF_MODEL_CHECKPOINT_SAFETENSORS, +} +SHARDED_CHECKPOINT_INDEX_FILENAMES = { + ModelCheckpointType.PYTORCH_STATE_DICT: HF_MODEL_SHARDED_CHECKPOINT_INDEX, + ModelCheckpointType.SAFE_TENSORS: HF_MODEL_SHARDED_CHECKPOINT_INDEX_SAFETENSORS, +} +# Same for both checkpoint types. +SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY = HF_MODEL_SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY + + +# When `None`, behaviour is implementation-specific. +_MODEL_CHECKPOINT_TYPE: ContextVar[Optional[ModelCheckpointType]] = ContextVar( + "model_checkpoint_type", default=None +) + + +def _load_safetensor_state_dicts_from_checkpoints( + checkpoints: Iterable[RepositoryFile], +) -> Iterable[Mapping[str, torch.Tensor]]: + if not has_safetensors: + raise ValueError( + "The `safetensors` library is required to load models from Safetensors checkpoints" + ) + + import safetensors.torch + + for checkpoint in checkpoints: + # Prefer to load from a path when possible. Since loading from a file + # temporarily puts the checkpoint in memory twice. + if checkpoint.path is not None: + # Map to CPU first to support all devices. + state_dict = safetensors.torch.load_file(checkpoint.path, device="cpu") + else: + with checkpoint.open() as f: + # This has memory overhead, since Safetensors does not have + # support for loading from a file object and cannot use + # the bytes in-place. + checkpoint_bytes = f.read() + state_dict = safetensors.torch.load(checkpoint_bytes) + yield state_dict + + +def _load_pytorch_state_dicts_from_checkpoints( + checkpoints: Iterable[RepositoryFile], +) -> Iterable[Mapping[str, torch.Tensor]]: + for checkpoint in checkpoints: + with checkpoint.open() as f: + # Map to CPU first to support all devices. + state_dict = torch.load( + f, map_location=torch.device("cpu"), weights_only=True + ) + yield state_dict diff --git a/curated_transformers/repository/file.py b/curated_transformers/repository/file.py new file mode 100644 index 00000000..42ef4113 --- /dev/null +++ b/curated_transformers/repository/file.py @@ -0,0 +1,65 @@ +from abc import ABC, abstractmethod +from typing import IO, Optional + + +class RepositoryFile(ABC): + """ + A repository file. + + Repository files can be a local path or a remote path exposed as a + file-like object. This is a common base class for such different types + of repository files. + """ + + @abstractmethod + def open(self, mode: str = "rb", encoding: Optional[str] = None) -> IO: + """ + Get the file as a file-like object. + + :param mode: + Mode to open the file with (see Python ``open``). + :param encoding: + Encoding to use when the file is opened as text. + :returns: + An I/O stream. + :raises OSError: + When the file cannot be opened. + """ + ... + + @property + @abstractmethod + def path(self) -> Optional[str]: + """ + Get the file as a local path. + + :returns: + The repository file. If the file is not available as a local + path, the value of this property is ``None``. In these cases + ``open`` can be used to get the file as a file-like object. + """ + ... + + +class LocalFile(RepositoryFile): + """ + Repository file on the local machine. + """ + + def __init__(self, path: str): + """ + Construct a local file representation. + + :param path: + The path of the file on the local filesystem. + """ + super().__init__() + self._path = path + + def open(self, mode: str = "rb", encoding: Optional[str] = None) -> IO: + # Raises OSError, so we don't have to do any rewrapping. + return open(self._path, mode=mode, encoding=encoding) + + @property + def path(self) -> Optional[str]: + return self._path diff --git a/curated_transformers/repository/fsspec.py b/curated_transformers/repository/fsspec.py new file mode 100644 index 00000000..861b332b --- /dev/null +++ b/curated_transformers/repository/fsspec.py @@ -0,0 +1,111 @@ +from dataclasses import dataclass +from typing import IO, Any, Dict, Optional + +from fsspec import AbstractFileSystem + +from .repository import Repository, RepositoryFile + + +@dataclass +class FsspecArgs: + """ + Convenience wrapper for additional fsspec arguments. + """ + + kwargs: Dict[str, Any] + + def __init__(self, **kwargs): + """ + Keyword arguments are passed through to the fsspec implementation. + Construction may raise in the future when reserved arguments like + ``mode`` or ``encoding`` are used. + """ + # Future improvement: raise on args that are used by fsspec methods, + # e.g. `mode` or `encoding`. + self.kwargs = kwargs + + +class FsspecFile(RepositoryFile): + """ + Repository file on an `fsspec`_ filesystem. + + .. _fsspec: https://filesystem-spec.readthedocs.io/en/latest/ + """ + + def __init__( + self, + fs: AbstractFileSystem, + path: str, + fsspec_args: Optional[FsspecArgs] = None, + ): + """ + Construct an fsspec file representation. + + :param fs: + The filesystem. + :param path: + The path of the file on the filesystem. + :param fsspec_args: + Implementation-specific arguments to pass to fsspec filesystem + operations. + """ + super().__init__() + self._fs = fs + self._path = path + self._fsspec_args = FsspecArgs() if fsspec_args is None else fsspec_args + + def open(self, mode: str = "rb", encoding: Optional[str] = None) -> IO: + try: + return self._fs.open( + self._path, mode=mode, encoding=encoding, **self._fsspec_args.kwargs + ) + except Exception as e: + raise OSError( + f"Cannot open fsspec path {self._fs.unstrip_protocol(self.path)}" + ) + + @property + def path(self) -> Optional[str]: + return None + + +class FsspecRepository(Repository): + """ + Repository using a filesystem that uses the `fsspec`_ interface. + + .. _fsspec: https://filesystem-spec.readthedocs.io/en/latest/ + + :param fs: + The filesystem. + :param path: + The the path of the repository within the filesystem. + :param fsspec_args: + Additional arguments that should be passed to the fsspec + implementation. + """ + + def __init__( + self, + fs: AbstractFileSystem, + path: str, + fsspec_args: Optional[FsspecArgs] = None, + ): + super().__init__() + self.fs = fs + self.repo_path = path + self.fsspec_args = FsspecArgs() if fsspec_args is None else fsspec_args + + def file(self, path: str) -> RepositoryFile: + full_path = f"{self.repo_path}/{path}" + if not self.fs.exists(full_path, **self.fsspec_args.kwargs): + raise FileNotFoundError(f"Cannot find file in repository: {path}") + return FsspecFile(self.fs, full_path, self.fsspec_args) + + def pretty_path(self, path: Optional[str] = None) -> str: + if not path: + return self._protocol + return f"{self._protocol}/{path}" + + @property + def _protocol(self) -> str: + return self.fs.unstrip_protocol(self.repo_path) diff --git a/curated_transformers/repository/hf_hub.py b/curated_transformers/repository/hf_hub.py new file mode 100644 index 00000000..8407c6b4 --- /dev/null +++ b/curated_transformers/repository/hf_hub.py @@ -0,0 +1,92 @@ +import warnings +from typing import Optional + +import huggingface_hub +from huggingface_hub.utils import ( + EntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, +) +from requests import HTTPError, ReadTimeout # type: ignore + +from ..repository.file import LocalFile, RepositoryFile +from .repository import Repository + + +class HfHubRepository(Repository): + """ + Hugging Face Hub repository. + """ + + def __init__(self, name: str, *, revision: str = "main"): + """ + :param name: + Name of the repository on Hugging Face Hub. + :param revision: + Source repository revision. Can either be a branch name + or a SHA hash of a commit. + """ + super().__init__() + self.name = name + self.revision = revision + + def file(self, path: str) -> RepositoryFile: + try: + return LocalFile( + path=hf_hub_download( + repo_id=self.name, filename=path, revision=self.revision + ) + ) + except ( + EntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, + ) as e: + raise FileNotFoundError(f"File not found: {self.pretty_path(path)}") from e + except Exception as e: + raise OSError(f"File could not be opened: {self.pretty_path(path)}") from e + + def pretty_path(self, path: Optional[str] = None) -> str: + if not path: + return f"{self.name} (revision: {self.revision})" + return f"{self.name}/{path} (revision: {self.revision})" + + +def hf_hub_download(repo_id: str, filename: str, revision: str) -> str: + """ + Resolve the provided filename and repository to a local file path. If the file + is not present in the cache, it will be downloaded from the Hugging Face Hub. + + :param repo_id: + Identifier of the source repository on Hugging Face Hub. + :param filename: + Name of the file in the source repository to download. + :param revision: + Source repository revision. Can either be a branch name + or a SHA hash of a commit. + :returns: + Resolved absolute filepath. + """ + + # The HF Hub library's `hf_hub_download` function will always attempt to connect to the + # remote repo and fetch its metadata even if it's locally cached (in order to update the + # out-of-date artifacts in the cache). This can occasionally lead to `HTTPError/ReadTimeout`s if the + # remote host is unreachable. Instead of failing loudly, we'll add a fallback that checks + # the local cache for the artifacts and uses them if found. + try: + resolved = huggingface_hub.hf_hub_download( + repo_id=repo_id, filename=filename, revision=revision + ) + except (HTTPError, ReadTimeout) as e: + # Attempt to check the cache. + resolved = huggingface_hub.try_to_load_from_cache( + repo_id=repo_id, filename=filename, revision=revision + ) + if resolved is None or resolved is huggingface_hub._CACHED_NO_EXIST: + # Not found, rethrow. + raise e + else: + warnings.warn( + f"Couldn't reach Hugging Face Hub; using cached artifact for '{repo_id}@{revision}:{filename}'" + ) + return resolved diff --git a/curated_transformers/repository/repository.py b/curated_transformers/repository/repository.py new file mode 100644 index 00000000..5dbac924 --- /dev/null +++ b/curated_transformers/repository/repository.py @@ -0,0 +1,298 @@ +import json +from abc import ABC, abstractmethod +from json import JSONDecodeError +from typing import Any, Dict, List, Optional, Tuple + +from .._compat import has_safetensors +from ._hf import ( + _MODEL_CHECKPOINT_TYPE, + HF_MODEL_CONFIG, + HF_TOKENIZER_CONFIG, + PRIMARY_CHECKPOINT_FILENAMES, + SHARDED_CHECKPOINT_INDEX_FILENAMES, + SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY, + SPECIAL_TOKENS_MAP, + TOKENIZER_JSON, + ModelCheckpointType, +) +from .file import RepositoryFile + + +class Repository(ABC): + """ + A repository that contains a model or tokenizer. + """ + + @abstractmethod + def file(self, path: str) -> RepositoryFile: + """ + Get a repository file. + + :param path: + The path of the file within the repository. + :returns: + The file. + :raises FileNotFoundError: + When the file cannot be found. + :raises OSError: + When the file cannot be opened. + """ + ... + + def json_file(self, path: str) -> Dict[str, Any]: + """ + Get and parse a JSON file. + + :param path: + The path of the file within the repository. + :returns: + The deserialized JSON. + :raises FileNotFoundError: + When the file cannot be found. + :raises OSError: + When the file cannot be opened. + :raises json.JSONDecodeError: + When the JSON cannot be decoded. + """ + f = self.file(path) + with f.open("r", encoding="utf-8") as f: + return json.load(f) + + @abstractmethod + def pretty_path(self, path: Optional[str] = None) -> str: + """ + Get a user-consumable path representation (e.g. for error messages). + + :param path: + The path of a file within the repository. The repository path + will be returned if ``path`` is falsy. + :returns: + The path representation. + """ + ... + + +class ModelRepository(Repository): + """ + Repository wrapper that exposes some methods that are useful for working + with repositories that contain a model. + """ + + # Cached model configuration. + _model_config: Optional[Dict[str, Any]] + + def __init__(self, repo: Repository): + """ + Construct a model repository wrapper. + + :param repo: + The repository to wrap. + """ + super().__init__() + self.repo = repo + self._model_config = None + + def file(self, filename: str) -> RepositoryFile: + return self.repo.file(filename) + + def json_file(self, path: str) -> Dict[str, Any]: + return self.repo.json_file(path) + + def model_checkpoints(self) -> Tuple[List[RepositoryFile], ModelCheckpointType]: + """ + Retrieve the model checkpoints and checkpoint type. + + :returns: + A tuple of the model checkpoints and the checkpoint type. + :raises OSError: + When the checkpoint paths files could not be retrieved. + """ + + def get_checkpoint_paths( + checkpoint_type: ModelCheckpointType, + ) -> List[RepositoryFile]: + # Attempt to download a non-sharded checkpoint first. + try: + return [self.file(PRIMARY_CHECKPOINT_FILENAMES[checkpoint_type])] + except: + pass + + # Try sharded checkpoint. + index_filename = SHARDED_CHECKPOINT_INDEX_FILENAMES[checkpoint_type] + + try: + index = self.json_file(index_filename) + except (FileNotFoundError, JSONDecodeError, OSError) as e: + raise OSError( + f"Index file for sharded checkpoint type {checkpoint_type.pretty_name} " + f"could not be loaded `{self.pretty_path(index_filename)}`" + ) from e + + weight_map = index.get(SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY) + if not isinstance(weight_map, dict): + raise OSError( + f"Invalid index file in sharded {checkpoint_type.pretty_name} " + f"checkpoint for model at `{self.pretty_path()}`" + ) + + checkpoint_paths = [] + for filename in sorted(set(weight_map.values())): + try: + checkpoint_paths.append(self.file(filename)) + except FileNotFoundError: + raise OSError( + f"File for sharded checkpoint type {checkpoint_type.pretty_name} " + f"could not be found at `{self.pretty_path(index_filename)}`" + ) + + return checkpoint_paths + + checkpoint_type = _MODEL_CHECKPOINT_TYPE.get() + checkpoint_paths: Optional[List[RepositoryFile]] = None + + if checkpoint_type is None: + # Precedence: Safetensors > PyTorch + if has_safetensors: + try: + checkpoint_type = ModelCheckpointType.SAFE_TENSORS + checkpoint_paths = get_checkpoint_paths(checkpoint_type) + except OSError: + pass + if checkpoint_paths is None: + checkpoint_type = ModelCheckpointType.PYTORCH_STATE_DICT + checkpoint_paths = get_checkpoint_paths(checkpoint_type) + else: + checkpoint_paths = get_checkpoint_paths(checkpoint_type) + + assert checkpoint_paths is not None + assert checkpoint_type is not None + return checkpoint_paths, checkpoint_type + + def model_config(self) -> Dict[str, Any]: + """ + Get the model configuration. The result is cached to speed up + subsequent lookups. + + :returns: + The model configuration. + :raises OSError: + When the model config cannot be opened. + """ + if self._model_config is None: + try: + self._model_config = self.json_file( + path=HF_MODEL_CONFIG, + ) + except (FileNotFoundError, JSONDecodeError, OSError) as e: + raise OSError( + "Cannot load config for the model at " + f"`{self.pretty_path(HF_MODEL_CONFIG)}`" + ) from e + + return self._model_config + + def model_type(self) -> str: + """ + Get the model type. + + :returns: + The model type. + :raises OSError: + When the model config cannot be opened. + """ + return self.model_config()["model_type"] + + def pretty_path(self, path: Optional[str] = None) -> str: + return self.repo.pretty_path(path) + + +class TokenizerRepository(Repository): + """ + Repository wrapper that exposes some methods that are useful for working + with repositories that contain a tokenizer. + """ + + _tokenizer_config: Optional[Dict[str, Any]] + + def __init__(self, repo: Repository): + """ + Construct a tokenizer repository wrapper. + + :param repo: + The repository to wrap. + """ + super().__init__() + self.repo = repo + self._tokenizer_config = None + + def file(self, path: str) -> RepositoryFile: + return self.repo.file(path) + + def json_file(self, path: str) -> Dict[str, Any]: + return self.repo.json_file(path) + + def model_type(self) -> str: + """ + Get the model type. + + :returns: + The model type. + :raises OSError: + When the model config cannot be opened. + """ + model_config = self.json_file(HF_MODEL_CONFIG) + return model_config["model_type"] + + def pretty_path(self, path: Optional[str] = None) -> str: + return self.repo.pretty_path(path) + + def special_tokens_map(self) -> Dict[str, Any]: + """ + Return the tokenizer's special tokens map. + + :returns: + The special tokens map. + :raises OSError: + When the special tokens map cannot be opened. + """ + try: + return self.repo.json_file( + path=SPECIAL_TOKENS_MAP, + ) + except (FileNotFoundError, JSONDecodeError, OSError) as e: + raise OSError( + "Could not load special tokens map for the tokenizer at " + f"`{self.repo.pretty_path(SPECIAL_TOKENS_MAP)}`" + ) from e + + def tokenizer_config(self) -> Dict[str, Any]: + """ + Return the model's tokenizer configuration. The result is cached to + speed up subsequent lookups. + + :returns: + Model configuration. + :raises OSError: + When the tokenizer config cannot be opened. + """ + if self._tokenizer_config is None: + try: + self._tokenizer_config = self.repo.json_file( + path=HF_TOKENIZER_CONFIG, + ) + except (FileNotFoundError, JSONDecodeError, OSError) as e: + raise OSError( + "Couldn't find a valid config for the tokenizer at " + f"`{self.repo.pretty_path(HF_TOKENIZER_CONFIG)}`" + ) from e + + return self._tokenizer_config + + def tokenizer_json(self) -> RepositoryFile: + """ + Return the HF tokenizers' ``tokenizer.json``. + + :returns: + The tokenizer file. + """ + return self.repo.file(TOKENIZER_JSON) diff --git a/curated_transformers/tests/models/test_auto_models.py b/curated_transformers/tests/models/test_auto_models.py index 0a6c3fe8..2d1695d3 100644 --- a/curated_transformers/tests/models/test_auto_models.py +++ b/curated_transformers/tests/models/test_auto_models.py @@ -21,6 +21,7 @@ from curated_transformers.models.llama.decoder import LlamaDecoder from curated_transformers.models.mpt.causal_lm import MPTCausalLM from curated_transformers.models.mpt.decoder import MPTDecoder +from curated_transformers.repository.fsspec import FsspecArgs @pytest.fixture @@ -49,7 +50,7 @@ def test_auto_encoder_fsspec(model_encoder_map): # The default revision is 'main', but we pass it anyway to test # that the function acceps fsspec_args. encoder = AutoEncoder.from_fsspec( - fs=HfFileSystem(), model_path=name, fsspec_args={"revision": "main"} + fs=HfFileSystem(), model_path=name, fsspec_args=FsspecArgs(revision="main") ) assert isinstance(encoder, encoder_cls) diff --git a/curated_transformers/tests/models/test_hf_hub.py b/curated_transformers/tests/models/test_hf_hub.py index 0a33067a..3d222ecf 100644 --- a/curated_transformers/tests/models/test_hf_hub.py +++ b/curated_transformers/tests/models/test_hf_hub.py @@ -4,7 +4,8 @@ from huggingface_hub import _CACHED_NO_EXIST, try_to_load_from_cache from curated_transformers.models.bert.encoder import BERTEncoder -from curated_transformers.util.hf import get_model_checkpoint_files +from curated_transformers.repository.hf_hub import HfHubRepository +from curated_transformers.repository.repository import ModelRepository from curated_transformers.util.serde import ( ModelCheckpointType, _use_model_checkpoint_type, @@ -54,10 +55,11 @@ def test_checkpoint_type_without_safetensors(): # By default, we expect the torch checkpoint to be loaded # even if the safetensor checkpoints are present # (as long as the library is not installed). - ckp_paths, ckp_type = get_model_checkpoint_files( - "explosion-testing/safetensors-test", revision="main" - ) + ckp_paths, ckp_type = ModelRepository( + HfHubRepository("explosion-testing/safetensors-test", revision="main") + ).model_checkpoints() assert len(ckp_paths) == 1 + assert ckp_paths[0].path is not None assert Path(ckp_paths[0].path).suffix == ".bin" assert ckp_type == ModelCheckpointType.PYTORCH_STATE_DICT @@ -70,9 +72,10 @@ def test_checkpoint_type_without_safetensors(): def test_checkpoint_type_with_safetensors(): # Since the safetensors library is installed, we should be # loading from those checkpoints. - ckp_paths, ckp_type = get_model_checkpoint_files( - "explosion-testing/safetensors-test", revision="main" + repo = ModelRepository( + HfHubRepository("explosion-testing/safetensors-test", revision="main") ) + ckp_paths, ckp_type = repo.model_checkpoints() assert len(ckp_paths) == 1 assert Path(ckp_paths[0].path).suffix == ".safetensors" assert ckp_type == ModelCheckpointType.SAFE_TENSORS @@ -83,9 +86,12 @@ def test_checkpoint_type_with_safetensors(): @pytest.mark.skipif(not has_safetensors, reason="requires huggingface safetensors") def test_forced_checkpoint_type(): with _use_model_checkpoint_type(ModelCheckpointType.PYTORCH_STATE_DICT): - ckp_paths, ckp_type = get_model_checkpoint_files( - "explosion-testing/safetensors-sharded-test", revision="main" + repo = ModelRepository( + HfHubRepository( + "explosion-testing/safetensors-sharded-test", revision="main" + ) ) + ckp_paths, ckp_type = repo.model_checkpoints() assert len(ckp_paths) == 3 assert all(Path(p.path).suffix == ".bin" for p in ckp_paths) assert ckp_type == ModelCheckpointType.PYTORCH_STATE_DICT @@ -93,9 +99,7 @@ def test_forced_checkpoint_type(): encoder = BERTEncoder.from_hf_hub(name="explosion-testing/safetensors-test") with _use_model_checkpoint_type(ModelCheckpointType.SAFE_TENSORS): - ckp_paths, ckp_type = get_model_checkpoint_files( - "explosion-testing/safetensors-sharded-test", revision="main" - ) + ckp_paths, ckp_type = repo.model_checkpoints() assert len(ckp_paths) == 3 assert all(Path(p.path).suffix == ".safetensors" for p in ckp_paths) assert ckp_type == ModelCheckpointType.SAFE_TENSORS diff --git a/curated_transformers/tests/tokenizers/legacy/test_camembert_tokenizer.py b/curated_transformers/tests/tokenizers/legacy/test_camembert_tokenizer.py index 4181a899..8e08b06f 100644 --- a/curated_transformers/tests/tokenizers/legacy/test_camembert_tokenizer.py +++ b/curated_transformers/tests/tokenizers/legacy/test_camembert_tokenizer.py @@ -1,11 +1,11 @@ import pytest import torch +from curated_transformers.repository.file import LocalFile from curated_transformers.tokenizers import PiecesWithIds from curated_transformers.tokenizers.legacy.camembert_tokenizer import ( CamemBERTTokenizer, ) -from curated_transformers.util.serde import LocalModelFile from ...compat import has_hf_transformers from ...utils import torch_assertclose @@ -15,7 +15,7 @@ @pytest.fixture def toy_tokenizer(test_dir): return CamemBERTTokenizer.from_files( - model_file=LocalModelFile(path=test_dir / "toy.model"), + model_file=LocalFile(path=test_dir / "toy.model"), ) diff --git a/curated_transformers/tests/tokenizers/legacy/test_xlmr_tokenizer.py b/curated_transformers/tests/tokenizers/legacy/test_xlmr_tokenizer.py index 6a3de92e..68dfee62 100644 --- a/curated_transformers/tests/tokenizers/legacy/test_xlmr_tokenizer.py +++ b/curated_transformers/tests/tokenizers/legacy/test_xlmr_tokenizer.py @@ -1,9 +1,9 @@ import pytest import torch +from curated_transformers.repository.file import LocalFile from curated_transformers.tokenizers import PiecesWithIds from curated_transformers.tokenizers.legacy.xlmr_tokenizer import XLMRTokenizer -from curated_transformers.util.serde import LocalModelFile from ...compat import has_hf_transformers from ...utils import torch_assertclose @@ -13,7 +13,7 @@ @pytest.fixture def toy_tokenizer(test_dir): return XLMRTokenizer.from_files( - model_file=LocalModelFile(path=test_dir / "toy.model"), + model_file=LocalFile(path=test_dir / "toy.model"), ) diff --git a/curated_transformers/tests/tokenizers/test_auto_tokenizer.py b/curated_transformers/tests/tokenizers/test_auto_tokenizer.py index 77a70152..ac297b49 100644 --- a/curated_transformers/tests/tokenizers/test_auto_tokenizer.py +++ b/curated_transformers/tests/tokenizers/test_auto_tokenizer.py @@ -1,6 +1,7 @@ import pytest from huggingface_hub import HfFileSystem +from curated_transformers.repository.fsspec import FsspecArgs from curated_transformers.tokenizers import AutoTokenizer _MODELS = [ @@ -24,7 +25,7 @@ def test_auto_tokenizer(model_revision): def test_auto_tokenizer_fsspec(model_revision): name, revision = model_revision AutoTokenizer.from_fsspec( - fs=HfFileSystem(), model_path=name, fsspec_args={"revision": revision} + fs=HfFileSystem(), model_path=name, fsspec_args=FsspecArgs(revision=revision) ) AutoTokenizer.from_hf_hub(name=name, revision=revision) diff --git a/curated_transformers/tests/tokenizers/test_tokenizer.py b/curated_transformers/tests/tokenizers/test_tokenizer.py index 4e707148..5ce938ed 100644 --- a/curated_transformers/tests/tokenizers/test_tokenizer.py +++ b/curated_transformers/tests/tokenizers/test_tokenizer.py @@ -6,7 +6,6 @@ from curated_transformers.tokenizers import Tokenizer from curated_transformers.tokenizers.chunks import InputChunks, TextChunk -from curated_transformers.util.hf import TOKENIZER_JSON from ..compat import has_hf_transformers, transformers from ..utils import torch_assertclose @@ -91,7 +90,7 @@ def test_from_dir(toy_tokenizer, toy_tokenizer_path, sample_texts): @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") def test_from_json(toy_tokenizer_path, sample_texts): - with open(toy_tokenizer_path / TOKENIZER_JSON, encoding="utf-8") as f: + with open(toy_tokenizer_path / "tokenizer.json", encoding="utf-8") as f: tokenizer = Tokenizer.from_json(f.read()) hf_tokenizer = transformers.RobertaTokenizerFast.from_pretrained( str(toy_tokenizer_path) diff --git a/curated_transformers/tokenizers/auto_tokenizer.py b/curated_transformers/tokenizers/auto_tokenizer.py index b8361996..a42f38d0 100644 --- a/curated_transformers/tokenizers/auto_tokenizer.py +++ b/curated_transformers/tokenizers/auto_tokenizer.py @@ -1,12 +1,11 @@ from typing import Any, Dict, Optional, Type, cast from fsspec import AbstractFileSystem -from huggingface_hub.utils import EntryNotFoundError -from ..util.fsspec import get_config_model_type as get_model_type_fsspec -from ..util.fsspec import get_tokenizer_config as get_tokenizer_config_fsspec -from ..util.hf import TOKENIZER_JSON, get_config_model_type, get_file_metadata -from .hf_hub import FromHFHub, get_tokenizer_config +from ..repository.fsspec import FsspecArgs, FsspecRepository +from ..repository.hf_hub import HfHubRepository +from ..repository.repository import Repository, TokenizerRepository +from .hf_hub import FromHFHub from .legacy.bert_tokenizer import BERTTokenizer from .legacy.camembert_tokenizer import CamemBERTTokenizer from .legacy.llama_tokenizer import LlamaTokenizer @@ -62,7 +61,9 @@ def from_hf_hub_to_cache( :param revision: Model revision. """ - tokenizer_cls = _resolve_tokenizer_class_hf_hub(name, revision) + tokenizer_cls = _resolve_tokenizer_class( + TokenizerRepository(HfHubRepository(name, revision=revision)) + ) tokenizer_cls.from_hf_hub_to_cache(name=name, revision=revision) @classmethod @@ -71,7 +72,7 @@ def from_fsspec( *, fs: AbstractFileSystem, model_path: str, - fsspec_args: Optional[Dict[str, Any]] = None, + fsspec_args: Optional[FsspecArgs] = None, ) -> TokenizerBase: """ Construct a tokenizer and load its parameters from an fsspec filesystem. @@ -86,15 +87,22 @@ def from_fsspec( :returns: The tokenizer. """ - tokenizer_cls = _resolve_tokenizer_class_fsspec( - fs=fs, model_path=model_path, fsspec_args=fsspec_args - ) - # This cast is safe, because we only return tokenizers. + return cls.from_repo(FsspecRepository(fs, model_path, fsspec_args)) + + @classmethod + def from_repo(cls, repo: Repository) -> TokenizerBase: + """ + Construct and load a tokenizer from a repository. + + :param repository: + The repository to load from. + :returns: + Loaded tokenizer. + """ + tokenizer_cls = _resolve_tokenizer_class(TokenizerRepository(repo)) return cast( TokenizerBase, - tokenizer_cls.from_fsspec( - fs=fs, model_path=model_path, fsspec_args=fsspec_args - ), + tokenizer_cls.from_repo(repo), ) @classmethod @@ -109,12 +117,7 @@ def from_hf_hub(cls, *, name: str, revision: str = "main") -> TokenizerBase: :returns: The tokenizer. """ - - tokenizer_cls = _resolve_tokenizer_class_hf_hub(name, revision) - # This cast is safe, because we only return tokenizers. - return cast( - TokenizerBase, tokenizer_cls.from_hf_hub(name=name, revision=revision) - ) + return cls.from_repo(HfHubRepository(name, revision=revision)) def _get_tokenizer_class_from_config( @@ -133,65 +136,32 @@ def _get_tokenizer_class_from_config( return HF_TOKENIZER_MAPPING.get(tokenizer_config.get("tokenizer_class", None), None) -def _resolve_tokenizer_class_fsspec( - fs: AbstractFileSystem, - model_path: str, - fsspec_args: Optional[Dict[str, Any]] = None, +def _resolve_tokenizer_class( + repo: TokenizerRepository, ) -> Type[FromHFHub]: - fsspec_args = {} if fsspec_args is None else fsspec_args - tokenizer_cls: Optional[Type[FromHFHub]] = None - if fs.exists(f"{model_path}/{TOKENIZER_JSON}", **fsspec_args): - return Tokenizer - - if tokenizer_cls is None: - tokenizer_config = get_tokenizer_config_fsspec( - fs=fs, model_path=model_path, fsspec_args=fsspec_args - ) - if tokenizer_config is not None: - tokenizer_cls = _get_tokenizer_class_from_config(tokenizer_config) - - if tokenizer_cls is None: - model_type = get_model_type_fsspec( - fs=fs, model_path=model_path, fsspec_args=fsspec_args - ) - if model_type is not None: - tokenizer_cls = HF_MODEL_MAPPING.get(model_type, None) - - if tokenizer_cls is None: - raise ValueError(f"Cannot infer tokenizer for model at path: {model_path}") - - return tokenizer_cls - - -def _resolve_tokenizer_class_hf_hub(name: str, revision: str) -> Type[FromHFHub]: - tokenizer_cls: Optional[Type[FromHFHub]] = None + cls: Optional[Type[FromHFHub]] = None try: - # We will try to fetch metadata to avoid potentially downloading - # the tokenizer file twice (here and Tokenizer.from_hf_hub). - get_file_metadata(filename=TOKENIZER_JSON, name=name, revision=revision) - except EntryNotFoundError: + repo.tokenizer_json() + cls = Tokenizer + except: pass - else: - tokenizer_cls = Tokenizer - if tokenizer_cls is None: + if cls is None: try: - tokenizer_config = get_tokenizer_config(name=name, revision=revision) - except EntryNotFoundError: + tokenizer_config = repo.tokenizer_config() + cls = _get_tokenizer_class_from_config(tokenizer_config) + except: pass - else: - tokenizer_cls = _get_tokenizer_class_from_config(tokenizer_config) - if tokenizer_cls is None: + if cls is None: try: - model_type = get_config_model_type(name=name, revision=revision) - except EntryNotFoundError: + model_type = repo.model_type() + print(model_type) + cls = HF_MODEL_MAPPING.get(model_type) + except: pass - else: - tokenizer_cls = HF_MODEL_MAPPING.get(model_type, None) - if tokenizer_cls is None: - raise ValueError( - f"Cannot infer tokenizer for repository '{name}' with revision '{revision}'" - ) - return tokenizer_cls + if cls is None: + raise ValueError(f"Cannot infer tokenizer for repo: {repo.pretty_path()}") + + return cls diff --git a/curated_transformers/tokenizers/hf_hub.py b/curated_transformers/tokenizers/hf_hub.py index 2e719792..790cf8a9 100644 --- a/curated_transformers/tokenizers/hf_hub.py +++ b/curated_transformers/tokenizers/hf_hub.py @@ -4,9 +4,10 @@ from fsspec import AbstractFileSystem from huggingface_hub.utils import EntryNotFoundError -from ..util.fsspec import get_tokenizer_config as get_tokenizer_config_fsspec -from ..util.hf import get_tokenizer_config, hf_hub_download -from ..util.serde import FsspecModelFile, LocalModelFile, ModelFile +from ..repository.file import RepositoryFile +from ..repository.fsspec import FsspecArgs, FsspecRepository +from ..repository.hf_hub import HfHubRepository +from ..repository.repository import Repository, TokenizerRepository SelfFromHFHub = TypeVar("SelfFromHFHub", bound="FromHFHub") @@ -41,13 +42,12 @@ def from_hf_hub_to_cache( raise NotImplementedError @classmethod - @abstractmethod def from_fsspec( cls: Type[SelfFromHFHub], *, fs: AbstractFileSystem, model_path: str, - fsspec_args: Optional[Dict[str, Any]] = None, + fsspec_args: Optional[FsspecArgs] = None, ) -> SelfFromHFHub: """ Construct a tokenizer and load its parameters from an fsspec filesystem. @@ -62,10 +62,11 @@ def from_fsspec( :returns: The tokenizer. """ - raise NotImplementedError + return cls.from_repo( + repo=FsspecRepository(fs, model_path, fsspec_args), + ) @classmethod - @abstractmethod def from_hf_hub( cls: Type[SelfFromHFHub], *, name: str, revision: str = "main" ) -> SelfFromHFHub: @@ -79,7 +80,25 @@ def from_hf_hub( :returns: The tokenizer. """ - raise NotImplementedError + return cls.from_repo( + repo=HfHubRepository(name=name, revision=revision), + ) + + @classmethod + @abstractmethod + def from_repo( + cls: Type[SelfFromHFHub], + repo: Repository, + ) -> SelfFromHFHub: + """ + Construct and load a tokenizer from a repository. + + :param repository: + The repository to load from. + :returns: + Loaded tokenizer. + """ + ... SelfLegacyFromHFHub = TypeVar("SelfLegacyFromHFHub", bound="LegacyFromHFHub") @@ -103,7 +122,7 @@ class LegacyFromHFHub(FromHFHub): def _load_from_vocab_files( cls: Type[SelfLegacyFromHFHub], *, - vocab_files: Mapping[str, ModelFile], + vocab_files: Mapping[str, RepositoryFile], tokenizer_config: Optional[Dict[str, Any]], ) -> SelfLegacyFromHFHub: """ @@ -126,50 +145,28 @@ def from_hf_hub_to_cache( name: str, revision: str = "main", ): + repo = TokenizerRepository(HfHubRepository(name, revision=revision)) for _, filename in cls.vocab_files.items(): - _ = hf_hub_download(repo_id=name, filename=filename, revision=revision) + _ = repo.file(filename) try: - _ = get_tokenizer_config(name=name, revision=revision) + _ = repo.tokenizer_config() except EntryNotFoundError: pass @classmethod - def from_fsspec( + def from_repo( cls: Type[SelfLegacyFromHFHub], - *, - fs: AbstractFileSystem, - model_path: str, - fsspec_args: Optional[Dict[str, Any]] = None, - ) -> SelfLegacyFromHFHub: - vocab_files = {} - for vocab_file, filename in cls.vocab_files.items(): - vocab_files[vocab_file] = FsspecModelFile( - fs, f"{model_path}/{filename}", fsspec_args - ) - - tokenizer_config = get_tokenizer_config_fsspec( - fs=fs, model_path=model_path, fsspec_args=fsspec_args - ) - - return cls._load_from_vocab_files( - vocab_files=vocab_files, tokenizer_config=tokenizer_config - ) - - @classmethod - def from_hf_hub( - cls: Type[SelfLegacyFromHFHub], *, name: str, revision: str = "main" + repo: Repository, ) -> SelfLegacyFromHFHub: + repo = TokenizerRepository(repo) vocab_files = {} for vocab_file, filename in cls.vocab_files.items(): - vocab_files[vocab_file] = LocalModelFile( - hf_hub_download(repo_id=name, filename=filename, revision=revision) - ) + vocab_files[vocab_file] = repo.file(filename) - # Try to get the tokenizer configuration. try: - tokenizer_config = get_tokenizer_config(name=name, revision=revision) - except EntryNotFoundError: + tokenizer_config = repo.tokenizer_config() + except OSError: tokenizer_config = None return cls._load_from_vocab_files( diff --git a/curated_transformers/tokenizers/legacy/bert_tokenizer.py b/curated_transformers/tokenizers/legacy/bert_tokenizer.py index 86c7accb..98aeb4d1 100644 --- a/curated_transformers/tokenizers/legacy/bert_tokenizer.py +++ b/curated_transformers/tokenizers/legacy/bert_tokenizer.py @@ -3,7 +3,7 @@ from curated_tokenizers import WordPieceProcessor -from ...util.serde import ModelFile +from ...repository.file import RepositoryFile from .._hf_compat import clean_up_decoded_string_like_hf, tokenize_chinese_chars_bert from ..chunks import ( InputChunks, @@ -260,7 +260,7 @@ def __init__( def from_files( cls: Type[Self], *, - vocab_file: ModelFile, + vocab_file: RepositoryFile, bos_piece: str = "[CLS]", eos_piece: str = "[SEP]", unk_piece: str = "[UNK]", @@ -311,7 +311,7 @@ def eos_piece(self) -> Optional[str]: def _load_from_vocab_files( cls: Type[Self], *, - vocab_files: Mapping[str, ModelFile], + vocab_files: Mapping[str, RepositoryFile], tokenizer_config: Optional[Dict[str, Any]], ) -> Self: extra_kwargs = {} diff --git a/curated_transformers/tokenizers/legacy/camembert_tokenizer.py b/curated_transformers/tokenizers/legacy/camembert_tokenizer.py index d76f0033..b54b0642 100644 --- a/curated_transformers/tokenizers/legacy/camembert_tokenizer.py +++ b/curated_transformers/tokenizers/legacy/camembert_tokenizer.py @@ -2,7 +2,7 @@ from curated_tokenizers import SentencePieceProcessor -from ...util.serde import ModelFile +from ...repository.file import RepositoryFile from ..hf_hub import LegacyFromHFHub from ._fairseq import FAIRSEQ_PIECE_IDS, FairSeqPostEncoder, FairSeqPreDecoder from .legacy_tokenizer import AddBosEosPreEncoder @@ -111,7 +111,7 @@ def __init__( def from_files( cls: Type[Self], *, - model_file: ModelFile, + model_file: RepositoryFile, bos_piece: str = "", eos_piece: str = "", ) -> Self: @@ -137,7 +137,7 @@ def from_files( def _load_from_vocab_files( cls: Type[Self], *, - vocab_files: Mapping[str, ModelFile], + vocab_files: Mapping[str, RepositoryFile], tokenizer_config: Optional[Dict[str, Any]], ) -> Self: return cls.from_files(model_file=vocab_files["model"]) diff --git a/curated_transformers/tokenizers/legacy/llama_tokenizer.py b/curated_transformers/tokenizers/legacy/llama_tokenizer.py index d8bffe37..7a6f2a33 100644 --- a/curated_transformers/tokenizers/legacy/llama_tokenizer.py +++ b/curated_transformers/tokenizers/legacy/llama_tokenizer.py @@ -2,7 +2,7 @@ from curated_tokenizers import SentencePieceProcessor -from ...util.serde import ModelFile +from ...repository.file import RepositoryFile from ..hf_hub import LegacyFromHFHub from .legacy_tokenizer import AddBosEosPreEncoder from .sentencepiece_tokenizer import SentencePieceTokenizer @@ -53,7 +53,7 @@ def __init__( def from_files( cls: Type[Self], *, - model_file: ModelFile, + model_file: RepositoryFile, add_bos_piece: bool = True, add_eos_piece: bool = False, ) -> Self: @@ -79,7 +79,7 @@ def from_files( def _load_from_vocab_files( cls: Type[Self], *, - vocab_files: Mapping[str, ModelFile], + vocab_files: Mapping[str, RepositoryFile], tokenizer_config: Optional[Dict[str, Any]], ) -> Self: if tokenizer_config is None: diff --git a/curated_transformers/tokenizers/legacy/roberta_tokenizer.py b/curated_transformers/tokenizers/legacy/roberta_tokenizer.py index a3213dad..b296f7f1 100644 --- a/curated_transformers/tokenizers/legacy/roberta_tokenizer.py +++ b/curated_transformers/tokenizers/legacy/roberta_tokenizer.py @@ -2,7 +2,7 @@ from curated_tokenizers import ByteBPEProcessor -from ...util.serde import ModelFile +from ...repository.file import RepositoryFile from ..hf_hub import LegacyFromHFHub from ..util import remove_pieces_from_sequence from .bbpe_tokenizer import ByteBPETokenizer @@ -87,8 +87,8 @@ def __init__( def from_files( cls: Type[Self], *, - vocab_file: ModelFile, - merges_file: ModelFile, + vocab_file: RepositoryFile, + merges_file: RepositoryFile, bos_piece: str = "", eos_piece: str = "", ) -> Self: @@ -124,7 +124,7 @@ def eos_piece(self) -> Optional[str]: def _load_from_vocab_files( cls: Type[Self], *, - vocab_files: Mapping[str, ModelFile], + vocab_files: Mapping[str, RepositoryFile], tokenizer_config: Optional[Dict[str, Any]], ) -> Self: return cls.from_files( diff --git a/curated_transformers/tokenizers/legacy/xlmr_tokenizer.py b/curated_transformers/tokenizers/legacy/xlmr_tokenizer.py index 75a11091..86d4eca7 100644 --- a/curated_transformers/tokenizers/legacy/xlmr_tokenizer.py +++ b/curated_transformers/tokenizers/legacy/xlmr_tokenizer.py @@ -2,7 +2,7 @@ from curated_tokenizers import SentencePieceProcessor -from ...util.serde import ModelFile +from ...repository.file import RepositoryFile from ..hf_hub import LegacyFromHFHub from ._fairseq import FAIRSEQ_PIECE_IDS, FairSeqPostEncoder, FairSeqPreDecoder from .legacy_tokenizer import AddBosEosPreEncoder @@ -112,7 +112,7 @@ def __init__( def from_files( cls: Type[Self], *, - model_file: ModelFile, + model_file: RepositoryFile, ) -> Self: """ Construct a XLM-R tokenizer from a SentencePiece model. @@ -128,7 +128,7 @@ def from_files( def _load_from_vocab_files( cls: Type[Self], *, - vocab_files: Mapping[str, ModelFile], + vocab_files: Mapping[str, RepositoryFile], tokenizer_config: Optional[Dict[str, Any]], ) -> Self: return cls.from_files(model_file=vocab_files["model"]) diff --git a/curated_transformers/tokenizers/tokenizer.py b/curated_transformers/tokenizers/tokenizer.py index a69e46e5..795b016c 100644 --- a/curated_transformers/tokenizers/tokenizer.py +++ b/curated_transformers/tokenizers/tokenizer.py @@ -5,23 +5,15 @@ from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union, cast import torch -from fsspec import AbstractFileSystem -from huggingface_hub.utils import EntryNotFoundError +from fsspec.implementations.local import LocalFileSystem +from huggingface_hub import Repository from tokenizers import Tokenizer as HFTokenizer from torch import Tensor from ..layers.attention import AttentionMask -from ..util.fsspec import get_special_tokens_map as get_special_tokens_map_fsspec -from ..util.fsspec import get_tokenizer_config as get_tokenizer_config_fsspec -from ..util.hf import ( - HF_TOKENIZER_CONFIG, - SPECIAL_TOKENS_MAP, - TOKENIZER_JSON, - get_special_piece, - get_special_tokens_map, - get_tokenizer_config, - hf_hub_download, -) +from ..repository.fsspec import FsspecRepository +from ..repository.hf_hub import HfHubRepository +from ..repository.repository import Repository, TokenizerRepository from ._hf_compat import clean_up_decoded_string_like_hf from .chunks import InputChunks, MergedSpecialPieceChunk from .hf_hub import FromHFHub @@ -322,23 +314,7 @@ def from_dir(cls: Type[Self], path: Path) -> Self: :param path: Path to the tokenizer directory. """ - tokenizer_path = path / TOKENIZER_JSON - config_path = path / HF_TOKENIZER_CONFIG - special_tokens_map_path = path / SPECIAL_TOKENS_MAP - hf_tokenizer = HFTokenizer.from_file(str(tokenizer_path)) - config = None - if config_path.is_file(): - with open(config_path, encoding="utf-8") as f: - config = json.load(f) - special_tokens_map = None - if special_tokens_map_path.is_file(): - with open(special_tokens_map_path, encoding="utf-8") as f: - special_tokens_map = json.load(f) - return cls( - tokenizer=hf_tokenizer, - config=config, - special_tokens_map=special_tokens_map, - ) + return cls.from_repo(FsspecRepository(LocalFileSystem(), str(path))) @classmethod def from_hf_hub_to_cache( @@ -347,59 +323,35 @@ def from_hf_hub_to_cache( name: str, revision: str = "main", ): - _ = hf_hub_download(repo_id=name, filename=TOKENIZER_JSON, revision=revision) + repo = TokenizerRepository(HfHubRepository(name, revision=revision)) + repo.tokenizer_json() try: - _ = get_tokenizer_config(name=name, revision=revision) - except EntryNotFoundError: + _ = repo.tokenizer_config() + except: pass try: - _ = get_special_tokens_map(name=name, revision=revision) - except EntryNotFoundError: + _ = repo.special_tokens_map() + except: pass @classmethod - def from_fsspec( - cls: Type[Self], - *, - fs: AbstractFileSystem, - model_path: str, - fsspec_args: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Self: - tokenizer_path = f"{model_path}/tokenizer.json" - if not fs.exists(tokenizer_path, **kwargs): - raise ValueError(f"Path cannot be found: {tokenizer_path}") - with fs.open(tokenizer_path) as f: - hf_tokenizer = HFTokenizer.from_buffer(f.read()) - - config = get_tokenizer_config_fsspec(fs, model_path, fsspec_args) - special_tokens_map = get_special_tokens_map_fsspec(fs, model_path, fsspec_args) - - return cls( - tokenizer=hf_tokenizer, - config=config, - special_tokens_map=special_tokens_map, - ) - - @classmethod - def from_hf_hub(cls: Type[Self], *, name: str, revision: str = "main") -> Self: - # We cannot directly use `HFTokenizer.from_pretrained`` to instantiate the HF - # tokenizer as it doesn't fetch the serialized files using the `huggingface_hub` - # library, which prevents us from being able to download models that require - # authentication. - tokenizer_path = hf_hub_download( - repo_id=name, filename=TOKENIZER_JSON, revision=revision - ) - hf_tokenizer = HFTokenizer.from_file(tokenizer_path) + def from_repo(cls: Type[Self], repo: Repository) -> Self: + repo = TokenizerRepository(repo) + tokenizer_file = repo.tokenizer_json() + if tokenizer_file.path is not None: + hf_tokenizer = HFTokenizer.from_file(tokenizer_file.path) + else: + with tokenizer_file.open() as f: + hf_tokenizer = HFTokenizer.from_buffer(f.read()) try: - config = get_tokenizer_config(name=name, revision=revision) - except EntryNotFoundError: + config = repo.tokenizer_config() + except OSError: config = None try: - special_tokens_map = get_special_tokens_map(name=name, revision=revision) - except EntryNotFoundError: + special_tokens_map = repo.special_tokens_map() + except OSError: special_tokens_map = None return cls( tokenizer=hf_tokenizer, @@ -439,3 +391,23 @@ def from_json( def piece_to_id(self, piece: str) -> Optional[int]: return self.tokenizer.token_to_id(piece) + + +def get_special_piece( + special_tokens_map: Dict[str, Any], piece_name: str +) -> Optional[str]: + """ + Get a special piece from the special tokens map or the tokenizer + configuration. + + :param special_tokens_map: + The special tokens map. + :param piece_name: + The piece to look up. + :returns: + The piece or ``None`` if this particular piece was not defined. + """ + piece = special_tokens_map.get(piece_name) + if isinstance(piece, dict): + piece = piece.get("content") + return piece diff --git a/curated_transformers/util/fsspec.py b/curated_transformers/util/fsspec.py deleted file mode 100644 index 3a27c79a..00000000 --- a/curated_transformers/util/fsspec.py +++ /dev/null @@ -1,291 +0,0 @@ -import json -import os -from typing import Any, Dict, List, Optional, Tuple - -from fsspec import AbstractFileSystem - -from .._compat import has_safetensors -from .hf import ( - HF_MODEL_CONFIG, - HF_TOKENIZER_CONFIG, - PRIMARY_CHECKPOINT_FILENAMES, - SHARDED_CHECKPOINT_INDEX_FILENAMES, - SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY, - SPECIAL_TOKENS_MAP, -) -from .serde import ( - _MODEL_CHECKPOINT_TYPE, - FsspecModelFile, - ModelCheckpointType, - ModelFile, -) - - -def get_file_metadata( - *, - fs: AbstractFileSystem, - model_path: str, - filename: str, - fsspec_args: Optional[Dict[str, Any]] = None, -) -> Optional[Dict[str, Any]]: - """ - Get a file from a model on an fsspec filesystem. - - :param fs: - The filesystem on which the model is stored. - :param model_path: - The path of the model on the filesystem. - :param filename: - The file to get metadata for. - :param fsspec_args: - Implementation-specific keyword arguments to pass to fsspec - filesystem operations. - :returns: - File metadata as a dictionary or ``None`` if the file does not - exist. - - """ - index = get_path_index(fs, model_path, fsspec_args=fsspec_args) - return index.get(filename) - - -def get_model_config( - fs: AbstractFileSystem, - model_path: str, - fsspec_args: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: - """ - Get the configuation of a model on an fsspec filesystem. - - :param fs: - The filesystem on which the model is stored. - :param model_path: - The path of the model on the filesystem. - :param fsspec_args: - Implementation-specific keyword arguments to pass to fsspec - filesystem operations. - :returns: - The model configuration. - """ - config = _get_and_parse_json_file( - fs, - path=f"{model_path}/{HF_MODEL_CONFIG}", - fsspec_args=fsspec_args, - ) - if config is None: - raise ValueError( - f"Cannot open model config path: {model_path}/{HF_MODEL_CONFIG}" - ) - return config - - -def get_config_model_type( - fs: AbstractFileSystem, - model_path: str, - fsspec_args: Optional[Dict[str, Any]] = None, -) -> Optional[str]: - """ - Get the type of a model on an fsspec filesystem. - - :param fs: - The filesystem on which the model is stored. - :param model_path: - The path of the model on the filesystem. - :param fsspec_args: - Implementation-specific keyword arguments to pass to fsspec - filesystem operations. - :returns: - The model type. - """ - config = get_model_config(fs, model_path, fsspec_args=fsspec_args) - return config.get("model_type") - - -def get_tokenizer_config( - fs: AbstractFileSystem, - model_path: str, - fsspec_args: Optional[Dict[str, Any]] = None, -) -> Optional[Dict[str, Any]]: - """ - Get the configuration of a tokenizer on an fsspec filesystem. - - :param fs: - The filesystem on which the model is stored. - :param model_path: - The path of the model on the filesystem. - :param fsspec_args: - Implementation-specific keyword arguments to pass to fsspec - filesystem operations. - :returns: - Deserialized tokenizer configuration. - """ - return _get_and_parse_json_file( - fs, - path=f"{model_path}/{HF_TOKENIZER_CONFIG}", - fsspec_args=fsspec_args, - ) - - -def get_special_tokens_map( - fs: AbstractFileSystem, - model_path: str, - fsspec_args: Optional[Dict[str, Any]] = None, -) -> Optional[Dict[str, Any]]: - """ - Get the special token mapping of a tokenizer on an fsspec filesystem. - - :param fs: - The filesystem on which the model is stored. - :param model_path: - The path of the model on the filesystem. - :param fsspec_args: - Implementation-specific keyword arguments to pass to fsspec - filesystem operations. - :returns: - Deserialized special token_map. - """ - return _get_and_parse_json_file( - fs, path=f"{model_path}/{SPECIAL_TOKENS_MAP}", fsspec_args=fsspec_args - ) - - -def _get_and_parse_json_file( - fs: AbstractFileSystem, - *, - path: str, - fsspec_args: Optional[Dict[str, Any]] = None, -) -> Optional[Dict[str, Any]]: - """ - Get a JSON file from an fsspec filesystem and parse it. - - :param fs: - The filesystem on which the model is stored. - :param path: - The path of the JSON file. - :param fsspec_args: - Implementation-specific keyword arguments to pass to fsspec - filesystem operations. - :returns: - List of absolute paths to the checkpoints - and the checkpoint type. - """ - fsspec_args = {} if fsspec_args is None else fsspec_args - - if not fs.exists(path, **fsspec_args): - return None - - with fs.open(path, "r", encoding="utf-8", **fsspec_args) as f: - config = json.load(f) - return config - - -def get_model_checkpoint_files( - fs: AbstractFileSystem, - model_path: str, - fsspec_args: Optional[Dict[str, Any]] = None, -) -> Tuple[List[ModelFile], ModelCheckpointType]: - """ - Return a list of local file paths to checkpoints that belong to the model - on an fsspec filesystem. In case of non-sharded models, a single file path - is returned. In case of sharded models, multiple file paths are returned. - - :param fs: - The filesystem on which the model is stored. - :param model_path: - The path of the model on the filesystem. - :param fsspec_args: - Implementation-specific keyword arguments to pass to fsspec - filesystem operations. - :returns: - List of absolute paths to the checkpoints - and the checkpoint type. - """ - fsspec_args = {} if fsspec_args is None else fsspec_args - - def get_checkpoint_paths( - checkpoint_type: ModelCheckpointType, - ) -> List[ModelFile]: - index = get_path_index(fs, model_path, fsspec_args=fsspec_args) - - # Attempt to download a non-sharded checkpoint first. - entry = index.get(PRIMARY_CHECKPOINT_FILENAMES[checkpoint_type]) - if entry is not None: - return [FsspecModelFile(fs, entry["name"], fsspec_args)] - - # Try sharded checkpoint. - index_filename = SHARDED_CHECKPOINT_INDEX_FILENAMES[checkpoint_type] - entry = index.get(index_filename) - if entry is None: - raise ValueError( - f"Couldn't find a valid {checkpoint_type.pretty_name} checkpoint for " - f"model with path `{model_path}`. Could not open {index_filename}" - ) - - with fs.open(entry["name"], "rb", **fsspec_args) as f: - index = json.load(f) - - weight_map = index.get(SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY) - if not isinstance(weight_map, dict): - raise ValueError( - f"Invalid index file in sharded {checkpoint_type.pretty_name} " - f"checkpoint for model with path `{model_path}`" - ) - - filepaths = [] - # We shouldn't need to hold on to the weights map in the index as each checkpoint - # should contain its constituent parameter names. - for filename in set(weight_map.values()): - filepaths.append(f"{model_path}/{filename}") - - return [FsspecModelFile(fs, path, fsspec_args) for path in sorted(filepaths)] - - checkpoint_type = _MODEL_CHECKPOINT_TYPE.get() - checkpoint_paths: Optional[List[ModelFile]] = None - - if checkpoint_type is None: - # Precedence: Safetensors > PyTorch - if has_safetensors: - try: - checkpoint_type = ModelCheckpointType.SAFE_TENSORS - checkpoint_paths = get_checkpoint_paths(checkpoint_type) - except ValueError: - pass - if checkpoint_paths is None: - checkpoint_type = ModelCheckpointType.PYTORCH_STATE_DICT - checkpoint_paths = get_checkpoint_paths(checkpoint_type) - else: - checkpoint_paths = get_checkpoint_paths(checkpoint_type) - - assert checkpoint_paths is not None - assert checkpoint_type is not None - return checkpoint_paths, checkpoint_type - - -def get_path_index( - fs: AbstractFileSystem, - path: str, - fsspec_args: Optional[Dict[str, Any]] = None, -) -> Dict[str, Dict[str, Any]]: - """ - Get the files and their metadata of a model on an fsspec filesystem. - - :param fs: - The filesystem on which the model is stored. - :param path: - The path to return the index for. - :param fsspec_args: - Implementation-specific keyword arguments to pass to fsspec - filesystem operations. - :returns: - List of absolute paths to the checkpoints - and the checkpoint type. - """ - fsspec_args = {} if fsspec_args is None else fsspec_args - - try: - return { - os.path.basename(entry["name"]): entry - for entry in fs.ls(path, **fsspec_args) - } - except FileNotFoundError: - raise ValueError(f"Path cannot be found: {path}") diff --git a/curated_transformers/util/hf.py b/curated_transformers/util/hf.py deleted file mode 100644 index 3fafef09..00000000 --- a/curated_transformers/util/hf.py +++ /dev/null @@ -1,301 +0,0 @@ -import json -import warnings -from typing import Any, Dict, List, Optional, Tuple - -import huggingface_hub -from requests import HTTPError, ReadTimeout # type: ignore - -from .._compat import has_safetensors -from .serde import ( - _MODEL_CHECKPOINT_TYPE, - LocalModelFile, - ModelCheckpointType, - ModelFile, -) - -HF_MODEL_CONFIG = "config.json" -HF_MODEL_CHECKPOINT = "pytorch_model.bin" -HF_MODEL_CHECKPOINT_SAFETENSORS = "model.safetensors" -HF_MODEL_SHARDED_CHECKPOINT_INDEX = "pytorch_model.bin.index.json" -HF_MODEL_SHARDED_CHECKPOINT_INDEX_SAFETENSORS = "model.safetensors.index.json" -HF_MODEL_SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY = "weight_map" -HF_TOKENIZER_CONFIG = "tokenizer_config.json" -SPECIAL_TOKENS_MAP = "special_tokens_map.json" -TOKENIZER_JSON = "tokenizer.json" - -PRIMARY_CHECKPOINT_FILENAMES = { - ModelCheckpointType.PYTORCH_STATE_DICT: HF_MODEL_CHECKPOINT, - ModelCheckpointType.SAFE_TENSORS: HF_MODEL_CHECKPOINT_SAFETENSORS, -} -SHARDED_CHECKPOINT_INDEX_FILENAMES = { - ModelCheckpointType.PYTORCH_STATE_DICT: HF_MODEL_SHARDED_CHECKPOINT_INDEX, - ModelCheckpointType.SAFE_TENSORS: HF_MODEL_SHARDED_CHECKPOINT_INDEX_SAFETENSORS, -} -# Same for both checkpoint types. -SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY = HF_MODEL_SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY - - -def get_file_metadata( - *, filename: str, name: str, revision: str -) -> huggingface_hub.HfFileMetadata: - """ - Get the metadata for a file on Huggingface Hub. - - :param filename: - The file to get the metadata for. - :param name: - Model name. - :param revision: - Model revision. - """ - url = huggingface_hub.hf_hub_url(name, filename, revision=revision) - return huggingface_hub.get_hf_file_metadata(url) - - -def get_config_model_type(name: str, revision: str) -> str: - """ - Get the type of a model on Hugging Face Hub. - - :param name: - The model to get the type of. - :param revision: - The revision of the model. - """ - config = get_model_config(name, revision) - model_type = config.get("model_type") - if model_type is None: - raise ValueError( - f"Model type not found in Hugging Face model config for model '{name}' ({revision})" - ) - return model_type - - -def get_model_config(name: str, revision: str) -> Dict[str, Any]: - """ - Return the model's configuration. If the config is not found in the - cache, it is downloaded from Hugging Face Hub. - - :param name: - Model name. - :param revision: - Model revision. - :returns: - Model configuration. - """ - try: - path = hf_hub_download( - repo_id=name, filename=HF_MODEL_CONFIG, revision=revision - ) - except: - raise ValueError( - f"Couldn't find a valid config for model `{name}` " - f"(revision `{revision}`) on HuggingFace Model Hub" - ) - - with open(path, "r") as f: - config = json.load(f) - return config - - -def get_model_checkpoint_files( - name: str, revision: str -) -> Tuple[List[ModelFile], ModelCheckpointType]: - """ - Return a list of local file paths to checkpoints that belong to the Hugging - Face model. In case of non-sharded models, a single file path is returned. In - case of sharded models, multiple file paths are returned. - - If the checkpoints are not found in the cache, they are downloaded from - Hugging Face Hub. - - :param name: - Model name. - :param revision: - Model revision. - :returns: - List of absolute paths to the checkpoints - and the checkpoint type. - """ - - def get_checkpoint_paths( - checkpoint_type: ModelCheckpointType, - ) -> List[ModelFile]: - # Attempt to download a non-sharded checkpoint first. - try: - model_filename = hf_hub_download( - repo_id=name, - filename=PRIMARY_CHECKPOINT_FILENAMES[checkpoint_type], - revision=revision, - ) - except HTTPError: - # We'll get a 404 HTTP error for sharded models. - model_filename = None - - if model_filename is not None: - return [LocalModelFile(model_filename)] - - try: - model_index_filename = hf_hub_download( - repo_id=name, - filename=SHARDED_CHECKPOINT_INDEX_FILENAMES[checkpoint_type], - revision=revision, - ) - except HTTPError: - raise ValueError( - f"Couldn't find a valid {checkpoint_type.pretty_name} checkpoint for " - f"model `{name}` (revision `{revision}`) on HuggingFace Model Hub" - ) - - with open(model_index_filename, "r") as f: - index = json.load(f) - - weight_map = index.get(SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY) - if not isinstance(weight_map, dict): - raise ValueError( - f"Invalid index file in sharded {checkpoint_type.pretty_name} " - f"checkpoint for model `{name}`" - ) - - filepaths = [] - # We shouldn't need to hold on to the weights map in the index as each checkpoint - # should contain its constituent parameter names. - for filename in set(weight_map.values()): - resolved_filename = hf_hub_download( - repo_id=name, filename=filename, revision=revision - ) - filepaths.append(resolved_filename) - - return [LocalModelFile(path) for path in sorted(filepaths)] - - checkpoint_type = _MODEL_CHECKPOINT_TYPE.get() - checkpoint_paths: Optional[List[ModelFile]] = None - - if checkpoint_type is None: - # Precedence: Safetensors > PyTorch - if has_safetensors: - try: - checkpoint_type = ModelCheckpointType.SAFE_TENSORS - checkpoint_paths = get_checkpoint_paths(checkpoint_type) - except ValueError: - pass - if checkpoint_paths is None: - checkpoint_type = ModelCheckpointType.PYTORCH_STATE_DICT - checkpoint_paths = get_checkpoint_paths(checkpoint_type) - else: - checkpoint_paths = get_checkpoint_paths(checkpoint_type) - - assert checkpoint_paths is not None - assert checkpoint_type is not None - return checkpoint_paths, checkpoint_type - - -def get_special_piece( - special_tokens_map: Dict[str, Any], piece_name: str -) -> Optional[str]: - """ - Get a special piece from the special tokens map or the tokenizer - configuration. - - :param special_tokens_map: - The special tokens map. - :param piece_name: - The piece to look up. - :returns: - The piece or ``None`` if this particular piece was not defined. - """ - piece = special_tokens_map.get(piece_name) - if isinstance(piece, dict): - piece = piece.get("content") - return piece - - -def get_special_tokens_map(*, name: str, revision="main") -> Dict[str, Any]: - """ - Get a tokenizer's special token mapping. - - :param name: - Model name. - :param revision: - Model revision. - :returns: - Deserialized special token_map. - """ - return _get_and_parse_json_file( - name=name, revision=revision, filename=SPECIAL_TOKENS_MAP - ) - - -def get_tokenizer_config(*, name: str, revision="main") -> Dict[str, Any]: - """ - Get a tokenizer configuration. - - :param name: - Model name. - :param revision: - Model revision. - :returns: - Deserialized tokenizer configuration. - """ - return _get_and_parse_json_file( - name=name, revision=revision, filename=HF_TOKENIZER_CONFIG - ) - - -def _get_and_parse_json_file( - *, name: str, revision: str, filename: str -) -> Dict[str, Any]: - """ - Get and parse a JSON file. - - :param name: - Model name. - :param revision: - Model revision. - :param filename: - File to download and parse. - :returns: - Deserialized JSON file. - """ - config_path = hf_hub_download(repo_id=name, filename=filename, revision=revision) - with open(config_path, encoding="utf-8") as f: - return json.load(f) - - -def hf_hub_download(repo_id: str, filename: str, revision: str) -> str: - """ - Resolve the provided filename and repository to a local file path. If the file - is not present in the cache, it will be downloaded from the Hugging Face Hub. - - :param repo_id: - Identifier of the source repository on Hugging Face Hub. - :param filename: - Name of the file in the source repository to download. - :param revision: - Source repository revision. Can either be a branch name - or a SHA hash of a commit. - :returns: - Resolved absolute filepath. - """ - - # The HF Hub library's `hf_hub_download` function will always attempt to connect to the - # remote repo and fetch its metadata even if it's locally cached (in order to update the - # out-of-date artifacts in the cache). This can occasionally lead to `HTTPError/ReadTimeout`s if the - # remote host is unreachable. Instead of failing loudly, we'll add a fallback that checks - # the local cache for the artifacts and uses them if found. - try: - resolved = huggingface_hub.hf_hub_download( - repo_id=repo_id, filename=filename, revision=revision - ) - except (HTTPError, ReadTimeout) as e: - # Attempt to check the cache. - resolved = huggingface_hub.try_to_load_from_cache( - repo_id=repo_id, filename=filename, revision=revision - ) - if resolved is None or resolved is huggingface_hub._CACHED_NO_EXIST: - # Not found, rethrow. - raise e - else: - warnings.warn( - f"Couldn't reach Hugging Face Hub; using cached artifact for '{repo_id}@{revision}:{filename}'" - ) - return resolved diff --git a/curated_transformers/util/serde.py b/curated_transformers/util/serde.py index 6794e5e0..10f59045 100644 --- a/curated_transformers/util/serde.py +++ b/curated_transformers/util/serde.py @@ -1,30 +1,13 @@ -from abc import ABC, abstractmethod from contextlib import contextmanager -from contextvars import ContextVar -from enum import Enum -from typing import ( - IO, - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - Mapping, - Optional, - Set, - Union, -) +from typing import Callable, Dict, Iterable, Mapping, Optional, Set, Union import torch -from fsspec import AbstractFileSystem from torch.nn import Module, Parameter -from .._compat import has_safetensors +from ..repository._hf import _MODEL_CHECKPOINT_TYPE, ModelCheckpointType +from ..repository.file import RepositoryFile from .pytorch import ModuleIterator, apply_to_module -if TYPE_CHECKING: - import safetensors - # Args: Parent module, module prefix, parameter name, tensor to convert, device. # Returns the new paramater. TensorToParameterConverterT = Callable[ @@ -37,137 +20,6 @@ [Mapping[str, torch.Tensor]], Mapping[str, torch.Tensor] ] -PathOrFileDescriptor = Union[str, IO] - - -class ModelFile(ABC): - """ - Model files can be a local path or a remote path exposed as e.g. an I/O - stream. This is a common base class for such different types of model - files. - """ - - @abstractmethod - def open(self, mode: str = "rb", encoding: Optional[str] = None) -> IO: - """ - Get the model file as an I/O stream. - - :param mode: - Mode to open the file with (see Python ``open``). - :param encoding: - Encoding to use when the file is opened as text. - :returns: - An I/O stream. - """ - ... - - @property - @abstractmethod - def path(self) -> Optional[str]: - """ - Get the model file as a local path. If the model file is not - available as a local path, the value of this property is - ``None``. - """ - ... - - -class FsspecModelFile(ModelFile): - """ - Model file on an fsspec filesystem. - """ - - def __init__( - self, - fs: AbstractFileSystem, - path: str, - fsspec_args: Optional[Dict[str, Any]] = None, - ): - """ - Construct an fsspec model file representation. - - :param fs: - The filesystem. - :param path: - The path of the model file on the filesystem. - :param fsspec_args: - Implementation-specific keyword arguments to pass to fsspec - filesystem operations. - """ - super().__init__() - self._fs = fs - self._path = path - self._fsspec_args = fsspec_args - - def open(self, mode: str = "rb", encoding: Optional[str] = None) -> IO: - return self._fs.open( - self._path, mode=mode, encoding=encoding, **self._fsspec_args - ) - - @property - def path(self) -> Optional[str]: - return None - - -class LocalModelFile(ModelFile): - """ - Model file on the local host machine. - """ - - def __init__(self, path: str): - """ - Construct a local model file representation. - - :param path: - The path of the model file on the local filesystem. - """ - super().__init__() - self._path = path - - def open(self, mode: str = "rb", encoding: Optional[str] = None) -> IO: - return open(self._path, mode=mode, encoding=encoding) - - @property - def path(self) -> Optional[str]: - return self._path - - -class ModelCheckpointType(Enum): - """ - Types of model checkpoints supported by Curated Transformers. - """ - - #: PyTorch `checkpoint`_. - PYTORCH_STATE_DICT = 0 - - #: Hugging Face `Safetensors `_ checkpoint. - SAFE_TENSORS = 1 - - @property - def loader( - self, - ) -> Callable[[Iterable[ModelFile]], Iterable[Mapping[str, torch.Tensor]]]: - checkpoint_type_to_loader = { - ModelCheckpointType.PYTORCH_STATE_DICT: _load_pytorch_state_dicts_from_checkpoints, - ModelCheckpointType.SAFE_TENSORS: _load_safetensor_state_dicts_from_checkpoints, - } - return checkpoint_type_to_loader[self] - - @property - def pretty_name(self) -> str: - if self == ModelCheckpointType.PYTORCH_STATE_DICT: - return "PyTorch StateDict" - elif self == ModelCheckpointType.SAFE_TENSORS: - return "SafeTensors" - else: - return "" - - -# When `None`, behaviour is implementation-specific. -_MODEL_CHECKPOINT_TYPE: ContextVar[Optional[ModelCheckpointType]] = ContextVar( - "model_checkpoint_type", default=None -) - @contextmanager def _use_model_checkpoint_type( @@ -195,7 +47,7 @@ def _use_model_checkpoint_type( def load_model_from_checkpoints( model: Module, *, - filepaths: Iterable[ModelFile], + filepaths: Iterable[RepositoryFile], checkpoint_type: ModelCheckpointType, state_dict_converter: HFStateDictConverterT, tensor_to_param_converter: Optional[TensorToParameterConverterT] = None, @@ -375,41 +227,3 @@ def _validate_replacement( raise ValueError( f"Expected dtype of replacement for `{name}` to be {replaced.dtype}, but got {replacement.dtype}" ) - - -def _load_safetensor_state_dicts_from_checkpoints( - checkpoints: Iterable[ModelFile], -) -> Iterable[Mapping[str, torch.Tensor]]: - if not has_safetensors: - raise ValueError( - "The `safetensors` library is required to load models from Safetensors checkpoints" - ) - - import safetensors.torch - - for checkpoint in checkpoints: - # Prefer to load from a path when possible. Since loading from a file - # temporarily puts the checkpoint in memory twice. - if checkpoint.path is not None: - # Map to CPU first to support all devices. - state_dict = safetensors.torch.load_file(checkpoint.path, device="cpu") - else: - with checkpoint.open() as f: - # This has memory overhead, since Safetensors does not have - # support for loading from a file object and cannot use - # the bytes in-place. - checkpoint_bytes = f.read() - state_dict = safetensors.torch.load(checkpoint_bytes) - yield state_dict - - -def _load_pytorch_state_dicts_from_checkpoints( - checkpoints: Iterable[ModelFile], -) -> Iterable[Mapping[str, torch.Tensor]]: - for checkpoint in checkpoints: - with checkpoint.open() as f: - # Map to CPU first to support all devices. - state_dict = torch.load( - f, map_location=torch.device("cpu"), weights_only=True - ) - yield state_dict diff --git a/docs/source/api.rst b/docs/source/api.rst index 6af19588..03cf662b 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -9,6 +9,7 @@ API decoders causal-lm generation + repositories tokenizers quantization utils diff --git a/docs/source/repositories.rst b/docs/source/repositories.rst new file mode 100644 index 00000000..83290c4d --- /dev/null +++ b/docs/source/repositories.rst @@ -0,0 +1,62 @@ +Repositories +============ + +Models and tokenizers can be loaded from repositories using the ``from_repo`` +method. You can add your own type of repository by implementing the +:py:class:`curated_transformers.repository.Repository` base class. + +This is an example repository that opens files on the local filesystem: + +.. code-block:: python + + import os.path + from typing import Optional + + from curated_transformers.repository import Repository, RepositoryFile + + class LocalRepository(Repository): + def __init__(self, path: str): + super().__init__() + self.repo_path = path + + def file(self, path: str) -> RepositoryFile: + full_path = f"{self.repo_path}/path" + if not os.path.isfile(full_path): + raise FileNotFoundError(f"File not found: {full_path}") + return LocalFile(path=full_path) + + def pretty_path(self, path: Optional[str] = None) -> str: + return self.full_path + +Base Classes +------------ + +.. autoclass:: curated_transformers.repository.Repository + :members: + :show-inheritance: + +.. autoclass:: curated_transformers.repository.RepositoryFile + :members: + :show-inheritance: + +Repositories +------------ + +.. autoclass:: curated_transformers.repository.FsspecRepository + :members: + :show-inheritance: + +.. autoclass:: curated_transformers.repository.HfHubRepository + :members: + :show-inheritance: + +Repository Files +---------------- + +.. autoclass:: curated_transformers.repository.FsspecFile + :members: + :show-inheritance: + +.. autoclass:: curated_transformers.repository.LocalFile + :members: + :show-inheritance: diff --git a/docs/source/usage.rst b/docs/source/usage.rst index f13a5182..d8c16583 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -143,6 +143,7 @@ using the ``from_fsspec`` method. import torch from curated_transformers.models import BERTEncoder + from curated_transformers.repository import FsspecArgs from fsspec.implementations.local import LocalFileSystem from huggingface_hub import HfFileSystem @@ -156,7 +157,7 @@ using the ``from_fsspec`` method. encoder = BERTEncoder.from_fsspec( fs=HfFileSystem(), model_path="bert-base-uncased", - fsspec_args={"revision": "a265f773a47193eed794233aa2a0f0bb6d3eaa63"}, + fsspec_args=FsspecArgs(revision= "a265f773a47193eed794233aa2a0f0bb6d3eaa63"), device=torch.device("cuda", index=0), )