-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TF generate refactor - Beam Search #16374
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@@ -259,8 +255,8 @@ def _create_score_penalties(self, input_ids, logits): | |||
np.put(token_penalties[i], prev_input_id, logit_penalties) | |||
return tf.convert_to_tensor(token_penalties, dtype=tf.float32) | |||
|
|||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: | |||
score_penalties = self._create_score_penalties(input_ids, scores) | |||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XLA greedy search was probably missing this as well in the logits processors, since it has the same padded input_ids
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes thanks for adding it! Just note that I don't really think we can make this processor XLA compilable anyways as it's very complex and numpy can't be used in XLA. cur_len
is mostly added in Flax/JAX to make the rprocessors XLA-compilable. But doesn't hurt to added it here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tf.unique
is not compatible with XLA because the output shape is dependent on the specific input data, and so cannot be inferred at compile time. However, there should be a way to make this logit processor XLA-compilable - there's probably some solution where you store counts in a sparse matrix and then use triu()
or tril()
followed by a matmul to see if a token has been preceded by the same token. Let me know if you want me to try that (here or in a separate PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧠
I'd leave it to a subsequent PR, XLA-readiness is not the main priority here and this PR is already very long
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks very nice! Great job so far - it's really not an easy PR.
Hope we can hunt down those final differences for the other models and have identical results between this and the previous version.
IMO the most important tests that need to pass here are the slow, long batched generation tests in TFT5 and TFBART. Once those tests pass I think we can be confident that it works.
Some final minor differences in TFBart might come from things like the length_penalty being slightly differently applied in the old version. E.g. if the hypothesis length is a bit different here:
score = sum_logprobs / len(hyp) ** self.length_penalty |
Note that we don't have 1-to-1 the same output in JAX as we do in TF here I think either, but it'd be important to match the new TF version exactly to the old TF version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job!
The PR looks good to merge for me more or less!
I'd advocate to change two final things:
-
- apply all logits processor the same way (as we do in PyTorch) at the expense of a negligible super-edge case where the repetition_penalty could lead to different results. No one uses (or should use) repetition penalty in beam search, topk yields in 99% of the cases the same result
-
- Add logits processors for
forced_bos_token_id
andforced_eos_token_id
and maybe adapt Marian'sadapt_logits_processor
function sighly to fit the one of PyTorch.
- Add logits processors for
# sets the score to 0 in the eos_token_id column | ||
scores = tf.zeros((batch_size, 1)) | ||
# sets the score to -inf everywhere else | ||
if self.eos_token_id > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) think it'd be cleaner to raise a ValueError if eos_token_id <= 0 in __init__()
. This should never be the case really. But maybe let's leave it for a follow-up PR
# sets the score to 0 in the bos_token_id column | ||
scores = tf.zeros((batch_size, 1)) | ||
# sets the score to -inf everywhere else | ||
if self.bos_token_id > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as for EOS
|
||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.call).parameters.keys()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work good to merge for me!
What does this PR do?
As discussed in the original TF generate refactor plan (#15562), adds
beam_search
.This Beam Search implementation was inspired by our FLAX implementation, which is XLA-friendly. However, this PR is not yet XLA-ready (😭). To pass existing tests, a few tweaks were added on top of the FLAX adaptation -- I added some comments in the PR to explain the differences (and why they were needed), hopefully making the review process easier.
Tests ran (and passing):