Răsfoiți Sursa

vulkan: Multi-pass softmax for large number of cols (#17892)

When the number of cols is large, split each row across multiple workgroups.
There are three phases that communicate partial results through temp buffers:
(1) compute max partials
(2) take max of partials, compute sum(exp(x-max)) partials
(3) sum partials, compute scaled result
Jeff Bolz 1 lună în urmă
părinte
comite
303f8615e9

+ 62 - 2
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -722,6 +722,11 @@ struct vk_device_struct {
     vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
     vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
     vk_pipeline pipeline_soft_max_back_f32;
+
+    vk_pipeline pipeline_soft_max_large1_f32, pipeline_soft_max_large1_f32_f16;
+    vk_pipeline pipeline_soft_max_large2_f32, pipeline_soft_max_large2_f32_f16;
+    vk_pipeline pipeline_soft_max_large3_f32, pipeline_soft_max_large3_f32_f16;
+
     vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
     vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
     vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
@@ -3998,6 +4003,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
     ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);
 
+    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32,     "soft_max_large1_f32",     soft_max_large1_f32_len,     soft_max_large1_f32_data,     "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32,     "soft_max_large2_f32",     soft_max_large2_f32_len,     soft_max_large2_f32_data,     "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32,     "soft_max_large3_f32",     soft_max_large3_f32_len,     soft_max_large3_f32_data,     "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32_f16, "soft_max_large1_f32_f16", soft_max_large1_f32_f16_len, soft_max_large1_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32_f16, "soft_max_large2_f32_f16", soft_max_large2_f32_f16_len, soft_max_large2_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32_f16, "soft_max_large3_f32_f16", soft_max_large3_f32_f16_len, soft_max_large3_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
+
     ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -10117,7 +10129,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
-    ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, {
+    vk_op_soft_max_push_constants pc {
         ncols,
         src1 != nullptr ? nrows_y : (uint32_t)0,
         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
@@ -10128,7 +10140,55 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
         n_head_log2,
         nrows_x,
         src2 != nullptr
-    });
+    };
+
+    if (ncols <= 16384) {
+        ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, std::move(pc));
+    } else {
+
+        vk_subbuffer buf_a = ggml_vk_tensor_subbuffer(ctx, src0);
+        vk_subbuffer buf_b = src1 ? ggml_vk_tensor_subbuffer(ctx, src1) : buf_a;
+        vk_subbuffer buf_c = src2 ? ggml_vk_tensor_subbuffer(ctx, src2) : buf_a;
+        vk_subbuffer buf_d = ggml_vk_tensor_subbuffer(ctx, dst);
+
+        uint32_t elems_per_wg = 128 * 4;
+        uint32_t num_wgs = CEIL_DIV(ncols, elems_per_wg);
+        size_t tmp_size = num_wgs * nrows_x * sizeof(float);
+
+        if (ctx->prealloc_size_x < tmp_size) {
+            ctx->prealloc_size_x = tmp_size;
+            ggml_vk_preallocate_buffers(ctx, subctx);
+        }
+        if (ctx->prealloc_size_y < tmp_size) {
+            ctx->prealloc_size_y = tmp_size;
+            ggml_vk_preallocate_buffers(ctx, subctx);
+        }
+        if (ctx->prealloc_x_need_sync || ctx->prealloc_y_need_sync) {
+            ggml_vk_sync_buffers(ctx, subctx);
+        }
+
+        vk_subbuffer buf_x = { ctx->prealloc_x, 0, tmp_size };
+        vk_subbuffer buf_y = { ctx->prealloc_y, 0, tmp_size };
+
+        std::array<uint32_t, 3> elements = { num_wgs, nrows_x, 1 };
+
+        vk_pipeline pipeline1 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large1_f32_f16 : ctx->device->pipeline_soft_max_large1_f32;
+        vk_pipeline pipeline2 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large2_f32_f16 : ctx->device->pipeline_soft_max_large2_f32;
+        vk_pipeline pipeline3 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large3_f32_f16 : ctx->device->pipeline_soft_max_large3_f32;
+
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline3, 1);
+
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
+        ggml_vk_sync_buffers(ctx, subctx);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
+        ggml_vk_sync_buffers(ctx, subctx);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline3, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
+
+        ctx->prealloc_x_need_sync = true;
+        ctx->prealloc_y_need_sync = true;
+    }
 }
 
 static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {

+ 62 - 0
ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp

@@ -0,0 +1,62 @@
+#version 450
+
+#include "soft_max_large_common.glsl"
+
+void main() {
+    const uint tid = gl_LocalInvocationID.x;
+    const uint rowx = gl_WorkGroupID.y;
+    const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;
+
+    const uint32_t i03 = rowx / (p.ne01 * p.ne02);
+    const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
+    const uint32_t i01 = rowx % p.ne01;
+
+    uint rowy_start = 0;
+    if (p.KY > 0) {
+        rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
+    }
+
+    if (rowx >= p.nrows_x) {
+        return;
+    }
+
+    float slope = get_slope(rowx);
+
+    // Find max
+    FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
+
+    [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
+        const uint col = col0 + tid;
+
+        FLOAT_TYPE a = FLOAT_TYPE(0);
+        if (col < p.KX) {
+            a = data_a[rowx * p.KX + col];
+        }
+
+        FLOAT_TYPE b = FLOAT_TYPE(0);
+        if (p.KY > 0 && col < p.KX) {
+            b = data_b[rowy_start + col];
+        }
+
+        FLOAT_TYPE v = a * p.scale + slope * b;
+
+        if (col < p.KX) {
+            max_val = max(max_val, v);
+        }
+    }
+
+    // reduce across the workgroup
+    vals[tid] = max_val;
+    barrier();
+    [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+        if (tid < s) {
+            vals[tid] = max(vals[tid], vals[tid + s]);
+        }
+        barrier();
+    }
+
+    if (tid == 0) {
+        max_val = vals[0];
+        data_m[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = max_val;
+    }
+}

+ 79 - 0
ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp

@@ -0,0 +1,79 @@
+#version 450
+
+#include "soft_max_large_common.glsl"
+
+void main() {
+    const uint tid = gl_LocalInvocationID.x;
+    const uint rowx = gl_WorkGroupID.y;
+    const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;
+
+    const uint32_t i03 = rowx / (p.ne01 * p.ne02);
+    const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
+    const uint32_t i01 = rowx % p.ne01;
+
+    uint rowy_start = 0;
+    if (p.KY > 0) {
+        rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
+    }
+
+    if (rowx >= p.nrows_x) {
+        return;
+    }
+
+    float slope = get_slope(rowx);
+
+    // Find max
+    FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
+
+    [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) {
+        if (i + tid < gl_NumWorkGroups.x) {
+            max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]);
+        }
+    }
+
+    // reduce across the workgroup
+    vals[tid] = max_val;
+    barrier();
+    [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+        if (tid < s) {
+            vals[tid] = max(max_val, vals[tid + s]);
+        }
+        barrier();
+    }
+
+    max_val = vals[0];
+    barrier();
+
+    FLOAT_TYPE sum = FLOAT_TYPE(0.0f);
+
+    // Compute sum{exp(x - max)}
+    [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
+        const uint col = col0 + tid;
+
+        if (col >= p.KX) {
+            break;
+        }
+
+        // compute exp(a*scale+b*slope), add it to sum
+        const uint i = rowx * p.KX + col;
+        FLOAT_TYPE val;
+        val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val);
+        sum += val;
+        data_d[i] = D_TYPE(val);
+    }
+
+    // reduce across the workgroup
+    vals[tid] = sum;
+    barrier();
+    [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+        if (tid < s) {
+            vals[tid] += vals[tid + s];
+        }
+        barrier();
+    }
+
+    if (tid == 0) {
+        sum = vals[0];
+        data_s[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = sum;
+    }
+}

