Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confusion about past_key_values and attention_mask in GPT2Attention #16811

Closed
wiio12 opened this issue Apr 17, 2022 · 5 comments · Fixed by #16829
Closed

Confusion about past_key_values and attention_mask in GPT2Attention #16811

wiio12 opened this issue Apr 17, 2022 · 5 comments · Fixed by #16829

Comments

@wiio12
Copy link
Contributor

wiio12 commented Apr 17, 2022

Enviorment info

  • transformers version: 4.12.5

Models:

Infomation

When I read through the code in modeling_gpt2, I got confused about how attention_mask is used. Here, the code concatenates the past key and value into the current hidden_state's key and value. Here's the code in modeling_gpt2.GPT2Attention's forward method:

        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)

        if layer_past is not None:
            past_key, past_value = layer_past
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)

        if use_cache is True:
            present = (key, value)
        else:
            present = None

        if self.reorder_and_upcast_attn:
            attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
        else:
            attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

However, later in the self._attn function, when using an attention_mask, the code directly adds the attention_mask to the attention weight. Here's the code in the self._attn method:

def _attn(self, query, key, value, attention_mask=None, head_mask=None):
        attn_weights = torch.matmul(query, key.transpose(-1, -2))

        if self.scale_attn_weights:
            attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)

        # Layer-wise attention scaling
        if self.scale_attn_by_inverse_layer_idx:
            attn_weights = attn_weights / float(self.layer_idx + 1)

        if not self.is_cross_attention:
            # if only "normal" attention layer implements causal mask
            query_length, key_length = query.size(-2), key.size(-2)
            causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
            attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))

        if attention_mask is not None:
            # Apply the attention mask
            attn_weights = attn_weights + attention_mask

        attn_weights = nn.Softmax(dim=-1)(attn_weights)

The attn_weights has the shape of [batch, n_head, query_length, key_length], and attention_mask here has the shape of [batch, 1, 1, seq_length]. Does this action imply that the input attention mask's seq_length must match the full context length key_length instead of query_length? In other words, when we use past_key_and_values, the attention_mask must contain sequences from past_key_and_values and input_ids instead of only the sequences from input_ids?

@patrickvonplaten
Copy link
Contributor

Great question @wiio12!

You're exactly right attention_mask needs to contain the masking strategy that was used for past_key_values. In other words, the attention_mask always has to have the length: len(past_key_values) + len(input_ids)

@wiio12
Copy link
Contributor Author

wiio12 commented Apr 19, 2022

Thank you for your response @patrickvonplaten, very clear! Now I am sure how past_key_value and attention_mask work.

I wonder if this constraint is mentioned in any documentation, otherwise, the user may get an error with dimension mismatch but not know why this happens.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Apr 19, 2022

Think it'd be a good idea to document this somewhere! Would you like to add a sentence to the documentation of the attention_mask parameter in GPT2?

@wiio12
Copy link
Contributor Author

wiio12 commented Apr 19, 2022

Not sure I did it correctly, but I change the doc_string in modeling_gpt2 and modeling_tf_gpt2 and make a PR #16829.

Correct me if I did it wrong :)

@patrickvonplaten
Copy link
Contributor

Looks great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants