Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling Tapex in table question answering pipeline. #16663

Merged
merged 3 commits into from
Apr 14, 2022
Merged

Conversation

Narsil
Copy link
Contributor

@Narsil Narsil commented Apr 8, 2022

What does this PR do?

  • Enables newly output Tapex models in the pipeline
  • Use a self.type flag, since Tapas (the previously used models)
    have some non trivial defaults which do not apply to Tapex.(Ideally those type
    are not model specific like tapas but encoder vs encoder-decoder when
    applicable.
  • Added a slow test
  • Ideally we need a fast test, but there doesn't seem to be a small random
    model yet: https://huggingface.co/hf-internal-testing

@NielsRogge
@LysandreJik

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

)
self.type = "tapas" if hasattr(self.model.config, "aggregation_labels") else None
Copy link
Contributor

@NielsRogge NielsRogge Apr 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we can leverage model.config.is_encoder_decoder here (which is True for TAPEX but False for TAPAS)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually I would agree that reusing more "general" concepts is better.

The thing I want to avoid is treating not encoder-decoder as tapas. currently the code for Tapas is super highly specific, so I'd rather isolate tapas than isolate tapex if you want.

Ultimately both are OK tbh.

Comment on lines +340 to +344
if truncation is None:
if self.type == "tapas":
truncation = "drop_rows_to_fit"
else:
truncation = "do_not_truncate"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TAPEX also supports the "drop_rows_to_fit" truncation strategy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tokenizer said it didn't and gave me the regular list.

Is the tokenizer that gets resolved a BartTokenizer maybe ? (I have seen BartTokenizer in some random error log, I discarded it, but maybe ti's a True bug.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("microsoft/tapex-large-finetuned-wtq")
tokenizer.encode(table, truncation="drop_rows_to_fit")
ValueError: drop_rows_to_fit is not a valid TruncationStrategy, please select one of ['only_first', 'only_second', 'longest_first', 'do_not_truncate']

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reporting, investigating.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: you can continue with this, TAPEX doesn't support drop_rows_to_fit in a similar manner as TAPAS does.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 8, 2022

The documentation is not available anymore as the PR was closed or merged.

@Narsil
Copy link
Contributor Author

Narsil commented Apr 12, 2022

Gentle ping @NielsRogge if we want to go live next week.

@Narsil Narsil requested a review from LysandreJik April 13, 2022 08:36
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on it, looks good to me!

@Narsil Narsil merged commit 195fbbb into main Apr 14, 2022
@Narsil Narsil deleted the pipeline_tapex branch April 14, 2022 07:06
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
…6663)

* Enabling `Tapex` in table question answering pipeline.

* Questions are independant for Tapex, making the test respect that.

* Missing extra space.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants