Browse Source

CUDA: fix KQ max calculation (#18487)

Johannes Gäßler 4 weeks ago
parent
commit
ecc343de63
1 changed files with 2 additions and 2 deletions
  1. 2 2
      ggml/src/ggml-cuda/fattn-mma-f16.cuh

+ 2 - 2
ggml/src/ggml-cuda/fattn-mma-f16.cuh

@@ -531,7 +531,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
         for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
 #pragma unroll
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
             for (int l = 0; l < T_C_KQ::ne; ++l) {
-                if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) {
+                if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
                     KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
                     KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
                 }
                 }
             }
             }
@@ -583,7 +583,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
         for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
 #pragma unroll
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
             for (int l = 0; l < T_C_KQ::ne; ++l) {
-                if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) {
+                if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
                     // Turing + Volta:
                     // Turing + Volta:
                     KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
                     KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
                 }
                 }