Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring back curated_encoder prefix #34

Merged
merged 3 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions spacy_curated_transformers/models/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Callable, List, Optional, Tuple, Union, cast

import torch
import torch.nn as nn
from curated_transformers.layers import Activation, AttentionMask
from curated_transformers.models import (
ALBERTConfig,
Expand Down Expand Up @@ -1181,7 +1182,7 @@ def _pytorch_encoder(
grad_scaler_config["enabled"] = mixed_precision

model = PyTorchWrapper_v2(
encoder,
_EncoderWrapper(encoder),
convert_inputs=partial(
_convert_inputs,
max_model_seq_len=model_max_length,
Expand Down Expand Up @@ -1311,7 +1312,7 @@ def build_pytorch_checkpoint_loader_v2(*, path: Path) -> Callable[

def load(model, X=None, Y=None):
device = get_torch_default_device()
encoder = model.shims[0]._model
encoder = model.shims[0]._model.curated_encoder
assert isinstance(encoder, FromHFHub)
fs = LocalFileSystem()
encoder.from_fsspec_(fs=fs, model_path=path, device=device)
Expand All @@ -1325,3 +1326,16 @@ def _torch_dtype_from_str(dtype_as_str: str):
if not isinstance(dtype, torch.dtype):
raise ValueError(f"Invalid torch dtype `{dtype_as_str}`")
return dtype


class _EncoderWrapper(nn.Module):
"""Small wrapper to add a prefix that can be used by eg. learning rate
schedules.
"""

def __init__(self, encoder: nn.Module):
super().__init__()
self.curated_encoder = encoder

def forward(self, *args, **kwargs):
return self.curated_encoder.forward(*args, **kwargs)
2 changes: 1 addition & 1 deletion spacy_curated_transformers/models/hf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def build_hf_transformer_encoder_loader_v1(
"""

def load(model, X=None, Y=None):
encoder = model.shims[0]._model
encoder = model.shims[0]._model.curated_encoder
assert isinstance(encoder, FromHFHub)
device = model.shims[0].device
encoder.from_hf_hub_(name=name, revision=revision, device=device)
Expand Down
1 change: 0 additions & 1 deletion spacy_curated_transformers/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def getopt(opt):

@pytest.fixture
def test_dir(request):
print(request.fspath)
return Path(request.fspath).parent


Expand Down
41 changes: 41 additions & 0 deletions spacy_curated_transformers/tests/models/test_transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,44 @@ def test_pytorch_checkpoint_loader(test_config):
"spacy-curated-transformers.PyTorchCheckpointLoader.v1"
)(path=Path(checkpoint_path))
model.initialize()


@pytest.mark.parametrize(
"test_config",
[
(
build_albert_transformer_model_v1,
build_sentencepiece_encoder_v1(),
1000,
),
(
build_bert_transformer_model_v1,
build_bert_wordpiece_encoder_v1(),
1000,
),
(
build_roberta_transformer_model_v1,
build_byte_bpe_encoder_v1(),
1000,
),
],
)
def test_encoder_prefix(test_config):
model_factory, piece_encoder, vocab_size = test_config

# Curated Transformers needs the config to get the model hyperparameters.
with_spans = build_with_strided_spans_v1(stride=96, window=128)
model = model_factory(
hidden_width=32,
intermediate_width=37,
num_hidden_layers=4,
num_attention_heads=4,
piece_encoder=piece_encoder,
vocab_size=vocab_size,
with_spans=with_spans,
)

for name, _ in model.get_ref("transformer").shims[0]._model.named_parameters():
assert name.startswith(
"curated_encoder."
), f"Parameter name '{name} does not start with 'curated_encoder'"
Loading