|
|
@@ -955,22 +955,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
(K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
|
|
|
}
|
|
|
|
|
|
- for (; kb0 < kb0_stop-1; ++kb0) {
|
|
|
- constexpr bool last_iter = false;
|
|
|
- constexpr bool oob_check = false;
|
|
|
- constexpr int k_VKQ_sup = nbatch_fa;
|
|
|
- flash_attn_ext_f16_iter
|
|
|
- <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
|
|
- T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
|
- (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
|
- ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
|
- KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
|
|
- }
|
|
|
// kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
|
|
if constexpr (ncols2 == 1) {
|
|
|
- if (ne11 % nbatch_fa == 0) {
|
|
|
- constexpr bool last_iter = true;
|
|
|
- constexpr bool oob_check = false;
|
|
|
+ constexpr bool oob_check = true;
|
|
|
+ for (; kb0 < kb0_stop-1; ++kb0) {
|
|
|
+ constexpr bool last_iter = false;
|
|
|
constexpr int k_VKQ_sup = nbatch_fa;
|
|
|
flash_attn_ext_f16_iter
|
|
|
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
|
|
@@ -978,10 +967,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
|
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
|
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
|
|
- } else {
|
|
|
- constexpr bool last_iter = true;
|
|
|
- constexpr bool oob_check = true;
|
|
|
- const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
|
|
|
+ }
|
|
|
+ constexpr bool last_iter = true;
|
|
|
+ const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
|
|
|
+ flash_attn_ext_f16_iter
|
|
|
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
|
|
+ T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
|
+ (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
|
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
|
+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
|
|
+ } else {
|
|
|
+ constexpr bool oob_check = false;
|
|
|
+ for (; kb0 < kb0_stop-1; ++kb0) {
|
|
|
+ constexpr bool last_iter = false;
|
|
|
+ constexpr int k_VKQ_sup = nbatch_fa;
|
|
|
flash_attn_ext_f16_iter
|
|
|
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
|
|
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
|
@@ -989,9 +988,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
|
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
|
|
}
|
|
|
- } else {
|
|
|
constexpr bool last_iter = true;
|
|
|
- constexpr bool oob_check = false;
|
|
|
constexpr int k_VKQ_sup = nbatch_fa;
|
|
|
flash_attn_ext_f16_iter
|
|
|
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|