Sfoglia il codice sorgente

vulkan: use scalar FA rather than coopmat2 when N==1 (#13554)

Jeff Bolz 8 mesi fa
parent
commit
4f41ee11d6
1 ha cambiato i file con 7 aggiunte e 0 eliminazioni
  1. 7 0
      ggml/src/ggml-vulkan/ggml-vulkan.cpp

+ 7 - 0
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -5872,10 +5872,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     vk_pipeline *pipelines;
     vk_pipeline *pipelines;
     bool small_rows = N <= get_fa_num_small_rows(path);
     bool small_rows = N <= get_fa_num_small_rows(path);
 
 
+    // coopmat1 does not actually support "small rows" (it needs 16 rows).
+    // So use scalar instead.
     if (small_rows && path == FA_COOPMAT1) {
     if (small_rows && path == FA_COOPMAT1) {
         path = FA_SCALAR;
         path = FA_SCALAR;
     }
     }
 
 
+    // scalar is faster than coopmat2 when N==1
+    if (N == 1 && path == FA_COOPMAT2) {
+        path = FA_SCALAR;
+    }
+
     bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
     bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
 
 
     switch (path) {
     switch (path) {