-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Confusion about past_key_values and attention_mask in GPT2Attention #16811
Comments
Great question @wiio12! You're exactly right |
Thank you for your response @patrickvonplaten, very clear! Now I am sure how I wonder if this constraint is mentioned in any documentation, otherwise, the user may get an error with dimension mismatch but not know why this happens. |
Think it'd be a good idea to document this somewhere! Would you like to add a sentence to the documentation of the |
Not sure I did it correctly, but I change the Correct me if I did it wrong :) |
Looks great! |
Enviorment info
transformers
version: 4.12.5Models:
Infomation
When I read through the code in
modeling_gpt2
, I got confused about how attention_mask is used. Here, the code concatenates the past key and value into the current hidden_state's key and value. Here's the code inmodeling_gpt2.GPT2Attention
'sforward
method:However, later in the
self._attn
function, when using anattention_mask
, the code directly adds theattention_mask
to the attention weight. Here's the code in theself._attn
method:The
attn_weights
has the shape of[batch, n_head, query_length, key_length]
, andattention_mask
here has the shape of[batch, 1, 1, seq_length]
. Does this action imply that the input attention mask'sseq_length
must match the full context lengthkey_length
instead ofquery_length
? In other words, when we usepast_key_and_values
, theattention_mask
must contain sequences frompast_key_and_values
andinput_ids
instead of only the sequences frominput_ids
?The text was updated successfully, but these errors were encountered: