Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF generate refactor - Beam Search #16374

Merged
merged 27 commits into from
Apr 6, 2022
Merged

Conversation

gante
Copy link
Member

@gante gante commented Mar 23, 2022

What does this PR do?

As discussed in the original TF generate refactor plan (#15562), adds beam_search.

This Beam Search implementation was inspired by our FLAX implementation, which is XLA-friendly. However, this PR is not yet XLA-ready (😭). To pass existing tests, a few tweaks were added on top of the FLAX adaptation -- I added some comments in the PR to explain the differences (and why they were needed), hopefully making the review process easier.

Tests ran (and passing):

  • GPT-2
  • T5
  • BART
  • Vision Encoder Decoder
  • Encoder Decoder
  • Speech to Text
  • RAG

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 23, 2022

The documentation is not available anymore as the PR was closed or merged.

@gante gante marked this pull request as ready for review March 25, 2022 10:50
@gante gante changed the title TF -- Beam Search refactor TF generate refactor - Beam Search Mar 25, 2022
@@ -259,8 +255,8 @@ def _create_score_penalties(self, input_ids, logits):
np.put(token_penalties[i], prev_input_id, logit_penalties)
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
score_penalties = self._create_score_penalties(input_ids, scores)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XLA greedy search was probably missing this as well in the logits processors, since it has the same padded input_ids

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes thanks for adding it! Just note that I don't really think we can make this processor XLA compilable anyways as it's very complex and numpy can't be used in XLA. cur_len is mostly added in Flax/JAX to make the rprocessors XLA-compilable. But doesn't hurt to added it here!

Copy link
Member

@Rocketknight1 Rocketknight1 Apr 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tf.unique is not compatible with XLA because the output shape is dependent on the specific input data, and so cannot be inferred at compile time. However, there should be a way to make this logit processor XLA-compilable - there's probably some solution where you store counts in a sparse matrix and then use triu() or tril() followed by a matmul to see if a token has been preceded by the same token. Let me know if you want me to try that (here or in a separate PR)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧠
I'd leave it to a subsequent PR, XLA-readiness is not the main priority here and this PR is already very long

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very nice! Great job so far - it's really not an easy PR.

Hope we can hunt down those final differences for the other models and have identical results between this and the previous version.

IMO the most important tests that need to pass here are the slow, long batched generation tests in TFT5 and TFBART. Once those tests pass I think we can be confident that it works.

Some final minor differences in TFBart might come from things like the length_penalty being slightly differently applied in the old version. E.g. if the hypothesis length is a bit different here:

score = sum_logprobs / len(hyp) ** self.length_penalty
.

Note that we don't have 1-to-1 the same output in JAX as we do in TF here I think either, but it'd be important to match the new TF version exactly to the old TF version.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

src/transformers/generation_tf_utils.py Outdated Show resolved Hide resolved
src/transformers/generation_tf_utils.py Show resolved Hide resolved
src/transformers/generation_tf_utils.py Outdated Show resolved Hide resolved
src/transformers/generation_tf_utils.py Outdated Show resolved Hide resolved
src/transformers/generation_tf_utils.py Outdated Show resolved Hide resolved
@gante
Copy link
Member Author

gante commented Apr 1, 2022

image

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job!

The PR looks good to merge for me more or less!

I'd advocate to change two final things:

    1. apply all logits processor the same way (as we do in PyTorch) at the expense of a negligible super-edge case where the repetition_penalty could lead to different results. No one uses (or should use) repetition penalty in beam search, topk yields in 99% of the cases the same result
    1. Add logits processors for forced_bos_token_id and forced_eos_token_id and maybe adapt Marian's adapt_logits_processor function sighly to fit the one of PyTorch.

# sets the score to 0 in the eos_token_id column
scores = tf.zeros((batch_size, 1))
# sets the score to -inf everywhere else
if self.eos_token_id > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) think it'd be cleaner to raise a ValueError if eos_token_id <= 0 in __init__() . This should never be the case really. But maybe let's leave it for a follow-up PR

# sets the score to 0 in the bos_token_id column
scores = tf.zeros((batch_size, 1))
# sets the score to -inf everywhere else
if self.bos_token_id > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as for EOS


accepts_attention_mask = "attention_mask" in set(inspect.signature(self.call).parameters.keys())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work good to merge for me!

@gante gante merged commit 3f43d82 into huggingface:main Apr 6, 2022
@gante gante deleted the beam_search_tf branch April 11, 2022 21:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants