Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clear output of Torch SDPA for masked pieces #360

Merged
merged 6 commits into from
Feb 8, 2024

Conversation

danieldk
Copy link
Contributor

@danieldk danieldk commented Feb 8, 2024

Description

Since Torch 2.1, the Torch memory-efficient SDPA GPU kernel returns NaN for pieces that are completely masked out. This leads to NaN propagation in the next attention layer, because masked pieces get an attention of zero, but zero times NaN is still NaN.

In this we fix this by setting masked tokens to zero to clear out any NaNs.

We currently rely on the query dimension of the mask to be singular, but in the future we should probably redesign the AttentionMask class to account for the differences between attention masks and causal masks.

Checklist

  • I confirm that I have the right to submit this contribution under the project's MIT license.

@danieldk danieldk added type/bug Type: Bug feat/layers Feature: Layers labels Feb 8, 2024
Since Torch 2.1, the Torch memory-efficient SDPA GPU kernel returns NaN
for pieces that are completely masked out. This leads to NaN propagation
in the next attention layer, because masked pieces get an attention of
zero, but zero times NaN is still NaN.

In this we fix this by setting masked tokens to zero to clear out any
NaNs.

We currently rely on the query dimension of the mask to be singular, but
in the future we should probably redesign the `AttentionMask` class to
account for the differences between attention masks and causal masks.
danieldk and others added 2 commits February 8, 2024 20:02
@danieldk danieldk merged commit f9da3b5 into explosion:main Feb 8, 2024
9 checks passed
@danieldk danieldk deleted the bugfix/sdpa-attention-nan branch February 8, 2024 19:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feat/layers Feature: Layers type/bug Type: Bug
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants