Browse Source

ggml : use WARP_SIZE/2 for argmax reduction offset (#18092)

Aadeshveer Singh 1 tháng trước cách đây
mục cha
commit
58062860af
1 tập tin đã thay đổi với 2 bổ sung2 xóa
  1. 2 2
      ggml/src/ggml-cuda/argmax.cu

+ 2 - 2
ggml/src/ggml-cuda/argmax.cu

@@ -21,7 +21,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
     }
 
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
+    for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
         const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
         const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
         if (val > maxval) {
@@ -50,7 +50,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
                 argmax = shared_argmax[lane_id];
             }
 #pragma unroll
-            for (int offset = 16; offset > 0; offset >>= 1) {
+            for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
                 const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
                 const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
                 if (val > maxval) {