Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rebase on curated-transformers 2.0.0 #19

Merged
merged 17 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Test
on: [push, pull_request, workflow_call]

env:
hf-transformers-pip: transformers[sentencepiece]>=3.4.0,<4.32.0
hf-transformers-pip: transformers[sentencepiece]>=4.39.0,<4.40.0

jobs:
validate:
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
curated-transformers>=0.1.0,<0.2.0
curated-tokenizers>=0.0.9,<0.1.0
curated-transformers>=2.0.0.dev1,<3.0.0
curated-tokenizers>=0.9.2,<1.0.0
fsspec>=2023.5.0
spacy>=4.0.0.dev2,<5.0.0
thinc>=9.0.0.dev4,<9.1.0
srsly
Expand Down
10 changes: 6 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[metadata]
version = 0.2.2
version = 1.0.0.dev0
description = Curated transformer models for spaCy pipelines
url = https:/explosion/spacy-curated-transformers
author = Explosion
Expand All @@ -12,10 +12,11 @@ long_description_content_type = text/markdown
[options]
zip_safe = true
include_package_data = true
python_requires = >=3.6
python_requires = >=3.8
install_requires =
curated-transformers>=0.1.0,<0.2.0
curated-tokenizers>=0.0.9,<0.1.0
curated-transformers>=2.0.0.dev1,<3.0.0
curated-tokenizers>=0.9.2,<1.0.0
fsspec>=2023.5.0
spacy>=4.0.0.dev2,<5.0.0
thinc>=9.0.0.dev4,<9.1.0
torch>=1.12.0
Expand Down Expand Up @@ -56,6 +57,7 @@ thinc_model_loaders =
spacy-curated-transformers.HFTransformerEncoderLoader.v1 = spacy_curated_transformers.models:build_hf_transformer_encoder_loader_v1
spacy-curated-transformers.HFPieceEncoderLoader.v1 = spacy_curated_transformers.tokenization:build_hf_piece_encoder_loader_v1
spacy-curated-transformers.PyTorchCheckpointLoader.v1 = spacy_curated_transformers.models:build_pytorch_checkpoint_loader_v1
spacy-curated-transformers.PyTorchCheckpointLoader.v2 = spacy_curated_transformers.models:build_pytorch_checkpoint_loader_v2
spacy-curated-transformers.SentencepieceLoader.v1 = spacy_curated_transformers.tokenization:build_sentencepiece_encoder_loader_v1
spacy-curated-transformers.WordpieceLoader.v1 = spacy_curated_transformers.tokenization:build_wordpiece_encoder_loader_v1

Expand Down
169 changes: 101 additions & 68 deletions spacy_curated_transformers/models/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,21 @@
from typing import Any, Callable, List, Optional, Tuple, Union, cast

