Просмотр исходного кода

CUDA: fix negative KV_max values in FA (#15321)

Johannes Gäßler 5 месяцев назад
Родитель
Сommit
4227c9be42
1 измененных файлов с 5 добавлено и 1 удалено
  1. 5 1
      ggml/src/ggml-cuda/fattn-common.cuh

+ 5 - 1
ggml/src/ggml-cuda/fattn-common.cuh

@@ -539,11 +539,15 @@ static __global__ void flash_attn_mask_to_KV_max(
         all_inf = warp_reduce_all(all_inf);
 
         if (!all_inf) {
-            KV_max_sj += FATTN_KQ_STRIDE;
             break;
         }
     }
 
+    // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
+    // If the break was triggered it's the lower edge of the tile with the first non-masked values.
+    // In either case, walk back the decrementation by FATTN_KQ_STRIDE.
+    KV_max_sj += FATTN_KQ_STRIDE;
+
     if (threadIdx.x != 0) {
         return;
     }