Quellcode durchsuchen

ggml : use WARP_SIZE/2 for argmax reduction offset (#18092)

Aadeshveer Singh vor 1 Monat
Ursprung
Commit
58062860af
1 geänderte Dateien mit 2 neuen und 2 gelöschten Zeilen
  1. 2 2
      ggml/src/ggml-cuda/argmax.cu

+ 2 - 2
ggml/src/ggml-cuda/argmax.cu

@@ -21,7 +21,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
     }
 
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
+    for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
         const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
         const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
         if (val > maxval) {
@@ -50,7 +50,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
                 argmax = shared_argmax[lane_id];
             }
 #pragma unroll
-            for (int offset = 16; offset > 0; offset >>= 1) {
+            for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
                 const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
                 const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
                 if (val > maxval) {