import torch
from curated_transformers.models.albert import AlbertConfig, AlbertEncoder
from curated_transformers.models.bert import BertConfig, BertEncoder
from curated_transformers.models.curated_transformer import (
CuratedEncoderT,
CuratedTransformer,
from curated_transformers.layers import Activation, AttentionMask
from curated_transformers.models import (
ALBERTConfig,
ALBERTEncoder,
BERTConfig,
BERTEncoder,
CamemBERTEncoder,
EncoderModule,
FromHFHub,
ModelOutput,
RoBERTaConfig,
RoBERTaEncoder,
XLMREncoder,
)
from curated_transformers.models.hf_util import convert_pretrained_model_for_encoder
from curated_transformers.models.output import PyTorchTransformerOutput
from curated_transformers.models.roberta import RobertaConfig, RobertaEncoder
from fsspec.implementations.local import LocalFileSystem
from spacy.tokens import Doc
from spacy.util import SimpleFrozenDict
from thinc.api import (
Expand Down Expand Up @@ -125,33 +131,34 @@ def build_albert_transformer_model_v1(
Optional listener to wrap. Only used when replacing listeners
in downstream components.
"""
config = AlbertConfig(
config = ALBERTConfig(
embedding_width=embedding_width,
hidden_width=hidden_width,
intermediate_width=intermediate_width,
num_attention_heads=num_attention_heads,
num_hidden_groups=num_hidden_groups,
num_hidden_layers=num_hidden_layers,
n_attention_heads=num_attention_heads,
n_hidden_groups=num_hidden_groups,
n_hidden_layers=num_hidden_layers,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
vocab_size=vocab_size,
type_vocab_size=type_vocab_size,
max_position_embeddings=max_position_embeddings,
activation=Activation(hidden_act),
n_pieces=vocab_size,
n_types=type_vocab_size,
n_positions=max_position_embeddings,
model_max_length=model_max_length,
layer_norm_eps=layer_norm_eps,
padding_idx=padding_idx,
)

if torchscript:
transformer = _torchscript_encoder(
model_max_length=model_max_length, padding_idx=padding_idx
)
else:
encoder = AlbertEncoder(config)
encoder = ALBERTEncoder(config)
transformer = _pytorch_encoder(
encoder,
hidden_width=hidden_width,
model_max_length=model_max_length,
padding_idx=padding_idx,
mixed_precision=mixed_precision,
grad_scaler_config=grad_scaler_config,
)
Expand Down Expand Up @@ -233,32 +240,33 @@ def build_bert_transformer_model_v1(
Optional listener to wrap. Only used when replacing listeners
in downstream components.
"""
config = BertConfig(
config = BERTConfig(
embedding_width=hidden_width,
hidden_width=hidden_width,
intermediate_width=intermediate_width,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
n_attention_heads=num_attention_heads,
n_hidden_layers=num_hidden_layers,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
vocab_size=vocab_size,
type_vocab_size=type_vocab_size,
max_position_embeddings=max_position_embeddings,
activation=Activation(hidden_act),
n_pieces=vocab_size,
n_types=type_vocab_size,
n_positions=max_position_embeddings,
model_max_length=model_max_length,
layer_norm_eps=layer_norm_eps,
padding_idx=padding_idx,
)

if torchscript:
transformer = _torchscript_encoder(
model_max_length=model_max_length, padding_idx=padding_idx
)
else:
encoder = BertEncoder(config)
encoder = BERTEncoder(config)
transformer = _pytorch_encoder(
encoder,
hidden_width=hidden_width,
model_max_length=model_max_length,
padding_idx=padding_idx,
mixed_precision=mixed_precision,
grad_scaler_config=grad_scaler_config,
)
Expand Down Expand Up @@ -340,32 +348,33 @@ def build_camembert_transformer_model_v1(
Optional listener to wrap. Only used when replacing listeners
in downstream components.
"""
config = RobertaConfig(
config = RoBERTaConfig(
embedding_width=hidden_width,
hidden_width=hidden_width,
intermediate_width=intermediate_width,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
n_attention_heads=num_attention_heads,
n_hidden_layers=num_hidden_layers,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
vocab_size=vocab_size,
type_vocab_size=type_vocab_size,
max_position_embeddings=max_position_embeddings,
activation=Activation(hidden_act),
n_pieces=vocab_size,
n_types=type_vocab_size,
n_positions=max_position_embeddings,
model_max_length=model_max_length,
layer_norm_eps=layer_norm_eps,
padding_idx=padding_idx,
)

if torchscript:
transformer = _torchscript_encoder(
model_max_length=model_max_length, padding_idx=padding_idx
)
else:
encoder = RobertaEncoder(config)
encoder = CamemBERTEncoder(config)
transformer = _pytorch_encoder(
encoder,
hidden_width=hidden_width,
model_max_length=model_max_length,
padding_idx=padding_idx,
mixed_precision=mixed_precision,
grad_scaler_config=grad_scaler_config,
)
Expand Down Expand Up @@ -447,32 +456,33 @@ def build_roberta_transformer_model_v1(
Optional listener to wrap. Only used when replacing listeners
in downstream components.
"""
config = RobertaConfig(
config = RoBERTaConfig(
embedding_width=hidden_width,
hidden_width=hidden_width,
intermediate_width=intermediate_width,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
n_attention_heads=num_attention_heads,
n_hidden_layers=num_hidden_layers,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
vocab_size=vocab_size,
type_vocab_size=type_vocab_size,
max_position_embeddings=max_position_embeddings,
activation=Activation(hidden_act),
n_pieces=vocab_size,
n_types=type_vocab_size,
n_positions=max_position_embeddings,
model_max_length=model_max_length,
layer_norm_eps=layer_norm_eps,
padding_idx=padding_idx,
)

if torchscript:
transformer = _torchscript_encoder(
model_max_length=model_max_length, padding_idx=padding_idx
)
else:
encoder = RobertaEncoder(config)
encoder = RoBERTaEncoder(config)
transformer = _pytorch_encoder(
encoder,
hidden_width=hidden_width,
model_max_length=model_max_length,
padding_idx=padding_idx,
mixed_precision=mixed_precision,
grad_scaler_config=grad_scaler_config,
)
Expand Down Expand Up @@ -554,32 +564,33 @@ def build_xlmr_transformer_model_v1(
Optional listener to wrap. Only used when replacing listeners
in downstream components.
"""
config = RobertaConfig(
config = RoBERTaConfig(
embedding_width=hidden_width,
hidden_width=hidden_width,
intermediate_width=intermediate_width,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
n_attention_heads=num_attention_heads,
n_hidden_layers=num_hidden_layers,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
vocab_size=vocab_size,
type_vocab_size=type_vocab_size,
max_position_embeddings=max_position_embeddings,
activation=Activation(hidden_act),
n_pieces=vocab_size,
n_types=type_vocab_size,
n_positions=max_position_embeddings,
model_max_length=model_max_length,
layer_norm_eps=layer_norm_eps,
padding_idx=padding_idx,
)

if torchscript:
transformer = _torchscript_encoder(
model_max_length=model_max_length, padding_idx=padding_idx
)
else:
encoder = RobertaEncoder(config)
encoder = XLMREncoder(config)
transformer = _pytorch_encoder(
encoder,
hidden_width=hidden_width,
model_max_length=model_max_length,
padding_idx=padding_idx,
mixed_precision=mixed_precision,
grad_scaler_config=grad_scaler_config,
)
Expand Down Expand Up @@ -620,8 +631,6 @@ def build_transformer_model_v1(
transformer: TorchTransformerModelT,
piece_encoder: Tok2PiecesModelT,
) -> TransformerModelT:
# FIXME: do we want to make `remove_bos_eos` configurable as well or
# is it always the same post-processing?
layers = [
with_non_ws_tokens(
chain(piece_encoder, with_spans(transformer), remove_bos_eos())
Expand Down Expand Up @@ -668,8 +677,10 @@ def transformer_model_init(


def _pytorch_encoder(
encoder: CuratedEncoderT,
encoder: EncoderModule,
hidden_width: int,
padding_idx: int,
model_max_length: int,
*,
mixed_precision: bool = False,
grad_scaler_config: dict = SimpleFrozenDict(),
Expand All @@ -682,11 +693,11 @@ def _pytorch_encoder(
grad_scaler_config["enabled"] = mixed_precision

model = PyTorchWrapper_v2(
CuratedTransformer(encoder),
encoder,
convert_inputs=partial(
_convert_inputs,
max_model_seq_len=encoder.max_seq_len,
padding_idx=encoder.padding_idx,
max_model_seq_len=model_max_length,
padding_idx=padding_idx,
),
convert_outputs=_convert_outputs,
mixed_precision=mixed_precision,
Expand Down Expand Up @@ -735,18 +746,19 @@ def _convert_inputs(
span_len = span.shape[0]
Xt[i, :span_len] = span
Xt = xp2torch(Xt)
mask = AttentionMask(Xt.ne(padding_idx))

def convert_from_torch_backward(d_inputs: Any):
# No gradients for the inputs.
return [ops.alloc1f(x.shape[0]) for x in X]

output = ArgsKwargs(args=(Xt,), kwargs={})
output = ArgsKwargs(args=(Xt, mask), kwargs={})
return output, convert_from_torch_backward


def _convert_outputs(
model: Model,
inputs_outputs: Tuple[TorchTransformerInT, PyTorchTransformerOutput],
inputs_outputs: Tuple[TorchTransformerInT, ModelOutput],
is_train: bool,
) -> Tuple[TorchTransformerOutT, Callable[[List[List[Floats2d]]], ArgsKwargs]]:
model_inputs, model_outputs = inputs_outputs
Expand Down Expand Up @@ -790,19 +802,40 @@ def build_pytorch_checkpoint_loader_v1(*, path: Path) -> Callable[
TorchTransformerModelT,
]:
"""Construct a callback that initializes a supported transformer
model with weights from a PyTorch checkpoint.
model with weights from a PyTorch or SafeTensors checkpoint.

path (Path):
Path to the PyTorch checkpoint.
"""
return build_pytorch_checkpoint_loader_v2(path=path.parent)
danieldk marked this conversation as resolved.
Show resolved Hide resolved


def build_pytorch_checkpoint_loader_v2(*, path: Path) -> Callable[
[TorchTransformerModelT, Optional[List[Doc]], Optional[List[Doc]]],
TorchTransformerModelT,
]:
"""Construct a callback that initializes a supported transformer
model with weights from a PyTorch or SafeTensors checkpoint.

path (Path):
Path to the directory containing the checkpoint.
"""

def load(model, X=None, Y=None):
encoder = model.shims[0]._model
device = get_torch_default_device()
params = torch.load(path, map_location=device)
params = convert_pretrained_model_for_encoder(encoder, params)
encoder.load_state_dict(params)
encoder.to(device)
encoder = model.shims[0]._model
assert isinstance(encoder, FromHFHub)
from_fsspec = type(encoder).from_fsspec

# We can discard the previously initialized model entirely
# and use the Curated Transformers API to load it from the
# hub.
model.shims[0]._model = None
del encoder

fs = LocalFileSystem()
encoder = from_fsspec(fs=fs, model_path=path, device=device)
model.shims[0]._model = encoder
return model

return load
Loading
Loading