ソースを参照

vulkan: vec dot matrix multiplication fix (#16151)

* vulkan: fix matrix multiplication index calculation for odd m/n and odd k in combination with batching

* add odd m/n + odd k test with batching
Ruben Ortlam 3 ヶ月 前
コミット
9073a73d82

+ 20 - 8
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

@@ -31,10 +31,22 @@
 #include "types.comp"
 #include "types.comp"
 
 
 #ifndef LOAD_VEC_A
 #ifndef LOAD_VEC_A
-#define LOAD_VEC_A 2
+#define LOAD_VEC_A 1
 #endif
 #endif
 #ifndef LOAD_VEC_B
 #ifndef LOAD_VEC_B
-#define LOAD_VEC_B 2
+#define LOAD_VEC_B 1
+#endif
+
+// Load 2 values at once without affecting index calculations through LOAD_VEC
+#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED)
+#define LOAD_VEC_BATCH_A 2
+#else
+#define LOAD_VEC_BATCH_A 1
+#endif
+#if !defined(ALIGNED)
+#define LOAD_VEC_BATCH_B 2
+#else
+#define LOAD_VEC_BATCH_B 1
 #endif
 #endif
 
 
 #if !defined(TO_FLOAT_TYPE)
 #if !defined(TO_FLOAT_TYPE)
@@ -236,13 +248,13 @@ void main() {
     const uint warp_r = warp_i % (BM / WM);
     const uint warp_r = warp_i % (BM / WM);
     const uint warp_c = warp_i / (BM / WM);
     const uint warp_c = warp_i / (BM / WM);
 
 
-    const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
-    const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
-    const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
-    const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
+    const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
+    const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
+    const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
+    const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
 
 
-    const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK;
-    const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
+    const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK;
+    const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;
 
 
 #ifdef MUL_MAT_ID
 #ifdef MUL_MAT_ID
 #ifdef MUL_MAT_ID_USE_SUBGROUPS
 #ifdef MUL_MAT_ID_USE_SUBGROUPS

+ 9 - 9
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp

@@ -14,8 +14,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
             FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
             buf_a[buf_idx    ] = aa.xy;
             buf_a[buf_idx    ] = aa.xy;
             buf_a[buf_idx + 1] = aa.zw;
             buf_a[buf_idx + 1] = aa.zw;
-#else // LOAD_VEC_A == 2
-            const uint idx = pos_a * 2 + col * p.stride_a + row * 2;
+#else // LOAD_VEC_BATCH_A == 2
+            const uint idx = pos_a + col * p.stride_a + row * 2;
             const uint buf_idx = col * SHMEM_STRIDE + row;
             const uint buf_idx = col * SHMEM_STRIDE + row;
             if (idx_m < p.M && block + row * 2 + 1 < end_k) {
             if (idx_m < p.M && block + row * 2 + 1 < end_k) {
                 buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],
                 buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],
@@ -33,8 +33,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
             FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
             buf_a[buf_idx    ] = aa.xy;
             buf_a[buf_idx    ] = aa.xy;
             buf_a[buf_idx + 1] = aa.zw;
             buf_a[buf_idx + 1] = aa.zw;
-#else // LOAD_VEC_A == 2
-            const uint idx = pos_a * 2 + col * p.stride_a + row * 2;
+#else // LOAD_VEC_BATCH_A == 2
+            const uint idx = pos_a + col * p.stride_a + row * 2;
             const uint buf_idx = col * SHMEM_STRIDE + row;
             const uint buf_idx = col * SHMEM_STRIDE + row;
             if (idx_m < p.M && block + row * 2 + 1 < end_k) {
             if (idx_m < p.M && block + row * 2 + 1 < end_k) {
                 buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),
                 buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),
@@ -500,8 +500,8 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
 #endif
 #endif
             buf_b[buf_idx + 0] = bb.xy;
             buf_b[buf_idx + 0] = bb.xy;
             buf_b[buf_idx + 1] = bb.zw;
             buf_b[buf_idx + 1] = bb.zw;
-#else // LOAD_VEC_B == 2
-            const uint idx = pos_b * 2 + col * p.stride_b + row * 2;
+#else // LOAD_VEC_BATCH_B == 2
+            const uint idx = pos_b + col * p.stride_b + row * 2;
             const uint buf_idx = col * SHMEM_STRIDE + row;
             const uint buf_idx = col * SHMEM_STRIDE + row;
             if (idx_n < p.N && block + row * 2 + 1 < end_k) {
             if (idx_n < p.N && block + row * 2 + 1 < end_k) {
                 buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
                 buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
@@ -536,17 +536,17 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
 #endif
 #endif
             buf_b[buf_idx + 0] = bb.xy;
             buf_b[buf_idx + 0] = bb.xy;
             buf_b[buf_idx + 1] = bb.zw;
             buf_b[buf_idx + 1] = bb.zw;
-#else // LOAD_VEC_B == 2
+#else // LOAD_VEC_BATCH_B == 2
             const uint row_i = ic * BN + col;
             const uint row_i = ic * BN + col;
             const uint buf_idx = col * SHMEM_STRIDE + row;
             const uint buf_idx = col * SHMEM_STRIDE + row;
             if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
             if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
                 const u16vec2 row_idx = row_ids[col];
                 const u16vec2 row_idx = row_ids[col];
-                const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
+                const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
                 buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
                 buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
                                                  TO_FLOAT_TYPE(data_b[idx + 1]));
                                                  TO_FLOAT_TYPE(data_b[idx + 1]));
             } else if (row_i < _ne1 && block + row * 2 < end_k) {
             } else if (row_i < _ne1 && block + row * 2 < end_k) {
                 const u16vec2 row_idx = row_ids[col];
                 const u16vec2 row_idx = row_ids[col];
-                const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
+                const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
                 buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
                 buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
             } else {
             } else {
                 buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
                 buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);

+ 1 - 1
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

@@ -454,7 +454,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
 
 
         std::string data_a_key = "DATA_A_" + to_uppercase(tname);
         std::string data_a_key = "DATA_A_" + to_uppercase(tname);
         // For unaligned, load one at a time for f32/f16, or two at a time for quants
         // For unaligned, load one at a time for f32/f16, or two at a time for quants
-        std::string load_vec_a_unaligned = coopmat2 ? "1" : (tname == "f32" || tname == "f16" || tname == "bf16") ? "2" : load_vec_quant;
+        std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
         // For aligned matmul loads
         // For aligned matmul loads
         std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
         std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
 
 

+ 1 - 0
tests/test-backend-ops.cpp

@@ -6231,6 +6231,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1,  1}, {4, 1}, {0, 2, 1, 3}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1,  1}, {4, 1}, {0, 2, 1, 3}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67,  {1,  1}, {4, 1}, {0, 2, 1, 3}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67,  {1,  1}, {4, 1}, {0, 2, 1, 3}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 32, 32, { 1,  1}, {1, 1}, {0, 1, 2, 3}, true, 3));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 32, 32, { 1,  1}, {1, 1}, {0, 1, 2, 3}, true, 3));
+    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1}));
 
 
     for (auto bs2 : {1,3}) {
     for (auto bs2 : {1,3}) {
         for (auto bs : {1,2,4,8}) {
         for (auto bs : {1,2,4,8}) {