Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF: add test for PushToHubCallback #20231

Merged
merged 8 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions src/transformers/keras_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from packaging.version import parse
from tensorflow.keras.callbacks import Callback

from huggingface_hub import Repository
from huggingface_hub import Repository, create_repo

from . import IntervalStrategy, PreTrainedTokenizerBase
from .modelcard import TrainingSummary
Expand Down Expand Up @@ -339,11 +339,13 @@ def __init__(

self.output_dir = output_dir
self.hub_model_id = hub_model_id
create_repo(self.hub_model_id, exist_ok=True)
self.repo = Repository(
str(self.output_dir),
clone_from=self.hub_model_id,
use_auth_token=hub_token if hub_token else True,
)

self.tokenizer = tokenizer
self.last_job = None
self.checkpoint = checkpoint
Expand Down Expand Up @@ -394,17 +396,22 @@ def on_epoch_end(self, epoch, logs=None):
)

def on_train_end(self, logs=None):
# Makes sure the latest version of the model is uploaded
if self.last_job is not None and not self.last_job.is_done:
self.last_job._process.terminate() # Gotta go fast
logging.info("Pushing the last epoch to the Hub, this may take a while...")
while not self.last_job.is_done:
sleep(1)
self.model.save_pretrained(self.output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(self.output_dir)
train_summary = TrainingSummary.from_keras(
model=self.model, model_name=self.hub_model_id, keras_history=self.training_history, **self.model_card_args
)
model_card = train_summary.to_model_card()
with (self.output_dir / "README.md").open("w") as f:
f.write(model_card)
self.repo.push_to_hub(commit_message="End of training", blocking=True)
else:
self.model.save_pretrained(self.output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(self.output_dir)
train_summary = TrainingSummary.from_keras(
model=self.model,
model_name=self.hub_model_id,
keras_history=self.training_history,
**self.model_card_args,
)
model_card = train_summary.to_model_card()
with (self.output_dir / "README.md").open("w") as f:
f.write(model_card)
self.repo.push_to_hub(commit_message="End of training", blocking=True)
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1:
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A Blenderbot sequence has the following format:
- single sequence: ` X </s>`

Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1:
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A Blenderbot sequence has the following format:
- single sequence: ` X </s>`

Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/markuplm/tokenization_markuplm.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def build_inputs_with_special_tokens(
adding special tokens. A RoBERTa sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s></s> B </s>`

Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,7 @@ def build_inputs_with_special_tokens(
adding special tokens. A RoBERTa sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s></s> B </s>`

Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/tapex/tokenization_tapex.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def build_inputs_with_special_tokens(
adding special tokens. A TAPEX sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s></s> B </s>`

Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
Expand Down
44 changes: 39 additions & 5 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig,
PushToHubCallback,
RagRetriever,
TFAutoModel,
TFAutoModelForSequenceClassification,
TFBertForMaskedLM,
TFBertModel,
TFRagModel,
TFSharedEmbeddings,
Expand Down Expand Up @@ -2344,6 +2346,11 @@ def tearDownClass(cls):
except HTTPError:
pass

try:
delete_repo(token=cls._token, repo_id="test-model-tf-callback")
except HTTPError:
pass

try:
delete_repo(token=cls._token, repo_id="valid_org/test-model-tf-org")
except HTTPError:
Expand All @@ -2363,13 +2370,14 @@ def test_push_to_hub(self):
model.push_to_hub("test-model-tf", use_auth_token=self._token)
logging.set_verbosity_warning()
# Check the model card was created and uploaded.
self.assertIn("Uploading README.md to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
self.assertIn("Uploading the following files to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)

new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)

# Reset repo
Expand All @@ -2382,8 +2390,32 @@ def test_push_to_hub(self):
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)

def test_push_to_hub_callback(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just like the test above (including the same checks), but uses the callback mechanism to push the model.

config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertForMaskedLM(config)
model.compile()

with tempfile.TemporaryDirectory() as tmp_dir:
push_to_hub_callback = PushToHubCallback(
output_dir=tmp_dir,
hub_model_id="test-model-tf-callback",
hub_token=self._token,
)
model.fit(model.dummy_inputs, model.dummy_inputs, epochs=1, callbacks=[push_to_hub_callback])

new_model = TFBertForMaskedLM.from_pretrained(f"{USER}/test-model-tf-callback")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)

def test_push_to_hub_in_organization(self):
Expand All @@ -2399,8 +2431,9 @@ def test_push_to_hub_in_organization(self):
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)

# Reset repo
Expand All @@ -2415,6 +2448,7 @@ def test_push_to_hub_in_organization(self):
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)