diff --git a/spacy_curated_transformers/models/architectures.py b/spacy_curated_transformers/models/architectures.py index 649e7e9..2345991 100644 --- a/spacy_curated_transformers/models/architectures.py +++ b/spacy_curated_transformers/models/architectures.py @@ -3,6 +3,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union, cast import torch +import torch.nn as nn from curated_transformers.layers import Activation, AttentionMask from curated_transformers.models import ( ALBERTConfig, @@ -1181,7 +1182,7 @@ def _pytorch_encoder( grad_scaler_config["enabled"] = mixed_precision model = PyTorchWrapper_v2( - encoder, + _EncoderWrapper(encoder), convert_inputs=partial( _convert_inputs, max_model_seq_len=model_max_length, @@ -1311,7 +1312,7 @@ def build_pytorch_checkpoint_loader_v2(*, path: Path) -> Callable[ def load(model, X=None, Y=None): device = get_torch_default_device() - encoder = model.shims[0]._model + encoder = model.shims[0]._model.curated_encoder assert isinstance(encoder, FromHFHub) fs = LocalFileSystem() encoder.from_fsspec_(fs=fs, model_path=path, device=device) @@ -1325,3 +1326,16 @@ def _torch_dtype_from_str(dtype_as_str: str): if not isinstance(dtype, torch.dtype): raise ValueError(f"Invalid torch dtype `{dtype_as_str}`") return dtype + + +class _EncoderWrapper(nn.Module): + """Small wrapper to add a prefix that can be used by eg. learning rate + schedules. + """ + + def __init__(self, encoder: nn.Module): + super().__init__() + self.curated_encoder = encoder + + def forward(self, *args, **kwargs): + return self.curated_encoder.forward(*args, **kwargs) diff --git a/spacy_curated_transformers/models/hf_loader.py b/spacy_curated_transformers/models/hf_loader.py index 3f58d2e..0748c15 100644 --- a/spacy_curated_transformers/models/hf_loader.py +++ b/spacy_curated_transformers/models/hf_loader.py @@ -24,7 +24,7 @@ def build_hf_transformer_encoder_loader_v1( """ def load(model, X=None, Y=None): - encoder = model.shims[0]._model + encoder = model.shims[0]._model.curated_encoder assert isinstance(encoder, FromHFHub) device = model.shims[0].device encoder.from_hf_hub_(name=name, revision=revision, device=device) diff --git a/spacy_curated_transformers/tests/conftest.py b/spacy_curated_transformers/tests/conftest.py index 0e698c1..8084b41 100644 --- a/spacy_curated_transformers/tests/conftest.py +++ b/spacy_curated_transformers/tests/conftest.py @@ -40,7 +40,6 @@ def getopt(opt): @pytest.fixture def test_dir(request): - print(request.fspath) return Path(request.fspath).parent diff --git a/spacy_curated_transformers/tests/models/test_transformer_model.py b/spacy_curated_transformers/tests/models/test_transformer_model.py index 9d504b5..d77043a 100644 --- a/spacy_curated_transformers/tests/models/test_transformer_model.py +++ b/spacy_curated_transformers/tests/models/test_transformer_model.py @@ -178,3 +178,44 @@ def test_pytorch_checkpoint_loader(test_config): "spacy-curated-transformers.PyTorchCheckpointLoader.v1" )(path=Path(checkpoint_path)) model.initialize() + + +@pytest.mark.parametrize( + "test_config", + [ + ( + build_albert_transformer_model_v1, + build_sentencepiece_encoder_v1(), + 1000, + ), + ( + build_bert_transformer_model_v1, + build_bert_wordpiece_encoder_v1(), + 1000, + ), + ( + build_roberta_transformer_model_v1, + build_byte_bpe_encoder_v1(), + 1000, + ), + ], +) +def test_encoder_prefix(test_config): + model_factory, piece_encoder, vocab_size = test_config + + # Curated Transformers needs the config to get the model hyperparameters. + with_spans = build_with_strided_spans_v1(stride=96, window=128) + model = model_factory( + hidden_width=32, + intermediate_width=37, + num_hidden_layers=4, + num_attention_heads=4, + piece_encoder=piece_encoder, + vocab_size=vocab_size, + with_spans=with_spans, + ) + + for name, _ in model.get_ref("transformer").shims[0]._model.named_parameters(): + assert name.startswith( + "curated_encoder." + ), f"Parameter name '{name} does not start with 'curated_encoder'"