فهرست منبع

vulkan: Reduce temporary memory usage for TOP_K (#17623)

- Compute row size for the temp buffer based on the output of the first pass.
- Update shader addressing math to use the output row size
- Pass the output row size as "ncols_output", what used to be "ncols_output" is now "k"

For the common case of K=40 and src0=(200000,1,1,1), this reduces the temporary buffer
from about 3.2MB to 500KB.
Jeff Bolz 1 ماه پیش
والد
کامیت
61bde8e21f

+ 28 - 11
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -1227,6 +1227,7 @@ struct vk_op_topk_push_constants {
     uint32_t orig_ncols;
     uint32_t ncols_input;
     uint32_t ncols_output;
+    uint32_t k;
     uint32_t nrows;
     uint32_t first_pass;
     uint32_t last_pass;
@@ -1673,6 +1674,14 @@ class vk_perf_logger {
             timings[name.str()].push_back(time);
             return;
         }
+        if (node->op == GGML_OP_TOP_K) {
+            std::stringstream name;
+            name << ggml_op_name(node->op) <<
+                " K=" << node->ne[0] <<
+                " (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")";
+            timings[name.str()].push_back(time);
+            return;
+        }
         timings[ggml_op_name(node->op)].push_back(time);
     }
   private:
@@ -10345,17 +10354,8 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
     uint32_t nrows = ggml_nrows(src0);
     uint32_t k = dst->ne[0];
 
-    vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 };
+    vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 };
 
-    // Reserve space for ivec2 per element, double buffered
-    const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int);
-    const size_t x_sz = dbl_buf_size * 2;
-    uint32_t dbl_buf_index = 0;
-
-    if (ctx->prealloc_size_x < x_sz) {
-        ctx->prealloc_size_x = x_sz;
-        ggml_vk_preallocate_buffers(ctx, subctx);
-    }
     if (ctx->prealloc_x_need_sync) {
         ggml_vk_sync_buffers(ctx, subctx);
     }
@@ -10370,8 +10370,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
     // largest elements. Repeat until we have the top K elements.
     // Need to do at least one iteration to write out the results.
     bool done_one_iter = false;
+    uint32_t dbl_buf_index = 0;
+    size_t dbl_buf_size;
     while (num_elements > k || !done_one_iter) {
-        done_one_iter = true;
 
         // Prefer going as small as num_topk_pipelines - 3 for perf reasons.
         // But if K is larger, then we need a larger workgroup
@@ -10411,6 +10412,21 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
         // Number of elements remaining after this pass
         uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
 
+        pc2.ncols_output = num_dst_elements;
+
+        if (!done_one_iter) {
+            // Reserve space for ivec2 per element, double buffered
+            // K per workgroup per row
+            dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int);
+            dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment);
+            const size_t x_sz = dbl_buf_size * 2;
+
+            if (ctx->prealloc_size_x < x_sz) {
+                ctx->prealloc_size_x = x_sz;
+                ggml_vk_preallocate_buffers(ctx, subctx);
+            }
+        }
+
         vk_subbuffer src_buf;
         vk_subbuffer dst_buf;
 
@@ -10436,6 +10452,7 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
         if (num_elements > k) {
             ggml_vk_sync_buffers(ctx, subctx);
         }
+        done_one_iter = true;
     }
     ctx->prealloc_x_need_sync = true;
 }

+ 12 - 7
ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp

@@ -19,6 +19,7 @@ layout (push_constant) uniform parameter {
     uint orig_ncols;
     uint ncols_input;
     uint ncols_output;
+    uint k;
     uint nrows;
     uint first_pass;
     uint last_pass;
@@ -36,7 +37,7 @@ void topk(bool needs_bounds_check, const uint row) {
             const uint row_offset = row * p.ncols_input;
             dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
         } else {
-            const uint row_offset = row * p.orig_ncols;
+            const uint row_offset = row * p.ncols_input;
             dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
         }
     } else {
@@ -44,7 +45,7 @@ void topk(bool needs_bounds_check, const uint row) {
     }
     barrier();
 
-    if (p.ncols_output == 1) {
+    if (p.k == 1) {
         // Fast path for single output - just do a max reduction
         [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
             if (col < s) {
@@ -84,13 +85,17 @@ void topk(bool needs_bounds_check, const uint row) {
         }
     }
 
-    if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
+    if (col < p.k) {
         if (p.last_pass != 0) {
-            const uint row_offset = row * p.ncols_output;
-            data_d[row_offset + col] = dst_row[col].x;
+            if (gl_GlobalInvocationID.x < p.ncols_input) {
+                const uint row_offset = row * p.k;
+                data_d[row_offset + col] = dst_row[col].x;
+            }
         } else {
-            const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
-            data_t[row_offset + col] = dst_row[col];
+            if (gl_WorkGroupID.x * p.k + col < p.ncols_output) {
+                const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;
+                data_t[row_offset + col] = dst_row[col];
+            }
         }
     }
 }

+ 14 - 9
ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp

@@ -25,6 +25,7 @@ layout (push_constant) uniform parameter {
     uint orig_ncols;
     uint ncols_input;
     uint ncols_output;
+    uint k;
     uint nrows;
     uint first_pass;
     uint last_pass;
@@ -60,7 +61,7 @@ void topk(const uint row) {
             const uint row_offset = row * p.ncols_input;
             dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
         } else {
-            const uint row_offset = row * p.orig_ncols;
+            const uint row_offset = row * p.ncols_input;
             dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];
         }
     } else {
@@ -68,7 +69,7 @@ void topk(const uint row) {
     }
     barrier();
 
-    if (p.ncols_output == 1) {
+    if (p.k == 1) {
         // Fast path for single output - just do a max reduction
         [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
             if (tid < s) {
@@ -98,7 +99,7 @@ void topk(const uint row) {
         uint range_max = 0xFF800000;
         // How many are above the current range, and how many we need to find.
         uint total = 0;
-        uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
+        uint limit = min(p.k, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
 
         while (mask != 0) {
             barrier();
@@ -139,7 +140,7 @@ void topk(const uint row) {
             range_max = range_min + ((min_idx + 1) << shift);
             range_min = range_min + (min_idx << shift);
 
-            if (total == p.ncols_output) {
+            if (total == p.k) {
                 break;
             }
             total -= counts[min_idx];
@@ -179,13 +180,17 @@ void topk(const uint row) {
         barrier();
     }
 
-    if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
+    if (tid < p.k) {
         if (p.last_pass != 0) {
-            const uint row_offset = row * p.ncols_output;
-            data_d[row_offset + tid] = dst_row[tid].x;
+            if (gl_GlobalInvocationID.x < p.ncols_input) {
+                const uint row_offset = row * p.k;
+                data_d[row_offset + tid] = dst_row[tid].x;
+            }
         } else {
-            const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
-            data_t[row_offset + tid] = dst_row[tid];
+            if (gl_WorkGroupID.x * p.k + tid < p.ncols_output) {
+                const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;
+                data_t[row_offset + tid] = dst_row[tid];
+            }
         }
     }
 }