Skip to content

Commit

Permalink
Add the transformer_discriminative schedule (#27)
Browse files Browse the repository at this point in the history
* Add the `transformer_discriminative` schedule

This schedule wraps two schedules:

1. `transformer_schedule` is used for transformer parameters.
2. `default_schedule` is used for all other parameters.

This differentation allows you to implement good transformer finetuning
practices such as:

- Initially freezing the transformer while freshly-initialized classifier
  parameter settle to avoid catastrophic forgetting in the transformer
  due to the initially large gradients.
- Using smaller learning rates for the transformer than classification
  layers.

* CI: Update Python upper and lower bounds

spaCy 4 requires Python 3.8, 4.0.0.dev1 has a binary wheel for 3.11.

* Update Thinc and spaCy dependencies

* black

* Bump mypy dependency to work with Torch type annotations

* Fix docstring typo

Co-authored-by: Sofie Van Landeghem <[email protected]>

* Test with `curated_encoder.` as infix

---------

Co-authored-by: Sofie Van Landeghem <[email protected]>
  • Loading branch information
danieldk and svlandeg authored Feb 4, 2024
1 parent e5baf57 commit 8bda5c6
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.7", "3.10"]
python-version: ["3.8", "3.11"]
include:
- os: ubuntu-latest
hf-path: ~/.cache/huggingface/hub
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
curated-transformers>=0.1.0,<0.2.0
curated-tokenizers>=0.0.8,<0.1.0
spacy>=3.7.0.dev0,<4.0.0
thinc>=8.1.6,<9.1.0
spacy>=4.0.0.dev2,<5.0.0
thinc>=9.0.0.dev4,<9.1.0
srsly
torch>=1.12.0
mypy>=0.990,<0.1000; platform_machine != "aarch64" and python_version >= "3.7"
mypy>=1.5.0,<1.6.0; platform_machine != "aarch64" and python_version >= "3.7"

# Development dependencies
pytest
6 changes: 5 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ python_requires = >=3.6
install_requires =
curated-transformers>=0.1.0,<0.2.0
curated-tokenizers>=0.0.7,<0.1.0
spacy>=3.7.0.dev0,<4.0.0
spacy>=4.0.0.dev2,<5.0.0
thinc>=9.0.0.dev4,<9.1.0
torch>=1.12.0

[options.entry_points]
Expand Down Expand Up @@ -58,6 +59,9 @@ thinc_model_loaders =
spacy-curated-transformers.SentencepieceLoader.v1 = spacy_curated_transformers.tokenization:build_sentencepiece_encoder_loader_v1
spacy-curated-transformers.WordpieceLoader.v1 = spacy_curated_transformers.tokenization:build_wordpiece_encoder_loader_v1

thinc_schedules =
spacy-curated-transformers.transformer_discriminative.v1 = spacy_curated_transformers.schedules:transformer_discriminative

[bdist_wheel]
universal = true

Expand Down
4 changes: 1 addition & 3 deletions spacy_curated_transformers/models/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,9 +785,7 @@ def convert_for_torch_backward(dY: List[List[Floats2d]]):
return output, convert_for_torch_backward


def build_pytorch_checkpoint_loader_v1(
*, path: Path
) -> Callable[
def build_pytorch_checkpoint_loader_v1(*, path: Path) -> Callable[
[TorchTransformerModelT, Optional[List[Doc]], Optional[List[Doc]]],
TorchTransformerModelT,
]:
Expand Down
6 changes: 3 additions & 3 deletions spacy_curated_transformers/models/listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,9 +536,9 @@ def __init__(
)

# Ensure that the transformer returns the required outputs.
transformer.attrs[
"_all_layer_outputs"
] = ListenerStateUtils.requires_all_layer_outputs(listener)
transformer.attrs["_all_layer_outputs"] = (
ListenerStateUtils.requires_all_layer_outputs(listener)
)

# Freeze the embedded transformer if the source pipe was frozen.
transformer.attrs["_frozen"] = frozen
Expand Down
47 changes: 47 additions & 0 deletions spacy_curated_transformers/schedules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Optional, Tuple

from thinc.api import Schedule

# This is the parameter prefix for curated encoders.
_CURATED_ENCODER_PREFIX = "curated_encoder."


def transformer_discriminative(
default_schedule: Schedule,
transformer_schedule: Schedule,
) -> Schedule:
"""Discriminative learning rate schedule for transformer encoders.
This schedule uses `transformer_schedule` for all transformer encoder
parameters and `default_schedule` for other parameters.
default_schedule (Schedule): default schedule.
transformer_schedule (Schedule): schedule for transformer parameters.
"""
return Schedule(
"transfomer",
_transformer_discriminative_schedule,
attrs={
"default_schedule": default_schedule,
"transformer_schedule": transformer_schedule,
},
)


def _transformer_discriminative_schedule(
schedule: Schedule, step: int, *, key: Optional[Tuple[int, str]] = None, **kwargs
) -> float:
default_schedule: Schedule = schedule.attrs["default_schedule"]
transformer_schedule: Schedule = schedule.attrs["transformer_schedule"]

if key is None:
return default_schedule(step=step, key=key, **kwargs)

key_str = key[1]
# We don't do a strict prefix check, since we want to support
# an encoder wrapped into another model as well. In the latter
# case, the prefix becomes an infix.
if _CURATED_ENCODER_PREFIX in key_str:
return transformer_schedule(step=step, key=key, **kwargs)

return default_schedule(step=step, key=key, **kwargs)
11 changes: 11 additions & 0 deletions spacy_curated_transformers/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ def test_model_loaders_from_registry(loader_name):
registry.model_loaders.get(loader_name)


@pytest.mark.parametrize(
"schedule_name",
[
"spacy-curated-transformers.transformer_discriminative.v1",
],
)
def test_schedule_from_registry(schedule_name):
# Can't be constructed, since most schedules have mandatory arguments.
registry.schedules.get(schedule_name)


@pytest.mark.parametrize(
"callback_name",
[
Expand Down
12 changes: 12 additions & 0 deletions spacy_curated_transformers/tests/test_schedules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from thinc.api import constant

from spacy_curated_transformers.schedules import transformer_discriminative


def test_schedules():
default_schedule = constant(1e-3)
transformer_schedule = constant(1e-5)
schedule = transformer_discriminative(default_schedule, transformer_schedule)
assert schedule(0, key=(0, "some_key")) == 1e-3
assert schedule(0, key=(1, "curated_encoder.embeddings")) == 1e-5
assert schedule(0, key=(2, "wrapping_model.curated_encoder.embeddings")) == 1e-5
4 changes: 1 addition & 3 deletions spacy_curated_transformers/tokenization/hf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
SUPPORTED_TOKENIZERS = () # type: ignore


def build_hf_piece_encoder_loader_v1(
*, name: str, revision: str = "main"
) -> Callable[
def build_hf_piece_encoder_loader_v1(*, name: str, revision: str = "main") -> Callable[
[Tok2PiecesModelT, Optional[Tok2PiecesInT], Optional[Tok2PiecesInT]],
Tok2PiecesModelT,
]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ def sentencepiece_encoder_forward(
return pieces, lambda dY: []


def build_sentencepiece_encoder_loader_v1(
*, path: Path
) -> Callable[
def build_sentencepiece_encoder_loader_v1(*, path: Path) -> Callable[
[Tok2PiecesModelT, Optional[Tok2PiecesInT], Optional[Tok2PiecesInT]],
Tok2PiecesModelT,
]:
Expand Down

0 comments on commit 8bda5c6

Please sign in to comment.