+ 65 - 0
ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp

@@ -0,0 +1,65 @@
+#version 450
+
+#include "soft_max_large_common.glsl"
+
+shared FLOAT_TYPE sumsh[BLOCK_SIZE];
+
+void main() {
+    const uint tid = gl_LocalInvocationID.x;
+    const uint rowx = gl_WorkGroupID.y;
+    const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;
+
+    const uint32_t i03 = rowx / (p.ne01 * p.ne02);
+    const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
+    const uint32_t i01 = rowx % p.ne01;
+
+    uint rowy_start = 0;
+    if (p.KY > 0) {
+        rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
+    }
+
+    if (rowx >= p.nrows_x) {
+        return;
+    }
+
+    FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
+    FLOAT_TYPE sum = FLOAT_TYPE(0.0f);
+
+    [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) {
+        if (i + tid < gl_NumWorkGroups.x) {
+            max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]);
+            sum += data_s[rowx * gl_NumWorkGroups.x + i + tid];
+        }
+    }
+
+    // reduce across the workgroup
+    vals[tid] = max_val;
+    sumsh[tid] = sum;
+    barrier();
+    [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+        if (tid < s) {
+            vals[tid] = max(max_val, vals[tid + s]);
+            sumsh[tid] += sumsh[tid + s];
+        }
+        barrier();
+    }
+
+    max_val = vals[0];
+    sum = sumsh[0];
+
+    if (p.has_sinks != 0) {
+        sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val));
+    }
+
+    FLOAT_TYPE rcpdivisor = 1.0/sum;
+
+    [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
+        const uint col = col0 + tid;
+
+        if (col >= p.KX) {
+            continue;
+        }
+
+        data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor);
+    }
+}

