Просмотр исходного кода

Vulkan Optimizations and Fixes (#8959)

* Optimize Vulkan REPEAT performance

* Use Vulkan GLSL fused multiply-add instruction where possible

* Add GGML_VULKAN_PERF option to output performance data per operator

* Rework and fix Vulkan descriptor set and descriptor pool handling

* Fix float32 concat f16 shader validation error

* Add Vulkan GROUP_NORM eps parameter

* Fix validation error with transfer queue memory barrier flags

* Remove trailing whitespaces
0cc4m 1 год назад
Родитель
Сommit
5fd89a70ea

+ 4 - 0
Makefile

@@ -763,6 +763,10 @@ ifdef GGML_VULKAN_MEMORY_DEBUG
 	MK_CPPFLAGS  += -DGGML_VULKAN_MEMORY_DEBUG
 endif
 
+ifdef GGML_VULKAN_PERF
+	MK_CPPFLAGS  += -DGGML_VULKAN_PERF
+endif
+
 ifdef GGML_VULKAN_VALIDATE
 	MK_CPPFLAGS  += -DGGML_VULKAN_VALIDATE
 endif

+ 1 - 0
ggml/CMakeLists.txt

@@ -135,6 +135,7 @@ option(GGML_VULKAN                          "ggml: use Vulkan"
 option(GGML_VULKAN_CHECK_RESULTS            "ggml: run Vulkan op checks"                      OFF)
 option(GGML_VULKAN_DEBUG                    "ggml: enable Vulkan debug output"                OFF)
 option(GGML_VULKAN_MEMORY_DEBUG             "ggml: enable Vulkan memory debug output"         OFF)
+option(GGML_VULKAN_PERF                     "ggml: enable Vulkan perf output"                 OFF)
 option(GGML_VULKAN_VALIDATE                 "ggml: enable Vulkan validation"                  OFF)
 option(GGML_VULKAN_RUN_TESTS                "ggml: run Vulkan tests"                          OFF)
 option(GGML_KOMPUTE                         "ggml: use Kompute"                               OFF)

+ 4 - 0
ggml/src/CMakeLists.txt

@@ -602,6 +602,10 @@ if (GGML_VULKAN)
             add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG)
         endif()
 
+        if (GGML_VULKAN_PERF)
+            add_compile_definitions(GGML_VULKAN_PERF)
+        endif()
+
         if (GGML_VULKAN_VALIDATE)
             add_compile_definitions(GGML_VULKAN_VALIDATE)
         endif()

Разница между файлами не показана из-за своего большого размера
+ 352 - 242
ggml/src/ggml-vulkan.cpp


+ 5 - 1
ggml/src/vulkan-shaders/concat.comp

@@ -30,6 +30,10 @@ void main() {
 #ifndef OPTIMIZATION_ERROR_WORKAROUND
     data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]);
 #else
-    data_d[p.d_offset + dst_idx] = is_src0 ? data_a[src0_idx] : data_b[src1_idx];
+    if (is_src0) {
+        data_d[p.d_offset + dst_idx] = data_a[src0_idx];
+    } else {
+        data_d[p.d_offset + dst_idx] = data_b[src1_idx];
+    }
 #endif
 }

+ 1 - 2
ggml/src/vulkan-shaders/mul_mat_vec.comp

@@ -39,8 +39,7 @@ void main() {
         vec2 v = dequantize(ib, iqs, a_offset / QUANT_K);
 
         // matrix multiplication
-        tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[b_offset + iybs + iqs]) +
-                    FLOAT_TYPE(v.y) * FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
+        tmp[tid] = fma(FLOAT_TYPE(v.x), FLOAT_TYPE(data_b[b_offset + iybs + iqs]), fma(FLOAT_TYPE(v.y), FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]), tmp[tid]));
     }
 
     // sum up partial sums and write back result

+ 1 - 1
ggml/src/vulkan-shaders/mul_mat_vec_nc.comp

@@ -53,7 +53,7 @@ void main() {
 
         const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
 
-        tmp[tid] += xi * FLOAT_TYPE(data_b[iy]);
+        tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
     }
 
     // sum up partial sums and write back result

