Skip to content

Commit

Permalink
Load models in-place
Browse files Browse the repository at this point in the history
Before this change, we would first construct a model from using the
arguments passed to the registry function. Then we would construct it
again using `from_hf_hub`. This was not only a performance issue, but
also a correctness issues -- the model constructed through `from_hf_hub`
could have different hyperparameters than those specified in the
arguments to the registry function.

This change fixed this by using the new in-place loading support in
Curated Transformers 2.0.

The addition to in-place loading also added the `dtype` argument to the
model configuration. We also expose this argument now in v2 versions of
the registry functions. The configuration filling is also updated to
fill the data type from the `torch_dtype` option in the HF model
configuration.
  • Loading branch information
danieldk committed Apr 10, 2024
1 parent 4d8ccb4 commit fa27ea6
Show file tree
Hide file tree
Showing 13 changed files with 547 additions and 48 deletions.
2 changes: 1 addition & 1 deletion project/configs/layer-weighting.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ factory = "curated_transformer"
all_layer_outputs = True

[components.transformer.model]
@architectures = "spacy-curated-transformers.XlmrTransformer.v1"
@architectures = "spacy-curated-transformers.XlmrTransformer.v2"
vocab_size = 250002
num_hidden_layers = 12
hidden_width = 768
Expand Down
2 changes: 1 addition & 1 deletion project/configs/no-layer-weighting.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ upstream = "*"
factory = "curated_transformer"

[components.transformer.model]
@architectures = "spacy-curated-transformers.XlmrTransformer.v1"
@architectures = "spacy-curated-transformers.XlmrTransformer.v2"
vocab_size = 250002
piece_encoder = {"@architectures": "spacy-curated-transformers.XlmrSentencepieceEncoder.v1"}

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
curated-transformers>=2.0.0.dev1,<3.0.0
curated-transformers>=2.0.0.dev2,<3.0.0
curated-tokenizers>=0.9.2,<1.0.0
fsspec>=2023.5.0
spacy>=4.0.0.dev2,<5.0.0
Expand Down
7 changes: 6 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ zip_safe = true
include_package_data = true
python_requires = >=3.8
install_requires =
curated-transformers>=2.0.0.dev1,<3.0.0
curated-transformers>=2.0.0.dev2,<3.0.0
curated-tokenizers>=0.9.2,<1.0.0
fsspec>=2023.5.0
spacy>=4.0.0.dev2,<5.0.0
Expand All @@ -27,10 +27,15 @@ spacy_factories =

spacy_architectures =
spacy-curated-transformers.AlbertTransformer.v1 = spacy_curated_transformers.models:build_albert_transformer_model_v1
spacy-curated-transformers.AlbertTransformer.v2 = spacy_curated_transformers.models:build_albert_transformer_model_v2
spacy-curated-transformers.BertTransformer.v1 = spacy_curated_transformers.models:build_bert_transformer_model_v1
spacy-curated-transformers.BertTransformer.v2 = spacy_curated_transformers.models:build_bert_transformer_model_v2
spacy-curated-transformers.CamembertTransformer.v1 = spacy_curated_transformers.models:build_camembert_transformer_model_v1
spacy-curated-transformers.CamembertTransformer.v2 = spacy_curated_transformers.models:build_camembert_transformer_model_v2
spacy-curated-transformers.RobertaTransformer.v1 = spacy_curated_transformers.models:build_roberta_transformer_model_v1
spacy-curated-transformers.RobertaTransformer.v2 = spacy_curated_transformers.models:build_roberta_transformer_model_v2
spacy-curated-transformers.XlmrTransformer.v1 = spacy_curated_transformers.models:build_xlmr_transformer_model_v1
spacy-curated-transformers.XlmrTransformer.v2 = spacy_curated_transformers.models:build_xlmr_transformer_model_v2
spacy-curated-transformers.WithStridedSpans.v1 = spacy_curated_transformers.models:build_with_strided_spans_v1
spacy-curated-transformers.ScalarWeight.v1 = spacy_curated_transformers.models:build_scalar_weight_v1
spacy-curated-transformers.TransformerLayersListener.v1 = spacy_curated_transformers.models.listeners:build_transformer_layers_listener_v1
Expand Down
14 changes: 9 additions & 5 deletions spacy_curated_transformers/cli/fill_config_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,14 @@ def init_fill_curated_transformer_cli(

class HfParamSource(Enum):
MODEL_CONFIG = 1
TOKENIZER_CONFIG = 2
MODEL_CONFIG_OPTIONAL = 2
TOKENIZER_CONFIG = 3


# Entrypoint parameters that are common to all curated transformer models.
COMMON_ENTRYPOINT_PARAMS: Dict[str, HfParamSource] = {
"attention_probs_dropout_prob": HfParamSource.MODEL_CONFIG,
"dtype": HfParamSource.MODEL_CONFIG_OPTIONAL,
"hidden_act": HfParamSource.MODEL_CONFIG,
"hidden_dropout_prob": HfParamSource.MODEL_CONFIG,
"hidden_width": HfParamSource.MODEL_CONFIG,
Expand All @@ -99,6 +101,7 @@ class HfParamSource(Enum):
"intermediate_width": "intermediate_size",
"padding_idx": "pad_token_id",
"embedding_width": "embedding_size",
"dtype": "torch_dtype",
}


Expand Down Expand Up @@ -328,9 +331,9 @@ def _fill_parameters(
filled_params = {}
for param_name, source in params_to_fill.items():
hf_key = ENTRYPOINT_PARAMS_TO_HF_CONFIG_KEYS.get(param_name, param_name)
if source == HfParamSource.MODEL_CONFIG:
if source in (HfParamSource.MODEL_CONFIG, HfParamSource.MODEL_CONFIG_OPTIONAL):
value = hf_config.get(hf_key)
if value is None:
if value is None and source == HfParamSource.MODEL_CONFIG:
msg.fail(
f"Hugging Face model config has a missing key '{hf_key}'", exits=1
)
Expand All @@ -341,8 +344,9 @@ def _fill_parameters(
f"Hugging Face tokenizer config has a missing key '{hf_key}'",
exits=1,
)
assert value is not None
filled_params[param_name] = value
assert value is not None or source == HfParamSource.MODEL_CONFIG_OPTIONAL
if value is not None:
filled_params[param_name] = value

msg.info(title="Filled-in model parameters:")
msg.table(filled_params)
Expand Down
5 changes: 5 additions & 0 deletions spacy_curated_transformers/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from .architectures import (
build_albert_transformer_model_v1,
build_albert_transformer_model_v2,
build_bert_transformer_model_v1,
build_bert_transformer_model_v2,
build_camembert_transformer_model_v1,
build_camembert_transformer_model_v2,
build_pytorch_checkpoint_loader_v1,
build_roberta_transformer_model_v1,
build_roberta_transformer_model_v2,
build_xlmr_transformer_model_v1,
build_xlmr_transformer_model_v2,
)
from .hf_loader import build_hf_transformer_encoder_loader_v1
from .scalar_weight import build_scalar_weight_v1
Expand Down
Loading

0 comments on commit fa27ea6

Please sign in to comment.