Преглед изворни кода

cuda : fix bounds check for src0 rows in MMVQ kernel (whisper/2231)

* cuda : fix bounds check for src0 rows in MMVQ kernel

* Update ggml-cuda/mmvq.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Georgi Gerganov пре 1 година
родитељ
комит
19b7a836f6
1 измењених фајлова са 1 додато и 1 уклоњено
  1. 1 1
      ggml-cuda/mmvq.cu

+ 1 - 1
ggml-cuda/mmvq.cu

@@ -117,7 +117,7 @@ static __global__ void mul_mat_vec_q(
             tmp[j][i] = warp_reduce_sum(tmp[j][i]);
         }
 
-        if (threadIdx.x < rows_per_cuda_block) {
+        if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
             dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
         }
     }