Parcourir la source

vulkan: Implement GGML_OP_CUMSUM (#17479)

Jeff Bolz il y a 1 mois
Parent
commit
b3b03a7baf

+ 29 - 0
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -705,6 +705,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
     vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
     vk_pipeline pipeline_sum_rows_f32;
+    vk_pipeline pipeline_cumsum_f32;
     vk_pipeline pipeline_argmax_f32;
     vk_pipeline pipeline_count_equal_i32;
     vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@@ -3968,6 +3969,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
 
+    ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
+
     ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
 
 #define IM2COL(bda) \
@@ -8457,6 +8460,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_sum_rows_f32;
         }
         return nullptr;
+    case GGML_OP_CUMSUM:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_cumsum_f32;
+        }
+        return nullptr;
     case GGML_OP_ARGMAX:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
             return ctx->device->pipeline_argmax_f32;
@@ -8821,6 +8829,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
     case GGML_OP_SOFT_MAX:
     case GGML_OP_SOFT_MAX_BACK:
     case GGML_OP_SUM_ROWS:
+    case GGML_OP_CUMSUM:
     case GGML_OP_MEAN:
     case GGML_OP_ARGMAX:
         {
@@ -10150,6 +10159,11 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
     ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p);
 }
 
+static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
+    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
+}
+
 static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
 }
@@ -11749,6 +11763,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
     case GGML_OP_SUM_ROWS:
         ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
 
+        break;
+    case GGML_OP_CUMSUM:
+        ggml_vk_cumsum(ctx, compute_ctx, src0, node);
+
         break;
     case GGML_OP_MEAN:
         ggml_vk_mean(ctx, compute_ctx, src0, node);
@@ -13786,6 +13804,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
             return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
+        case GGML_OP_CUMSUM:
+            {
+                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+                auto device = ggml_vk_get_device(ctx->device);
+                if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
+                    return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
+                }
+                return false;
+            }
         case GGML_OP_ARGMAX:
         case GGML_OP_COUNT_EQUAL:
         case GGML_OP_IM2COL:
@@ -14436,6 +14463,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
             tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
         } else if (tensor->op == GGML_OP_SUM_ROWS) {
             tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
+        } else if (tensor->op == GGML_OP_CUMSUM) {
+            tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]);
         } else if (tensor->op == GGML_OP_MEAN) {
             tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
         } else if (tensor->op == GGML_OP_ARGMAX) {

+ 69 - 0
ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp

@@ -0,0 +1,69 @@
+#version 450
+
+#include "types.glsl"
+#include "sum_rows.glsl"
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#extension GL_KHR_shader_subgroup_basic : enable
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 128;
+layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
+
+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
+
+shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
+shared FLOAT_TYPE last_sum;
+
+void main() {
+    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
+    const uint tid = gl_LocalInvocationID.x;
+
+    const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
+    const uint i03_offset = i03 * p.ne01*p.ne02;
+    const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
+    const uint i01 = row - i03_offset - i02*p.ne01;
+
+    const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
+    const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
+
+    uint subgroup_id = tid / SUBGROUP_SIZE;
+
+    if (tid == 0) {
+        last_sum = 0;
+    }
+
+    uint col = tid;
+    uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE);
+    for (int i = 0; i < num_iter; ++i) {
+        FLOAT_TYPE v = 0;
+        if (col < p.n_cols) {
+            v = FLOAT_TYPE(data_a[src_idx + col]);
+        }
+        v = subgroupInclusiveAdd(v);
+
+        // Store the largest partial sum for each subgroup, then add the partials for all
+        // lower subgroups and the final partial sum from the previous iteration.
+        if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
+            partial[subgroup_id] = v;
+        }
+        barrier();
+        for (int j = 0; j < subgroup_id; ++j) {
+            v += partial[j];
+        }
+        v += last_sum;
+        barrier();
+        if (tid == BLOCK_SIZE - 1) {
+            last_sum = v;
+        }
+        if (col < p.n_cols) {
+            data_d[dst_idx + col] = D_TYPE(v);
+        }
+        col += BLOCK_SIZE;
+    }
+}

+ 1 - 24
ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp

@@ -1,6 +1,7 @@
 #version 450
 
 #include "types.glsl"
+#include "sum_rows.glsl"
 
 #extension GL_EXT_control_flow_attributes : enable
 
@@ -11,30 +12,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
 
 layout (constant_id = 0) const uint BLOCK_SIZE = 32;
 
-layout (push_constant) uniform parameter
-{
-    uint n_cols;
-    uint ne01, ne02;
-    uint nb01, nb02, nb03;
-    uint nb11, nb12, nb13;
-    float weight;
-    uint misalign_offsets;
-    uint ne0_12mp, ne0_12L;
-    uint ne0_1mp, ne0_1L;
-} p;
-
-uint get_aoffset() { return p.misalign_offsets >> 16; }
-uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
-
-// see init_fastdiv_values in ggml-vulkan.cpp
-uint fastdiv(uint n, uint mp, uint L) {
-    uint msbs, lsbs;
-    // msbs = mulhi(n, mp)
-    umulExtended(n, mp, msbs, lsbs);
-    return (msbs + n) >> L;
-}
-
-
 shared FLOAT_TYPE tmp[BLOCK_SIZE];
 
 void main() {

+ 25 - 0
ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl

@@ -0,0 +1,25 @@
+
+// vk_op_sum_rows_push_constants
+layout (push_constant) uniform parameter
+{
+    uint n_cols;
+    uint ne01, ne02;
+    uint nb01, nb02, nb03;
+    uint nb11, nb12, nb13;
+    float weight;
+    uint misalign_offsets;
+    uint ne0_12mp, ne0_12L;
+    uint ne0_1mp, ne0_1L;
+} p;
+
+uint get_aoffset() { return p.misalign_offsets >> 16; }
+uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
+
+// see init_fastdiv_values in ggml-vulkan.cpp
+uint fastdiv(uint n, uint mp, uint L) {
+    uint msbs, lsbs;
+    // msbs = mulhi(n, mp)
+    umulExtended(n, mp, msbs, lsbs);
+    return (msbs + n) >> L;
+}
+

+ 1 - 0
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

@@ -916,6 +916,7 @@ void process_shaders() {
     string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
     string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
     string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
+    string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
 
     for (std::string dim_str : {"", "_3d"}) {
         for (bool bda : {false, true}) {