Skip to content

Commit

Permalink
Bring back curated_encoder prefix
Browse files Browse the repository at this point in the history
Curated Transformers for spaCy 3.x used to have a `curated_encoder`
prefix that we used in e.g. the discriminative learning rate schedule.
Curated Transformers doesn't use such a prefix since 1.0. Add a small
wrapper to bring back the prefix, so that we can distinguish transformer
parameters from other parameters.
  • Loading branch information
danieldk committed Apr 12, 2024
1 parent 89e35f6 commit d4709ec
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 4 deletions.
17 changes: 15 additions & 2 deletions spacy_curated_transformers/models/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Callable, List, Optional, Tuple, Union, cast

import torch
import torch.nn as nn
from curated_transformers.layers import Activation, AttentionMask
from curated_transformers.models import (
ALBERTConfig,
Expand Down Expand Up @@ -1181,7 +1182,7 @@ def _pytorch_encoder(
grad_scaler_config["enabled"] = mixed_precision

model = PyTorchWrapper_v2(
encoder,
_EncoderWrapper(encoder),
convert_inputs=partial(
_convert_inputs,
max_model_seq_len=model_max_length,
Expand Down Expand Up @@ -1311,7 +1312,7 @@ def build_pytorch_checkpoint_loader_v2(*, path: Path) -> Callable[

def load(model, X=None, Y=None):
device = get_torch_default_device()
encoder = model.shims[0]._model
encoder = model.shims[0]._model.curated_encoder
assert isinstance(encoder, FromHFHub)
fs = LocalFileSystem()
encoder.from_fsspec_(fs=fs, model_path=path, device=device)
Expand All @@ -1325,3 +1326,15 @@ def _torch_dtype_from_str(dtype_as_str: str):
if not isinstance(dtype, torch.dtype):
raise ValueError(f"Invalid torch dtype `{dtype_as_str}`")
return dtype

class _EncoderWrapper(nn.Module):
""" Small wrapper to add a prefix that can be used by eg. learning rate
schedules.
"""

def __init__(self, encoder: nn.Module):
super().__init__()
self.curated_encoder = encoder

def forward(self, *args, **kwargs):
return self.curated_encoder.forward(*args, **kwargs)
2 changes: 1 addition & 1 deletion spacy_curated_transformers/models/hf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def build_hf_transformer_encoder_loader_v1(
"""

def load(model, X=None, Y=None):
encoder = model.shims[0]._model
encoder = model.shims[0]._model.curated_encoder
assert isinstance(encoder, FromHFHub)
device = model.shims[0].device
encoder.from_hf_hub_(name=name, revision=revision, device=device)
Expand Down
1 change: 0 additions & 1 deletion spacy_curated_transformers/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def getopt(opt):

@pytest.fixture
def test_dir(request):
print(request.fspath)
return Path(request.fspath).parent


Expand Down
35 changes: 35 additions & 0 deletions spacy_curated_transformers/tests/models/test_transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,38 @@ def test_pytorch_checkpoint_loader(test_config):
"spacy-curated-transformers.PyTorchCheckpointLoader.v1"
)(path=Path(checkpoint_path))
model.initialize()


@pytest.mark.slow
@pytest.mark.skipif(not has_huggingface_hub, reason="requires huggingface hub")
@pytest.mark.parametrize(
"test_config",
[
(
build_albert_transformer_model_v1,
build_sentencepiece_encoder_v1(),
1000,
),
(
build_bert_transformer_model_v1,
build_bert_wordpiece_encoder_v1(),
1000,
),
(
build_roberta_transformer_model_v1,
build_byte_bpe_encoder_v1(),
1000,
),
],
)
def test_encoder_prefix(test_config):
model_factory, piece_encoder, vocab_size = test_config

# Curated Transformers needs the config to get the model hyperparameters.
with_spans = build_with_strided_spans_v1(stride=96, window=128)
model = model_factory(
piece_encoder=piece_encoder, vocab_size=vocab_size, with_spans=with_spans
)

for name, _ in model.get_ref("transformer").shims[0]._model.named_parameters():
assert name.startswith("curated_encoder."), f"Parameter name '{name} does not start with 'curated_encoder'"

0 comments on commit d4709ec

Please sign in to comment.