From 46e1f4a073b1c1eb9a22e712ad4535bde923e36f Mon Sep 17 00:00:00 2001 From: bhsueh Date: Mon, 5 Dec 2022 22:29:45 -0800 Subject: [PATCH] fix: fix fmha kernel assert bug --- .../fused_multihead_attention_v2.h | 4 ++-- 3rdparty/trt_fused_multihead_attention/qkvToContext.cu | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/3rdparty/trt_fused_multihead_attention/fused_multihead_attention_v2.h b/3rdparty/trt_fused_multihead_attention/fused_multihead_attention_v2.h index f04e74ade..dbe76e5d3 100644 --- a/3rdparty/trt_fused_multihead_attention/fused_multihead_attention_v2.h +++ b/3rdparty/trt_fused_multihead_attention/fused_multihead_attention_v2.h @@ -5390,8 +5390,8 @@ class FusedMultiHeadAttentionXMMAKernelV2: virtual uint64_t hashID(const KernelMeta& kernelMeta) const { - assert(kernelMeta.mD == 64 || kernelMeta.mD == 32 || kernel.mD == 40 || kernel.mD == 80 || kernel.mD == 128 - || kernel.mD == 160 || kernel.mD == 256); + assert(kernelMeta.mD == 64 || kernelMeta.mD == 32 || kernelMeta.mD == 40 || kernelMeta.mD == 80 || kernelMeta.mD == 128 + || kernelMeta.mD == 160 || kernelMeta.mD == 256); return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, diff --git a/3rdparty/trt_fused_multihead_attention/qkvToContext.cu b/3rdparty/trt_fused_multihead_attention/qkvToContext.cu index bce5af3e3..54a07c022 100644 --- a/3rdparty/trt_fused_multihead_attention/qkvToContext.cu +++ b/3rdparty/trt_fused_multihead_attention/qkvToContext.cu @@ -94,7 +94,9 @@ public: warps_n = 8; } else { - assert(false && "Unsupporte seqlen"); + // S >= 512, flash attention + warps_m = 4; + warps_n = 1; } } else { @@ -111,7 +113,9 @@ public: warps_n = 8; } else { - assert(false && "Unsupporte seqlen"); + // S >= 512, flash attention + warps_m = 4; + warps_n = 1; } } // The number of threads per CTA.