+ 1 - 1
ggml/src/vulkan-shaders/mul_mat_vec_p021.comp

@@ -52,7 +52,7 @@ void main() {
         // y is not transposed but permuted
         const uint iy = channel*nrows_y + row_y;
 
-        tmp[tid] += xi * FLOAT_TYPE(data_b[iy]);
+        tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
     }
 
     // dst is not transposed and not permuted

+ 18 - 17
ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp

@@ -39,24 +39,25 @@ void main() {
         FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
         FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
         for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
-            sum1 += FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3);
-            sum2 += FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF);
+            sum1 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3), sum1))))))));
+            sum2 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF), sum2))))))));
         }
-        tmp[16 * ix + tid] += dall * sum1 - dmin * sum2;
+        const uint tmp_idx = 16 * ix + tid;
+        tmp[tmp_idx] = fma(dall, sum1, fma(-dmin, sum2, tmp[tmp_idx]));
     }
 
     // sum up partial sums and write back result

+ 10 - 9
ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp

@@ -40,16 +40,17 @@ void main() {
 
         FLOAT_TYPE sum = FLOAT_TYPE(0.0);
         for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
-            sum += FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 0)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 1)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 2)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 3)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4));
+            sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 0)) != 0) ? 0 : 4)),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 1)) != 0) ? 0 : 4)),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 2)) != 0) ? 0 : 4)),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 3)) != 0) ? 0 : 4)),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
         }
-        tmp[16 * ix + tid] += d * sum;
+        const uint tmp_idx = 16 * ix + tid;
+        tmp[tmp_idx] = fma(d, sum, tmp[tmp_idx]);
     }
 
     // sum up partial sums and write back result

+ 24 - 21
ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp

@@ -67,17 +67,17 @@ void main() {
         const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66]  >> 4);
         const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67]  >> 4);
 
-        const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1 + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3);
-        const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_5 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7);
-        const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx]) * q4_8 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_9 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * q4_10 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11);
-        const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_12 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_13 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * q4_14 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15);
-        const FLOAT_TYPE smin = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y1_idx    ]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx    ]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * sc7
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7
-        );
-        tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
+        const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx]),      q4_0,  fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]),  q4_1,  fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]),  q4_2,  FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) *  q4_3)));
+        const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_4,  fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), q4_5,  fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), q4_6,  FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7)));
+        const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx]),      q4_8,  fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]),  q4_9,  fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]),  q4_10, FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) *  q4_11)));
+        const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_12, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), q4_13, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), q4_14, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15)));
+        const FLOAT_TYPE smin =
+            fma(FLOAT_TYPE(data_b[b_offset + y1_idx    ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx    ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7,
+            fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), sc7,
+            fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), sc7,
+            fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 3]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 35]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 3]), sc6,     FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7)))))))))))))));
+        const uint tmp_idx = 16 * ix + tid;
+        tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
 #else
         const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset     ] & 0xf);
         const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset +  1] & 0xf);
@@ -88,16 +88,19 @@ void main() {
         const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64]  >> 4);
         const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65]  >> 4);
 
-        const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx     ]) * q4_0  + FLOAT_TYPE(data_b[b_offset + y1_idx +  1]) * q4_1);
-        const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_2  + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3);
-        const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx     ]) * q4_4  + FLOAT_TYPE(data_b[b_offset + y2_idx +  1]) * q4_5);
-        const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7);
-        const FLOAT_TYPE smin = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y1_idx]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7
-        );
-
-        tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
+        const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx     ]), q4_0, FLOAT_TYPE(data_b[b_offset + y1_idx +  1]) * q4_1);
+        const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3);
+        const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx     ]), q4_4, FLOAT_TYPE(data_b[b_offset + y2_idx +  1]) * q4_5);
+        const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7);
+        const FLOAT_TYPE smin =
+            fma(FLOAT_TYPE(data_b[b_offset + y1_idx    ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx    ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7,
+          + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7)))))));
+
+        tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) +
+                        sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
+        const uint tmp_idx = 16 * ix + tid;
+        tmp[tmp_idx] = fma(dall, (fma(sx, FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f), fma(sy, FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f),
+                       fma(sz, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)), fma(sw, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))))))), fma(-dmin, smin, tmp[tmp_idx]));
 #endif
     }
 

