Skip to content

Commit

Permalink
fix: fix fmha kernel assert bug
Browse files Browse the repository at this point in the history
  • Loading branch information
byshiue committed Dec 6, 2022
1 parent aa2ceb5 commit 46e1f4a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5390,8 +5390,8 @@ class FusedMultiHeadAttentionXMMAKernelV2:

virtual uint64_t hashID(const KernelMeta& kernelMeta) const
{
assert(kernelMeta.mD == 64 || kernelMeta.mD == 32 || kernel.mD == 40 || kernel.mD == 80 || kernel.mD == 128
|| kernel.mD == 160 || kernel.mD == 256);
assert(kernelMeta.mD == 64 || kernelMeta.mD == 32 || kernelMeta.mD == 40 || kernelMeta.mD == 80 || kernelMeta.mD == 128
|| kernelMeta.mD == 160 || kernelMeta.mD == 256);
return hashID(kernelMeta.mS,
kernelMeta.mD,
kernelMeta.mInterleaved,
Expand Down
8 changes: 6 additions & 2 deletions 3rdparty/trt_fused_multihead_attention/qkvToContext.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ public:
warps_n = 8;
}
else {
assert(false && "Unsupporte seqlen");
// S >= 512, flash attention
warps_m = 4;
warps_n = 1;
}
}
else {
Expand All @@ -111,7 +113,9 @@ public:
warps_n = 8;
}
else {
assert(false && "Unsupporte seqlen");
// S >= 512, flash attention
warps_m = 4;
warps_n = 1;
}
}
// The number of threads per CTA.
Expand Down

0 comments on commit 46e1f4a

Please sign in to comment.