|
|
@@ -532,7 +532,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
#pragma unroll
|
|
|
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
|
if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
|
|
- KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l]);
|
|
|
+ KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -585,7 +585,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
|
if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
|
|
// Turing + Volta:
|
|
|
- KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l]);
|
|
|
+ KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
|
|
|
}
|
|
|
}
|
|
|
}
|