+ 27 - 29
ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp

@@ -66,35 +66,33 @@ void main() {
         const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80]  >> 4);
         const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81]  >> 4);
 
-        const FLOAT_TYPE sx = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y1_idx     ]) * (q4_0 + (((data_a[ib0 + i].qh[l0     ] & hm1) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx +  1]) * (q4_1 + (((data_a[ib0 + i].qh[l0 +  1] & hm1) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) * (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0))
-        );
-        const FLOAT_TYPE sy = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * (q4_4 + (((data_a[ib0 + i].qh[l0     ] & (hm1 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * (q4_5 + (((data_a[ib0 + i].qh[l0 +  1] & (hm1 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) * (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0))
-        );
-        const FLOAT_TYPE sz = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y2_idx     ]) * (q4_8  + (((data_a[ib0 + i].qh[l0     ] & hm2) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx +  1]) * (q4_9  + (((data_a[ib0 + i].qh[l0 +  1] & hm2) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) * (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0))
-        );
-        const FLOAT_TYPE sw = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * (q4_12 + (((data_a[ib0 + i].qh[l0     ] & (hm2 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * (q4_13 + (((data_a[ib0 + i].qh[l0 +  1] & (hm2 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) * (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0))
-        );
-        const FLOAT_TYPE smin = FLOAT_TYPE(
-            (FLOAT_TYPE(data_b[b_offset + y1_idx]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17])) * sc2 + (FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49])) * sc3
-          + (FLOAT_TYPE(data_b[b_offset + y2_idx]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17])) * sc6 + (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7
-        );
-        tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
+        const FLOAT_TYPE sx =
+          fma(FLOAT_TYPE(data_b[b_offset + y1_idx     ]), (q4_0 + (((data_a[ib0 + i].qh[l0     ] & hm1) != 0) ? 16 : 0)),
+          fma(FLOAT_TYPE(data_b[b_offset + y1_idx +  1]), (q4_1 + (((data_a[ib0 + i].qh[l0 +  1] & hm1) != 0) ? 16 : 0)),
+          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 16]), (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)),
+             FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)))));
+        const FLOAT_TYPE sy =
+          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), (q4_4 + (((data_a[ib0 + i].qh[l0     ] & (hm1 << 1)) != 0) ? 16 : 0)),
+          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), (q4_5 + (((data_a[ib0 + i].qh[l0 +  1] & (hm1 << 1)) != 0) ? 16 : 0)),
+          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 48]), (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)),
+             FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)))));
+        const FLOAT_TYPE sz =
+          fma(FLOAT_TYPE(data_b[b_offset + y2_idx     ]), (q4_8  + (((data_a[ib0 + i].qh[l0     ] & hm2) != 0) ? 16 : 0)),
+          fma(FLOAT_TYPE(data_b[b_offset + y2_idx +  1]), (q4_9  + (((data_a[ib0 + i].qh[l0 +  1] & hm2) != 0) ? 16 : 0)),
+          fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 16]), (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)),
+             FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)))));
+        const FLOAT_TYPE sw =
+          fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), (q4_12 + (((data_a[ib0 + i].qh[l0     ] & (hm2 << 1)) != 0) ? 16 : 0)),
+          fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), (q4_13 + (((data_a[ib0 + i].qh[l0 +  1] & (hm2 << 1)) != 0) ? 16 : 0)),
+          fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 48]), (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)),
+             FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)))));
+        const FLOAT_TYPE smin =
+          fma(FLOAT_TYPE(data_b[b_offset + y1_idx     ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]), sc2,
+          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]), sc3,
+          fma(FLOAT_TYPE(data_b[b_offset + y2_idx     ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]), sc6,
+              (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7)));
+        const uint tmp_idx = 16 * ix + tid;
+        tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
     }
 
     // sum up partial sums and write back result

+ 13 - 13
ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp

