Răsfoiți Sursa

CUDA: fix negative KV_max values in FA (#15321)

Johannes Gäßler 5 luni în urmă
părinte
comite
4227c9be42
1 a modificat fișierele cu 5 adăugiri și 1 ștergeri
  1. 5 1
      ggml/src/ggml-cuda/fattn-common.cuh

+ 5 - 1
ggml/src/ggml-cuda/fattn-common.cuh

@@ -539,11 +539,15 @@ static __global__ void flash_attn_mask_to_KV_max(
         all_inf = warp_reduce_all(all_inf);
         all_inf = warp_reduce_all(all_inf);
 
 
         if (!all_inf) {
         if (!all_inf) {
-            KV_max_sj += FATTN_KQ_STRIDE;
             break;
             break;
         }
         }
     }
     }
 
 
+    // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
+    // If the break was triggered it's the lower edge of the tile with the first non-masked values.
+    // In either case, walk back the decrementation by FATTN_KQ_STRIDE.
+    KV_max_sj += FATTN_KQ_STRIDE;
+
     if (threadIdx.x != 0) {
     if (threadIdx.x != 0) {
         return;
         return;
     }
     }