Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF: XLA repetition penalty #16879

Merged
merged 3 commits into from
Apr 22, 2022
Merged

Conversation

gante
Copy link
Member

@gante gante commented Apr 21, 2022

What does this PR do?

This PR adds our first XLA-compatible TF logit processor, as well as corresponding tests. Since this is the first of a series of small (but similar) PRs, I'd like to request a more thorough review, so the remaining ones are quick.

More specifically, this PR makes three changes:

  1. Rewrites the TF repetition penalty processor so as to be XLA-compatible;
  2. Adds XLA tests for the processor;
  3. Since the test mentioned in 2. was a near copy/paste of the non-XLA test, I've decided to split the test into three parts to improve code reuse and reduce errors from ad hoc edits (as the first and last part can be reused in the two versions of the test, XLA and non-XLA)
    • get inputs
    • run the processor
    • check the output

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 21, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a good port of the original numpy code!

(But I still think logit penalties should be additive rather than multiplicative)

@Rocketknight1
Copy link
Member

Thinking about it more, a multiplicative logit penalty really doesn't work, right? Even if we use the reciprocal when the logit is negative, the scale of the penalty depends on the logit's distance from 0. For example, a logit in the range -0.1 to +0.1 will barely be moved by the penalty term, but such logits usually have quite a high probability of being chosen, because most logits are large and negative.

@gante
Copy link
Member Author

gante commented Apr 22, 2022

(merging as the main goal was to port to XLA but, by all means, continue the discussion :) )

@gante gante merged commit 99c8226 into huggingface:main Apr 22, 2022
@gante gante deleted the xla_repetition_penalty branch April 22, 2022 17:29
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants