Просмотр исходного кода

metal : refactor + optimize (#15857)

* metal : refactor

ggml-ci

* cont : refactor FA-vec kernel

* cont : print metal library load time

* minor : warn to debug + bettern kernel names

ggml-ci

* metal : optimize mul_mv q8_0

ggml-ci

* metal : simplify FA pipeline creation functions

ggml-ci

* metal : improve naming consistency

* metal : safer function constants offsets

ggml-ci

* metal : comments

ggml-ci
Georgi Gerganov 4 месяцев назад
Родитель
Сommit
f28d4f4ac9

+ 44 - 4
ggml/src/ggml-metal/ggml-metal-impl.h

@@ -20,8 +20,8 @@
 #define N_R0_Q5_1 4
 #define N_SG_Q5_1 2
 
-#define N_R0_Q8_0 4
-#define N_SG_Q8_0 2
+#define N_R0_Q8_0 2
+#define N_SG_Q8_0 4
 
 #define N_R0_MXFP4 2
 #define N_SG_MXFP4 2
@@ -68,6 +68,11 @@
 #define N_R0_IQ4_XS 2
 #define N_SG_IQ4_XS 2
 
+// function constants offsets
+#define FC_FLASH_ATTN_EXT              100
+#define FC_FLASH_ATTN_EXT_VEC          200
+#define FC_FLASH_ATTN_EXT_VEC_REDUCE   300
+
 // kernel argument structs
 //
 // - element counters (e.g. ne00) typically use int32_t to reduce register usage
@@ -236,9 +241,11 @@ typedef struct {
     int32_t  ne11;
     int32_t  ne_12_2; // assume K and V are same shape
     int32_t  ne_12_3;
+    int32_t  ns10;
     uint64_t nb11;
     uint64_t nb12;
     uint64_t nb13;
+    int32_t  ns20;
     uint64_t nb21;
     uint64_t nb22;
     uint64_t nb23;
@@ -258,10 +265,43 @@ typedef struct {
     float    logit_softcap;
 } ggml_metal_kargs_flash_attn_ext;
 
+typedef struct {
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne11;
+    int32_t  ne_12_2; // assume K and V are same shape
+    int32_t  ne_12_3;
+    int32_t  ns10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ns20;
+    uint64_t nb21;
+    uint64_t nb22;
+    uint64_t nb23;
+    int32_t  ne32;
+    int32_t  ne33;
+    uint64_t nb31;
+    uint64_t nb32;
+    uint64_t nb33;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    float    scale;
+    float    max_bias;
+    float    m0;
+    float    m1;
+    int32_t  n_head_log2;
+    float    logit_softcap;
+} ggml_metal_kargs_flash_attn_ext_vec;
+
 typedef struct {
     int32_t  nrows;
-    int32_t  ne20;
-} ggml_metal_kargs_flash_attn_ext_reduce;
+} ggml_metal_kargs_flash_attn_ext_vec_reduce;
 
 typedef struct {
     int32_t  ne00;

Разница между файлами не показана из-за своего большого размера
+ 260 - 625
ggml/src/ggml-metal/ggml-metal.m


Разница между файлами не показана из-за своего большого размера
+ 622 - 417
ggml/src/ggml-metal/ggml-metal.metal


+ 2 - 2
tests/test-backend-ops.cpp

@@ -6501,8 +6501,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
     }
 
-    for (int hsk : { 40, 64, 80, 128, 192, 256, 576 }) {
-        for (int hsv : { 40, 64, 80, 128, 192, 256, 512 }) {
+    for (int hsk : { 40, 64, 80, 96, 128, 192, 256, 576 }) {
+        for (int hsv : { 40, 64, 80, 96, 128, 192, 256, 512 }) {
             if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
             if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
             if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA

Некоторые файлы не были показаны из-за большого количества измененных файлов