瀏覽代碼

CUDA: fix bug in topk-moe softmax (#16711)

Aman Gupta 3 月之前
父節點
當前提交
9285325ce0
共有 1 個文件被更改,包括 1 次插入1 次删除
  1. 1 1
      ggml/src/ggml-cuda/topk-moe.cu

+ 1 - 1
ggml/src/ggml-cuda/topk-moe.cu

@@ -141,7 +141,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
         wt_sum              = warp_reduce_sum(wt_sum);
         const float inv_sum = 1.0f / wt_sum;
 
-        for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
+        for (int i = 0; i < experts_per_thread; i++) {
             output_weights[i] *= inv_sum;
         }
     }