+ 53 - 0
ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl

@@ -0,0 +1,53 @@
+#extension GL_EXT_control_flow_attributes : enable
+
+layout (push_constant) uniform parameter
+{
+    uint KX;
+    uint KY;
+    uint ne00;
+    uint ne01;
+    uint ne02;
+    uint ne12;
+    uint ne13;
+    uint nb11;
+    uint nb12;
+    uint nb13;
+    float scale;
+    float max_bias;
+    float m0;
+    float m1;
+    uint n_head_log2;
+    uint nrows_x;
+    uint has_sinks;
+} p;
+
+#include "types.glsl"
+
+layout(constant_id = 0) const uint BLOCK_SIZE = 128;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+layout(constant_id = 1) const uint num_iters = 4;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
+layout (binding = 2) readonly buffer Z {float data_c[];};
+layout (binding = 3) buffer D {D_TYPE data_d[];};
+layout (binding = 4) buffer M {float data_m[];};
+layout (binding = 5) buffer S {float data_s[];};
+
+shared FLOAT_TYPE vals[BLOCK_SIZE];
+
+float get_slope(uint rowx) {
+    float slope = 1.0f;
+
+    // ALiBi
+    if (p.max_bias > 0.0f) {
+        const uint h = (rowx / p.ne01) % p.ne02; // head index
+
+        const float base = h < p.n_head_log2 ? p.m0 : p.m1;
+        const uint   exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
+
+        slope = pow(base, exp);
+    }
+
+    return slope;
+}

+ 7 - 0
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

@@ -899,6 +899,13 @@ void process_shaders() {
     string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
     string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
 
+    string_to_spv("soft_max_large1_f32", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+    string_to_spv("soft_max_large2_f32", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+    string_to_spv("soft_max_large3_f32", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+    string_to_spv("soft_max_large1_f32_f16", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
+    string_to_spv("soft_max_large2_f32_f16", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
+    string_to_spv("soft_max_large3_f32_f16", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
+
     string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
     string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
     string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});

+ 3 - 0
tests/test-backend-ops.cpp

@@ -7652,6 +7652,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  true,  GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  true,  GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
 
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true,  true,  GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true,  true,  GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
+
     for (float max_bias : {0.0f, 8.0f}) {
         for (float scale : {1.0f, 0.1f}) {
             for (int64_t ne0 : {16, 1024}) {