Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add doc tests for Albert and Bigbird #16774

Merged
merged 23 commits into from
Apr 22, 2022
Merged

Conversation

vumichien
Copy link
Contributor

What does this PR do?

Add doc tests for Albert and Bigbird, a part of issue #16292

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@patrickvonplaten, @ydshieh
Documentation: @sgugger

@vumichien vumichien closed this Apr 14, 2022
@vumichien vumichien reopened this Apr 14, 2022
@vumichien
Copy link
Contributor Author

vumichien commented Apr 14, 2022

@ydshieh Could you please take a look at it? I think we still have a problem with AlbertTokenizer as we have discussed on the Discord channel the AlbertTokenizer will add an extra "_" just after the "[MASK]" token which will lead to the different shape between input_text and target_text. This is the code snippet for checking the output.

from transformers import AlbertTokenizer, AlbertForMaskedLM
import torch

tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2")
model = AlbertForMaskedLM.from_pretrained("albert-base-v2")

input_text = "The capital of France is [MASK]."
target_text = "The capital of France is Paris."

tokenizer.tokenize(input_text)
# ['▁the', '▁capital', '▁of', '▁france', '▁is', ' [MASK]', '▁', '.']
tokenizer.tokenize(target_text )
# ['▁the', '▁capital', '▁of', '▁france', '▁is', '▁paris', '.']

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 14, 2022

The documentation is not available anymore as the PR was closed or merged.

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 14, 2022

Hi, @vumichien

I won't be available for the next few days. Will check when I am back, or my colleague could check this PR :-)

Regarding the Albert tokenizer, do you encounter any runtime error due to the shape issue? I understand that the shapes are different, and had a short discussion with the team. But we thought it should still work. Sorry for not responding this part earlier, but if you see errors due to these shapes, could you post it here, please?

@vumichien
Copy link
Contributor Author

vumichien commented Apr 14, 2022

@ydshieh When I run the test for doc for modeling_albert.py in local, the error will show like the following (sorry for very long error)

