Просмотр исходного кода

vulkan: Use fp16 for the flash attention P*V multiplication (#12783)

This is consistent with the ggml-cuda behavior and the mul_mat fallback.
Jeff Bolz 9 месяцев назад
Родитель
Сommit
7ecd780b1a
1 измененных файлов с 4 добавлено и 2 удалено
  1. 4 2
      ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

+ 4 - 2
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

@@ -330,9 +330,11 @@ void main() {
         // resize eM by using smear/reduce
         coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
 
-        O = eMdiag * O;
+        // multiply with fp16 accumulation, then add to O.
+        coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
+        PV = coopMatMulAdd(P_A, V, PV);
 
-        O = coopMatMulAdd(P_A, V, O);
+        O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
     }
 
     // If there is split_k, then the split_k resolve shader does the final