@@ -44,22 +44,22 @@ void main() {
         const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
 
 #if K_QUANTS_PER_ITERATION == 1
-        FLOAT_TYPE sum = FLOAT_TYPE(data_b[b_offset + y_idx +  0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset +  0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x03) << 4)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x0c) << 2)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset +  0]  >> 4) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x30) >> 0)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16]  >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32]  >> 4) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0xc0) >> 2)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48]  >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32);
-        tmp[16 * ix + tid] += sum;
+        const uint tmp_idx = 16 * ix + tid;
+        tmp[tmp_idx] = fma(FLOAT_TYPE(data_b[b_offset + y_idx +  0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset +  0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x03) << 4)) - 32),
+                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32),
+                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x0c) << 2)) - 32),
+                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32),
+                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset +  0]  >> 4) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x30) >> 0)) - 32),
+                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16]  >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32),
+                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32]  >> 4) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0xc0) >> 2)) - 32),
+                       fma(FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48]  >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32), tmp[tmp_idx]))))))));
 #else
         FLOAT_TYPE sum = FLOAT_TYPE(0.0);
         [[unroll]] for (int l = 0; l < 4; ++l) {
-            sum += FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32)
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32)
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0]  >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32)
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32]  >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32);
+            sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0]  >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32]  >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32), sum))));
         }
         tmp[16 * ix + tid] += sum;
 #endif

+ 8 - 7
ggml/src/vulkan-shaders/mul_mm.comp

@@ -326,10 +326,10 @@ void main() {
                 mbyte = uint8_t((data_a[ib].scales[is + 4] >>  4) | ((data_a[ib].scales[is    ] >> 6) << 4));
             }
             const float d = loadd.x * sc;
-            const float m = loadd.y * mbyte;
+            const float m = -loadd.y * mbyte;
 
-            buf_a[buf_idx    ] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi    ] >> (b * 4)) & 0xF) - m);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) - m);
+            buf_a[buf_idx    ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi    ] >> (b * 4)) & 0xF), m));
+            buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
 #elif defined(DATA_A_Q5_K)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
             const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
@@ -357,10 +357,10 @@ void main() {
                 mbyte = uint8_t((data_a[ib].scales[is + 4] >>  4) | ((data_a[ib].scales[is    ] >> 6) << 4));
             }
             const float d = loadd.x * sc;
-            const float m = loadd.y * mbyte;
+            const float m = -loadd.y * mbyte;
 
-            buf_a[buf_idx    ] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi    ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi    ] & hm) != 0 ? 16 : 0)) - m);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0)) - m);
+            buf_a[buf_idx    ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi    ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi    ] & hm) != 0 ? 16 : 0), m));
+            buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
 #elif defined(DATA_A_Q6_K)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
             const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
@@ -463,7 +463,8 @@ void main() {
                 [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
                     [[unroll]] for (uint cc = 0; cc < TN; cc++) {
                         [[unroll]] for (uint cr = 0; cr < TM; cr++) {
-                            sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
+                            const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
+                            sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]);
                         }
                     }
                 }

+ 24 - 0
ggml/src/vulkan-shaders/repeat.comp

@@ -0,0 +1,24 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+uint src0_idx_mod(uint idx) {
+    const uint i13 = idx / (p.ne12*p.ne11*p.ne10);
+    const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
+    const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10);
+    const uint i12_offset = i12*p.ne11*p.ne10;
+    const uint i11 = (idx - i13_offset - i12_offset) / p.ne10;
+    const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
+    return (i13 % p.ne03)*p.nb03 + (i12 % p.ne02)*p.nb02 + (i11 % p.ne01)*p.nb01 + (i10 % p.ne00)*p.nb00;
+}
+
+void main() {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
+        return;
+    }
+
+    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx_mod(idx)]);
+}

+ 4 - 0
ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp

@@ -380,6 +380,10 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
         string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     }));
 
+    tasks.push_back(std::async(std::launch::async, [] {
+        string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+    }));
+
     tasks.push_back(std::async(std::launch::async, [] {
         string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     }));

Некоторые файлы не были показаны из-за большого количества измененных файлов