Skip to content

Commit

Permalink
Rebase on curated-transformers 2.0.0 (#19)
Browse files Browse the repository at this point in the history
* Update requirements and tick version to `0.3.0.dev0`

* Re-add `ScalarWeightClassifier`

* Rebase all transformer model architectures to `curated-transformers 1.0.0`

* Fix `curated-tojenizers 0.9.0` compatibility

* Set version to `1.0.0.dev0`

* `isort`

* Increment min Python version to 3.8 in CI

* Fix: do not load full state dict when loading locally

This will fail, since state dicts contain unused parameters (e.g. the LM
head in the case of encoders).

* Update to Curated Transformers 2.0.0.dev1

* Correctness: call `from_hf_hub` on type, not instance

* isort

* CI: Bump up transformers upper bound

* Download config.json in checkpoint loader test

* tests: do not fail if start method is already set

* tests: skip multiprocessing test when method is not 'spawn'

* Update checkpoint loader docstring to mention safetensors

---------

Co-authored-by: Daniël de Kok <[email protected]>
  • Loading branch information
shadeMe and danieldk authored Apr 9, 2024
1 parent 1272311 commit 4d8ccb4
Show file tree
Hide file tree
Showing 12 changed files with 262 additions and 118 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Test
on: [push, pull_request, workflow_call]

env:
hf-transformers-pip: transformers[sentencepiece]>=3.4.0,<4.32.0
hf-transformers-pip: transformers[sentencepiece]>=4.39.0,<4.40.0

jobs:
validate:
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
curated-transformers>=0.1.0,<0.2.0
curated-tokenizers>=0.0.9,<0.1.0
curated-transformers>=2.0.0.dev1,<3.0.0
curated-tokenizers>=0.9.2,<1.0.0
fsspec>=2023.5.0
spacy>=4.0.0.dev2,<5.0.0
thinc>=9.0.0.dev4,<9.1.0
srsly
Expand Down
10 changes: 6 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[metadata]
version = 0.2.2
version = 1.0.0.dev0
description = Curated transformer models for spaCy pipelines
url = https:/explosion/spacy-curated-transformers
author = Explosion
Expand All @@ -12,10 +12,11 @@ long_description_content_type = text/markdown
[options]
zip_safe = true
include_package_data = true
python_requires = >=3.6
python_requires = >=3.8
install_requires =
curated-transformers>=0.1.0,<0.2.0
curated-tokenizers>=0.0.9,<0.1.0
curated-transformers>=2.0.0.dev1,<3.0.0
curated-tokenizers>=0.9.2,<1.0.0
fsspec>=2023.5.0
spacy>=4.0.0.dev2,<5.0.0
thinc>=9.0.0.dev4,<9.1.0
torch>=1.12.0
Expand Down Expand Up @@ -56,6 +57,7 @@ thinc_model_loaders =
spacy-curated-transformers.HFTransformerEncoderLoader.v1 = spacy_curated_transformers.models:build_hf_transformer_encoder_loader_v1
spacy-curated-transformers.HFPieceEncoderLoader.v1 = spacy_curated_transformers.tokenization:build_hf_piece_encoder_loader_v1
spacy-curated-transformers.PyTorchCheckpointLoader.v1 = spacy_curated_transformers.models:build_pytorch_checkpoint_loader_v1
spacy-curated-transformers.PyTorchCheckpointLoader.v2 = spacy_curated_transformers.models:build_pytorch_checkpoint_loader_v2
spacy-curated-transformers.SentencepieceLoader.v1 = spacy_curated_transformers.tokenization:build_sentencepiece_encoder_loader_v1
spacy-curated-transformers.WordpieceLoader.v1 = spacy_curated_transformers.tokenization:build_wordpiece_encoder_loader_v1

Expand Down
169 changes: 101 additions & 68 deletions spacy_curated_transformers/models/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,21 @@
from typing import Any, Callable, List, Optional, Tuple, Union, cast

import torch
from curated_transformers.models.albert import AlbertConfig, AlbertEncoder
from curated_transformers.models.bert import BertConfig, BertEncoder
from curated_transformers.models.curated_transformer import (
CuratedEncoderT,
CuratedTransformer,
from curated_transformers.layers import Activation, AttentionMask
from curated_transformers.models import (
ALBERTConfig,
ALBERTEncoder,
BERTConfig,
BERTEncoder,
CamemBERTEncoder,
EncoderModule,
FromHFHub,
ModelOutput,
RoBERTaConfig,
RoBERTaEncoder,
XLMREncoder,
)
from curated_transformers.models.hf_util import convert_pretrained_model_for_encoder
from curated_transformers.models.output import PyTorchTransformerOutput
from curated_transformers.models.roberta import RobertaConfig, RobertaEncoder
from fsspec.implementations.local import LocalFileSystem
from spacy.tokens import Doc
from spacy.util import SimpleFrozenDict
from thinc.api import (
Expand Down Expand Up @@ -125,33 +131,34 @@ def build_albert_transformer_model_v1(
Optional listener to wrap. Only used when replacing listeners
in downstream components.
"""
config = AlbertConfig(
config = ALBERTConfig(
embedding_width=embedding_width,
hidden_width=hidden_width,
intermediate_width=intermediate_width,
num_attention_heads=num_attention_heads,
num_hidden_groups=num_hidden_groups,
num_hidden_layers=num_hidden_layers,
n_attention_heads=num_attention_heads,
n_hidden_groups=num_hidden_groups,
n_hidden_layers=num_hidden_layers,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
vocab_size=vocab_size,
type_vocab_size=type_vocab_size,
max_position_embeddings=max_position_embeddings,
activation=Activation(hidden_act),
n_pieces=vocab_size,
n_types=type_vocab_size,
n_positions=max_position_embeddings,
model_max_length=model_max_length,
layer_norm_eps=layer_norm_eps,
padding_idx=padding_idx,
)

if torchscript:
transformer = _torchscript_encoder(
model_max_length=model_max_length, padding_idx=padding_idx
)
else:
encoder = AlbertEncoder(config)
encoder = ALBERTEncoder(config)
transformer = _pytorch_encoder(
encoder,
hidden_width=hidden_width,
model_max_length=model_max_length,
padding_idx=padding_idx,
mixed_precision=mixed_precision,
grad_scaler_config=grad_scaler_config,
)
Expand Down Expand Up @@ -233,32 +240,33 @@ def build_bert_transformer_model_v1(
Optional listener to wrap. Only used when replacing listeners
in downstream components.
"""
config = BertConfig(
config = BERTConfig(
embedding_width=hidden_width,
hidden_width=hidden_width,
intermediate_width=intermediate_width,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
n_attention_heads=num_attention_heads,
n_hidden_layers=num_hidden_layers,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
vocab_size=vocab_size,
type_vocab_size=type_vocab_size,
max_position_embeddings=max_position_embeddings,
activation=Activation(hidden_act),
n_pieces=vocab_size,
n_types=type_vocab_size,
n_positions=max_position_embeddings,
model_max_length=model_max_length,
layer_norm_eps=layer_norm_eps,
padding_idx=padding_idx,
)

if torchscript:
transformer = _torchscript_encoder(
model_max_length=model_max_length, padding_idx=padding_idx
)
else:
encoder = BertEncoder(config)
encoder = BERTEncoder(config)
transformer = _pytorch_encoder(
encoder,
hidden_width=hidden_width,
model_max_length=model_max_length,
padding_idx=padding_idx,
mixed_precision=mixed_precision,
grad_scaler_config=grad_scaler_config,
)
Expand Down Expand Up @@ -340,32 +348,33 @@ def build_camembert_transformer_model_v1(
Optional listener to wrap. Only used when replacing listeners
in downstream components.
"""
config = RobertaConfig(
config = RoBERTaConfig(
embedding_width=hidden_width,
hidden_width=hidden_width,
intermediate_width=intermediate_width,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
n_attention_heads=num_attention_heads,
n_hidden_layers=num_hidden_layers,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
vocab_size=vocab_size,
type_vocab_size=type_vocab_size,
max_position_embeddings=max_position_embeddings,
activation=Activation(hidden_act),
n_pieces=vocab_size,
n_types=type_vocab_size,
n_positions=max_position_embeddings,
model_max_length=model_max_length,
layer_norm_eps=layer_norm_eps,
padding_idx=padding_idx,
)

if torchscript:
transformer = _torchscript_encoder(
model_max_length=model_max_length, padding_idx=padding_idx
)
else:
encoder = RobertaEncoder(config)
encoder = CamemBERTEncoder(config)
transformer = _pytorch_encoder(
encoder,
hidden_width=hidden_width,
model_max_length=model_max_length,
padding_idx=padding_idx,
mixed_precision=mixed_precision,
grad_scaler_config=grad_scaler_config,
)
Expand Down Expand Up @@ -447,32 +456,33 @@ def build_roberta_transformer_model_v1(
Optional listener to wrap. Only used when replacing listeners
in downstream components.
"""
config = RobertaConfig(
config = RoBERTaConfig(
embedding_width=hidden_width,
hidden_width=hidden_width,
intermediate_width=intermediate_width,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
n_attention_heads=num_attention_heads,
n_hidden_layers=num_hidden_layers,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
vocab_size=vocab_size,
type_vocab_size=type_vocab_size,
max_position_embeddings=max_position_embeddings,
activation=Activation(hidden_act),
n_pieces=vocab_size,
n_types=type_vocab_size,
n_positions=max_position_embeddings,
model_max_length=model_max_length,
layer_norm_eps=layer_norm_eps,
padding_idx=padding_idx,
)

if torchscript:
transformer = _torchscript_encoder(
model_max_length=model_max_length, padding_idx=padding_idx
)
else:
encoder = RobertaEncoder(config)
encoder = RoBERTaEncoder(config)
transformer = _pytorch_encoder(
encoder,
hidden_width=hidden_width,
model_max_length=model_max_length,
padding_idx=padding_idx,
mixed_precision=mixed_precision,
grad_scaler_config=grad_scaler_config,
)
Expand Down Expand Up @@ -554,32 +564,33 @@ def build_xlmr_transformer_model_v1(
Optional listener to wrap. Only used when replacing listeners
in downstream components.
"""
config = RobertaConfig(
config = RoBERTaConfig(
embedding_width=hidden_width,
hidden_width=hidden_width,
intermediate_width=intermediate_width,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
n_attention_heads=num_attention_heads,
n_hidden_layers=num_hidden_layers,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
vocab_size=vocab_size,
type_vocab_size=type_vocab_size,
max_position_embeddings=max_position_embeddings,
activation=Activation(hidden_act),
n_pieces=vocab_size,
n_types=type_vocab_size,
n_positions=max_position_embeddings,
model_max_length=model_max_length,
layer_norm_eps=layer_norm_eps,
padding_idx=padding_idx,
)

if torchscript:
transformer = _torchscript_encoder(
model_max_length=model_max_length, padding_idx=padding_idx
)
else:
encoder = RobertaEncoder(config)
encoder = XLMREncoder(config)
transformer = _pytorch_encoder(
encoder,
hidden_width=hidden_width,
model_max_length=model_max_length,
padding_idx=padding_idx,
mixed_precision=mixed_precision,
grad_scaler_config=grad_scaler_config,
)
Expand Down Expand Up @@ -620,8 +631,6 @@ def build_transformer_model_v1(
transformer: TorchTransformerModelT,
piece_encoder: Tok2PiecesModelT,
) -> TransformerModelT:
# FIXME: do we want to make `remove_bos_eos` configurable as well or
# is it always the same post-processing?
layers = [
with_non_ws_tokens(
chain(piece_encoder, with_spans(transformer), remove_bos_eos())
Expand Down Expand Up @@ -668,8 +677,10 @@ def transformer_model_init(


def _pytorch_encoder(
encoder: CuratedEncoderT,
encoder: EncoderModule,
hidden_width: int,
padding_idx: int,
model_max_length: int,
*,
mixed_precision: bool = False,
grad_scaler_config: dict = SimpleFrozenDict(),
Expand All @@ -682,11 +693,11 @@ def _pytorch_encoder(
grad_scaler_config["enabled"] = mixed_precision

model = PyTorchWrapper_v2(
CuratedTransformer(encoder),
encoder,
convert_inputs=partial(
_convert_inputs,
max_model_seq_len=encoder.max_seq_len,
padding_idx=encoder.padding_idx,
max_model_seq_len=model_max_length,
padding_idx=padding_idx,
),
convert_outputs=_convert_outputs,
mixed_precision=mixed_precision,
Expand Down Expand Up @@ -735,18 +746,19 @@ def _convert_inputs(
span_len = span.shape[0]
Xt[i, :span_len] = span
Xt = xp2torch(Xt)
mask = AttentionMask(Xt.ne(padding_idx))

def convert_from_torch_backward(d_inputs: Any):
# No gradients for the inputs.
return [ops.alloc1f(x.shape[0]) for x in X]

output = ArgsKwargs(args=(Xt,), kwargs={})
output = ArgsKwargs(args=(Xt, mask), kwargs={})
return output, convert_from_torch_backward


def _convert_outputs(
model: Model,
inputs_outputs: Tuple[TorchTransformerInT, PyTorchTransformerOutput],
inputs_outputs: Tuple[TorchTransformerInT, ModelOutput],
is_train: bool,
) -> Tuple[TorchTransformerOutT, Callable[[List[List[Floats2d]]], ArgsKwargs]]:
model_inputs, model_outputs = inputs_outputs
Expand Down Expand Up @@ -790,19 +802,40 @@ def build_pytorch_checkpoint_loader_v1(*, path: Path) -> Callable[
TorchTransformerModelT,
]:
"""Construct a callback that initializes a supported transformer
model with weights from a PyTorch checkpoint.
model with weights from a PyTorch or SafeTensors checkpoint.
path (Path):
Path to the PyTorch checkpoint.
"""
return build_pytorch_checkpoint_loader_v2(path=path.parent)


def build_pytorch_checkpoint_loader_v2(*, path: Path) -> Callable[
[TorchTransformerModelT, Optional[List[Doc]], Optional[List[Doc]]],
TorchTransformerModelT,
]:
"""Construct a callback that initializes a supported transformer
model with weights from a PyTorch or SafeTensors checkpoint.
path (Path):
Path to the directory containing the checkpoint.
"""

def load(model, X=None, Y=None):
encoder = model.shims[0]._model
device = get_torch_default_device()
params = torch.load(path, map_location=device)
params = convert_pretrained_model_for_encoder(encoder, params)
encoder.load_state_dict(params)
encoder.to(device)
encoder = model.shims[0]._model
assert isinstance(encoder, FromHFHub)
from_fsspec = type(encoder).from_fsspec

# We can discard the previously initialized model entirely
# and use the Curated Transformers API to load it from the
# hub.
model.shims[0]._model = None
del encoder

fs = LocalFileSystem()
encoder = from_fsspec(fs=fs, model_path=path, device=device)
model.shims[0]._model = encoder
return model

return load
Loading

0 comments on commit 4d8ccb4

Please sign in to comment.