Просмотр исходного кода

metal : use F32 prec for K*Q in vec FA (#9595)

ggml-ci
Georgi Gerganov 1 год назад
Родитель
Сommit
bf9c1013ac
1 измененных файлов с 7 добавлено и 7 удалено
  1. 7 7
      ggml/src/ggml-metal.metal

+ 7 - 7
ggml/src/ggml-metal.metal

@@ -2631,11 +2631,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
         const short iv3 = iq3 / rv3;
 
         // load the queries from shared memory into local memory
-        half4 mq[D4];
+        float4 mq[D4];
 
         for (short ii = 0; ii < D4; ii += NW) {
             short i = ii + tiisg;
-            mq[i] = sq4[i];
+            mq[i] = (float4) sq4[i];
         }
 
         // pointer to the mask
@@ -2661,11 +2661,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
                     for (short ii = 0; ii < D4; ii += NW) {
                         const short i = ii + tiisg;
 
-                        half4x4 mk;
-                        mk[0] = pk4[i + 0*(nb11/8)];
-                        mk[1] = pk4[i + 1*(nb11/8)];
-                        mk[2] = pk4[i + 2*(nb11/8)];
-                        mk[3] = pk4[i + 3*(nb11/8)];
+                        float4x4 mk;
+                        mk[0] = (float4) pk4[i + 0*(nb11/8)];
+                        mk[1] = (float4) pk4[i + 1*(nb11/8)];
+                        mk[2] = (float4) pk4[i + 2*(nb11/8)];
+                        mk[3] = (float4) pk4[i + 3*(nb11/8)];
 
                         mqk += (float4) (mq[i] * mk);
                     }