Jelajahi Sumber

opencl: fix FA for f32 (#16584)

lhez 3 bulan lalu
induk
melakukan
d93f8439b0
1 mengubah file dengan 4 tambahan dan 3 penghapusan
  1. 4 3
      ggml/src/ggml-opencl/kernels/flash_attn_f32.cl

+ 4 - 3
ggml/src/ggml-opencl/kernels/flash_attn_f32.cl

@@ -4,6 +4,7 @@
 #define ACC_TYPE4 float4
 #define DATA_TYPE float
 #define DATA_TYPE4 float4
+#define MASK_DATA_TYPE half
 #define CONVERT_ACC4(x) (x)
 #define CONVERT_DATA4(x) (x)
 
@@ -148,7 +149,7 @@ __kernel void flash_attn_f32(
             if (k_row1 >= n_kv) score1 = -INFINITY;
 
             if (mask_base != NULL) {
-                const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
+                const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
                 if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
                 if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
             }
@@ -281,7 +282,7 @@ __kernel void flash_attn_f32_q1(
         }
         ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
         if (mask_base != NULL) {
-            const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
+            const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
             score += slope * (ACC_TYPE)mask_ptr[k_idx];
         }
         if (logit_softcap > 0.0f) {
@@ -317,7 +318,7 @@ __kernel void flash_attn_f32_q1(
         }
         ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
         if (mask_base != NULL) {
-            const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
+            const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
             score += slope * (ACC_TYPE)mask_ptr[k_idx];
         }
         if (logit_softcap > 0.0f) {