Procházet zdrojové kódy

CUDA: fix FTZ in FA for Gemma 3 (#13991)

Johannes Gäßler před 7 měsíci
rodič
revize
0b4be4c435
1 změnil soubory, kde provedl 4 přidání a 1 odebrání
  1. 4 1
      ggml/src/ggml-cuda/fattn-mma-f16.cuh

+ 4 - 1
ggml/src/ggml-cuda/fattn-mma-f16.cuh

@@ -652,9 +652,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         float KQ_max_scale[cols_per_thread];
         float KQ_max_scale[cols_per_thread];
 #pragma unroll
 #pragma unroll
         for (int col = 0; col < cols_per_thread; ++col) {
         for (int col = 0; col < cols_per_thread; ++col) {
-            KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]);
+            const float KQ_max_diff = KQ_max[col] - KQ_max_new[col];
+            KQ_max_scale[col] = expf(KQ_max_diff);
             KQ_max[col] = KQ_max_new[col];
             KQ_max[col] = KQ_max_new[col];
 
 
+            *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
+
             // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
             // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
             KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
             KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
         }
         }