======================================================================================= FAILURES =======================================================================================
____________________________________________________ [doctest] transformers.models.albert.modeling_albert.AlbertForMaskedLM.forward ____________________________________________________
1034     >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
1035     >>> tokenizer.decode(predicted_token_id)
1036     'reims'
1037 
1038     ```
1039 
1040     ```python
1041     >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
1042     >>> # mask labels of non-[MASK] tokens
1043     >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
UNEXPECTED EXCEPTION: RuntimeError('The size of tensor a (10) must match the size of tensor b (9) at non-singleton dimension 1')
Traceback (most recent call last):
  File "/usr/lib/python3.8/doctest.py", line 1336, in __run
    exec(compile(example.source, filename, "single",
  File "<doctest transformers.models.albert.modeling_albert.AlbertForMaskedLM.forward[10]>", line 1, in <module>
RuntimeError: The size of tensor a (10) must match the size of tensor b (9) at non-singleton dimension 1
/home/vumichien/Detomo/transformers/src/transformers/models/albert/modeling_albert.py:1043: UnexpectedException
1036     'reims'
1037 
1038     ```
1039 
1040     ```python
1041     >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
1042     >>> # mask labels of non-[MASK] tokens
1043     >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
1044 
1045     >>> outputs = model(**inputs, labels=labels)
UNEXPECTED EXCEPTION: ValueError('Expected input batch_size (10) to match target batch_size (9).')
Traceback (most recent call last):
  File "/usr/lib/python3.8/doctest.py", line 1336, in __run
    exec(compile(example.source, filename, "single",
  File "<doctest transformers.models.albert.modeling_albert.AlbertForMaskedLM.forward[11]>", line 1, in <module>
  File "/home/vumichien/Detomo/transformers/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/vumichien/Detomo/transformers/src/transformers/models/albert/modeling_albert.py", line 964, in forward
    masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  File "/home/vumichien/Detomo/transformers/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/vumichien/Detomo/transformers/venv/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1163, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/home/vumichien/Detomo/transformers/venv/lib/python3.8/site-packages/torch/nn/functional.py", line 2996, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (10) to match target batch_size (9).
/home/vumichien/Detomo/transformers/src/transformers/models/albert/modeling_albert.py:1045: UnexpectedException
1037 
1038     ```
1039 
1040     ```python
1041     >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
1042     >>> # mask labels of non-[MASK] tokens
1043     >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
1044 
1045     >>> outputs = model(**inputs, labels=labels)
1046     >>> round(outputs.loss.item(), 2)
UNEXPECTED EXCEPTION: NameError("name 'outputs' is not defined")
Traceback (most recent call last):
  File "/usr/lib/python3.8/doctest.py", line 1336, in __run
    exec(compile(example.source, filename, "single",
  File "<doctest transformers.models.albert.modeling_albert.AlbertForMaskedLM.forward[12]>", line 1, in <module>
NameError: name 'outputs' is not defined
/home/vumichien/Detomo/transformers/src/transformers/models/albert/modeling_albert.py:1046: UnexpectedException
=================================================================================== warnings summary ===================================================================================
venv/lib/python3.8/site-packages/flatbuffers/compat.py:19
  /home/vumichien/Detomo/transformers/venv/lib/python3.8/site-packages/flatbuffers/compat.py:19: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
    import imp

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=============================================================================== short test summary info ================================================================================
FAILED src/transformers/models/albert/modeling_albert.py::transformers.models.albert.modeling_albert.AlbertForMaskedLM.forward
======================================================================= 1 failed, 6 passed, 1 warning in 52.18s ========================================================================

The error log is the same when I run test with doc for modeling_tf_albert.py

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 14, 2022

Maybe a quick easy way is just to overwrite the examples for AlbertForMaskedLM in the model files. Something similar to #16565 (comment)

But that case is reversed: masked input has fewer tokens. So you need to have some different operations.

Let's wait @patrickvonplaten to see if he has better suggestion.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good for me once the Albert example for masked language modeling is fixed. Thanks!

@@ -2397,6 +2397,8 @@ def set_output_embeddings(self, new_embeddings):
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
expected_output="'here'",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to comment so late here. Could we maybe overwrite the BigBird example as well? https://huggingface.co/google/bigbird-roberta-base has quite a significant number of downloads and it's know to be a long-range model. Could we maybe provide a long input to be masked here and to all other examples below as well?

Would be great if we could overwrite the example doc string here @ydshieh

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a great idea. Let me prepare better examples for Bigbird.

@patrickvonplaten
Copy link
Contributor

@vumichien @ydshieh, I'd be in favor of overwriting both Albert (so that MLM is correct) as well as BigBird (to show that it's long-range). What do you think?

@vumichien
Copy link
Contributor Author

@ydshieh @patrickvonplaten I have overwritten both the doc-test examples of Albert and Bigbird. What do you think about them?

>>> answer_end_index = outputs.end_logits.argmax()
>>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
>>> tokenizer.decode(predict_answer_tokens)
'Old College'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool example!

@patrickvonplaten
Copy link
Contributor

@ydshieh @patrickvonplaten I have overwritten both the doc-test examples of Albert and Bigbird. What do you think about them?

That's great! The classification and QA example could be made even much longer for BigBird :-) The examples look already great though. Happy to merge as is as well :-)

@vumichien
Copy link
Contributor Author

I have changed the longer examples for doctest. The examples are quite long, but in my opinion, they are good to show that Bigbird is long-range model

@sgugger
Copy link
Collaborator

sgugger commented Apr 20, 2022

Can we put that text in some dataset instead? The documentation will become a bit unreadable with such a long text, where as we could just load a dataset in one line and take the first sample.

@vumichien
Copy link
Contributor Author

@sgugger Thank you for your suggestion. I have changed to use the examples from squad datasets. How do you think about that?

@sgugger
Copy link
Collaborator

sgugger commented Apr 20, 2022

Way better, and great that you're showing the shape! Good for me if @patrickvonplaten is okay.

@ydshieh ydshieh self-requested a review April 20, 2022 20:02
Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As your previous PR, very high quality!
Thank you @vumichien for the effort to overwrite the doctest code 💯

I leave 2 tiny suggestions, but no need to feel necessary to do it.

Ran locally -> all tests passed!


>>> LONG_ARTICLE_TARGET = squad_ds[81514]["context"]
>>> # add mask_token
>>> LONG_ARTICLE_TO_MASK = LONG_ARTICLE_TARGET.replace("maximum", "[MASK]")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a tiny nit: could we show a few words around the target world maximum? Just for the readers to be able to see the context, and find the output indeed make sense :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will revise as your suggestion

@@ -2858,9 +2910,12 @@ def __init__(self, config):
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint="vumichien/token-classification-bigbird-roberta-base",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, you have trained a token classification model ..? 💯

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, it's just a random weight...(but maybe I will try to train bigbird for token classification in the near future 😅). The reason why I didn't use a model from hf-internal-testing (hf-internal-testing/tiny-random-bigbird_pegasus) is that I think it's also the random weight model but the output is too long. But if you think it's not a good way, I will revise with this checkpoint hf-internal-testing/tiny-random-bigbird_pegasus

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's OK, it is completely fine.

However, if vumichien/token-classification-bigbird-roberta-base has random weights, it is a good idea to have the name like

vumichien/token-classification-bigbird-roberta-base-random

This way, the doc reader and hub users won't be confused 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I will change the checkpoint name

>>> squad_ds = load_dataset("squad_v2", split="train") # doctest: +IGNORE_RESULT

>>> LONG_ARTICLE = squad_ds[81514]["context"]
>>> QUESTION = squad_ds[81514]["question"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would be good to show the question text as a comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will revise as your suggestion

@vumichien
Copy link
Contributor Author

@ydshieh I have revised as your suggestion. Please let me know if I need to revise something.

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 21, 2022

@ydshieh I have revised as your suggestion. Please let me know if I need to revise something.

Love it! Thank you.
I will let @patrickvonplaten to have a final look (if any) & click the merge button 💯


>>> tokenizer = BigBirdTokenizer.from_pretrained("l-yohai/bigbird-roberta-base-mnli")
>>> model = BigBirdForSequenceClassification.from_pretrained("l-yohai/bigbird-roberta-base-mnli")
>>> squad_ds = load_dataset("squad_v2", split="train") # doctest: +IGNORE_RESULT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's great!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing job @vumichien - thanks a mille for making the example so nice :-)

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 22, 2022

Running the last time -> will merge if all tests pass

Merged 🚀 Thanks again!

@ydshieh ydshieh merged commit 0d1cff1 into huggingface:main Apr 22, 2022
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* Add doctest BERT

* make fixup

* fix typo

* change checkpoints

* make fixup

* define doctest output value, update doctest for mobilebert

* solve fix-copies

* update QA target start index and end index

* change checkpoint for docs and reuse defined variable

* Update src/transformers/models/bert/modeling_tf_bert.py

Co-authored-by: Yih-Dar <[email protected]>

* Apply suggestions from code review

Co-authored-by: Yih-Dar <[email protected]>

* Apply suggestions from code review

Co-authored-by: Yih-Dar <[email protected]>

* make fixup

* Add Doctest for Albert and Bigbird

* make fixup

* overwrite examples for Albert and Bigbird

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* update longer examples for Bigbird

* using examples from squad_v2

* print out example text

* change name token-classification-big-bird checkpoint to random

Co-authored-by: Yih-Dar <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants