Просмотр исходного кода

CUDA: Remove unneded bias/gate dims in fused mmvq (#16858)

* CUDA: Remove unneded bias/gate dims in fused mmvq

Pointed out
[here](https://github.com/ggml-org/llama.cpp/pull/16847#discussion_r2476798989)
that only a single value is needed per target col per thread

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Fix "Error 991-D: extra braces are nonstandard" during compilation

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Oliver Simons 2 месяцев назад
Родитель
Сommit
d3dc9dd898
1 измененных файлов с 8 добавлено и 6 удалено
  1. 8 6
      ggml/src/ggml-cuda/mmvq.cu

+ 8 - 6
ggml/src/ggml-cuda/mmvq.cu

@@ -190,8 +190,8 @@ static __global__ void mul_mat_vec_q(
 
     const uint32_t channel_bias = ids ? channel_x : channel_dst;
 
-    float x_biases[ncols_dst][rows_per_cuda_block]    = { { 0.0f } };
-    float gate_biases[ncols_dst][rows_per_cuda_block] = { { 0.0f } };
+    float x_biases[ncols_dst]    = { 0.0f };
+    float gate_biases[ncols_dst] = { 0.0f };
     if constexpr (has_fusion) {
         if (use_bias) {
             x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
@@ -199,8 +199,9 @@ static __global__ void mul_mat_vec_q(
             // 2. load only on threads that won't die after partial sum calculation
             if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
                 (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
+#pragma unroll
                 for (int j = 0; j < ncols_dst; ++j) {
-                    x_biases[j][threadIdx.x] = x_bias[j * stride_col_dst + threadIdx.x];
+                    x_biases[j] = x_bias[j * stride_col_dst + threadIdx.x];
                 }
             }
         }
@@ -208,8 +209,9 @@ static __global__ void mul_mat_vec_q(
             gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
             if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
                 (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
+#pragma unroll
                 for (int j = 0; j < ncols_dst; ++j) {
-                    gate_biases[j][threadIdx.x] = gate_bias[j * stride_col_dst + threadIdx.x];
+                    gate_biases[j] = gate_bias[j * stride_col_dst + threadIdx.x];
                 }
             }
         }
@@ -299,12 +301,12 @@ static __global__ void mul_mat_vec_q(
             float result = tmp[j][threadIdx.x];
             if constexpr (has_fusion) {
                 if (use_bias) {
-                    result += x_biases[j][threadIdx.x];
+                    result += x_biases[j];
                 }
                 if (use_gate) {
                     float gate_value = tmp_gate[j][threadIdx.x];
                     if (use_gate_bias) {
-                        gate_value += gate_biases[j][threadIdx.x];
+                        gate_value += gate_biases[j];
                     }
                     switch (active_glu) {
                         case GGML_GLU_OP_SWIGLU: