Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load models in-place #32

Merged
merged 2 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion project/configs/layer-weighting.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ factory = "curated_transformer"
all_layer_outputs = True

[components.transformer.model]
@architectures = "spacy-curated-transformers.XlmrTransformer.v1"
@architectures = "spacy-curated-transformers.XlmrTransformer.v2"
vocab_size = 250002
num_hidden_layers = 12
hidden_width = 768
Expand Down
2 changes: 1 addition & 1 deletion project/configs/no-layer-weighting.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ upstream = "*"
factory = "curated_transformer"

[components.transformer.model]
@architectures = "spacy-curated-transformers.XlmrTransformer.v1"
@architectures = "spacy-curated-transformers.XlmrTransformer.v2"
vocab_size = 250002
piece_encoder = {"@architectures": "spacy-curated-transformers.XlmrSentencepieceEncoder.v1"}

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
curated-transformers>=2.0.0.dev1,<3.0.0
curated-transformers>=2.0.0.dev2,<3.0.0
curated-tokenizers>=0.9.2,<1.0.0
fsspec>=2023.5.0
spacy>=4.0.0.dev2,<5.0.0
Expand Down
7 changes: 6 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ zip_safe = true
include_package_data = true
python_requires = >=3.8
install_requires =
curated-transformers>=2.0.0.dev1,<3.0.0
curated-transformers>=2.0.0.dev2,<3.0.0
curated-tokenizers>=0.9.2,<1.0.0
fsspec>=2023.5.0
spacy>=4.0.0.dev2,<5.0.0
Expand All @@ -27,10 +27,15 @@ spacy_factories =

spacy_architectures =
spacy-curated-transformers.AlbertTransformer.v1 = spacy_curated_transformers.models:build_albert_transformer_model_v1
spacy-curated-transformers.AlbertTransformer.v2 = spacy_curated_transformers.models:build_albert_transformer_model_v2
spacy-curated-transformers.BertTransformer.v1 = spacy_curated_transformers.models:build_bert_transformer_model_v1
spacy-curated-transformers.BertTransformer.v2 = spacy_curated_transformers.models:build_bert_transformer_model_v2
spacy-curated-transformers.CamembertTransformer.v1 = spacy_curated_transformers.models:build_camembert_transformer_model_v1
spacy-curated-transformers.CamembertTransformer.v2 = spacy_curated_transformers.models:build_camembert_transformer_model_v2
spacy-curated-transformers.RobertaTransformer.v1 = spacy_curated_transformers.models:build_roberta_transformer_model_v1
spacy-curated-transformers.RobertaTransformer.v2 = spacy_curated_transformers.models:build_roberta_transformer_model_v2
spacy-curated-transformers.XlmrTransformer.v1 = spacy_curated_transformers.models:build_xlmr_transformer_model_v1
spacy-curated-transformers.XlmrTransformer.v2 = spacy_curated_transformers.models:build_xlmr_transformer_model_v2
spacy-curated-transformers.WithStridedSpans.v1 = spacy_curated_transformers.models:build_with_strided_spans_v1
spacy-curated-transformers.ScalarWeight.v1 = spacy_curated_transformers.models:build_scalar_weight_v1
spacy-curated-transformers.TransformerLayersListener.v1 = spacy_curated_transformers.models.listeners:build_transformer_layers_listener_v1
Expand Down
14 changes: 9 additions & 5 deletions spacy_curated_transformers/cli/fill_config_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,14 @@ def init_fill_curated_transformer_cli(

class HfParamSource(Enum):
MODEL_CONFIG = 1
TOKENIZER_CONFIG = 2
MODEL_CONFIG_OPTIONAL = 2
TOKENIZER_CONFIG = 3


# Entrypoint parameters that are common to all curated transformer models.
COMMON_ENTRYPOINT_PARAMS: Dict[str, HfParamSource] = {
"attention_probs_dropout_prob": HfParamSource.MODEL_CONFIG,
"dtype": HfParamSource.MODEL_CONFIG_OPTIONAL,
"hidden_act": HfParamSource.MODEL_CONFIG,
"hidden_dropout_prob": HfParamSource.MODEL_CONFIG,
"hidden_width": HfParamSource.MODEL_CONFIG,
Expand All @@ -99,6 +101,7 @@ class HfParamSource(Enum):
"intermediate_width": "intermediate_size",
"padding_idx": "pad_token_id",
"embedding_width": "embedding_size",
"dtype": "torch_dtype",
}


Expand Down Expand Up @@ -328,9 +331,9 @@ def _fill_parameters(
filled_params = {}
for param_name, source in params_to_fill.items():
hf_key = ENTRYPOINT_PARAMS_TO_HF_CONFIG_KEYS.get(param_name, param_name)
if source == HfParamSource.MODEL_CONFIG:
if source in (HfParamSource.MODEL_CONFIG, HfParamSource.MODEL_CONFIG_OPTIONAL):
value = hf_config.get(hf_key)
if value is None:
if value is None and source == HfParamSource.MODEL_CONFIG:
msg.fail(
f"Hugging Face model config has a missing key '{hf_key}'", exits=1
)
Expand All @@ -341,8 +344,9 @@ def _fill_parameters(
f"Hugging Face tokenizer config has a missing key '{hf_key}'",
exits=1,
)
assert value is not None
filled_params[param_name] = value
assert value is not None or source == HfParamSource.MODEL_CONFIG_OPTIONAL
if value is not None:
filled_params[param_name] = value

msg.info(title="Filled-in model parameters:")
msg.table(filled_params)
Expand Down
5 changes: 5 additions & 0 deletions spacy_curated_transformers/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from .architectures import (
build_albert_transformer_model_v1,
build_albert_transformer_model_v2,
build_bert_transformer_model_v1,
build_bert_transformer_model_v2,
build_camembert_transformer_model_v1,
build_camembert_transformer_model_v2,
build_pytorch_checkpoint_loader_v1,
build_roberta_transformer_model_v1,
build_roberta_transformer_model_v2,
build_xlmr_transformer_model_v1,
build_xlmr_transformer_model_v2,
)
from .hf_loader import build_hf_transformer_encoder_loader_v1
from .scalar_weight import build_scalar_weight_v1
Expand Down
Loading
Loading