Просмотр исходного кода

Vulkan: Fix mmq int dot float cache size (#12722)

0cc4m 9 месяцев назад
Родитель
Сommit
92e3006bb6

+ 2 - 4
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

@@ -234,9 +234,9 @@ void main() {
 #endif
 
 #if QUANT_AUXF == 1
-    FLOAT_TYPE cache_a_dm[TM];
+    FLOAT_TYPE cache_a_dm[WMITER * TM];
 #else
-    FLOAT_TYPE_VEC2 cache_a_dm[TM];
+    FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
 #endif
 
     FLOAT_TYPE_VEC2 cache_b_ds[TN];
@@ -247,7 +247,6 @@ void main() {
             const uint iqs = loadr_a;
             const uint buf_ib = loadc_a + l;
 
-            // Should ds be gated to a single thread?
             if (iqs == 0) {
 #if QUANT_AUXF == 1
                 buf_a_dm[buf_ib] = get_d(ib);
@@ -276,7 +275,6 @@ void main() {
 
             const uint buf_ib = loadc_b + l;
 
-            // Should ds be gated to a single thread?
             if (iqs == 0) {
                 buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
             }

+ 2 - 2
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp

@@ -17,7 +17,7 @@ i32vec2 repack(uint ib, uint iqs) {
 }
 
 ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
-    return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0 * dsb.y));
+    return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y));
 }
 #endif
 
@@ -51,7 +51,7 @@ i32vec2 repack(uint ib, uint iqs) {
 }
 
 ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
-    return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0 * dsb.y));
+    return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y));
 }
 #endif