1
0
Эх сурвалжийг харах

metal: SSM kernel improvements (#17876)

* feat: Add a batched version of ssm_conv

This was done using Claude Code. It found a number of optimizations around
how the threads were organized, resulting in a huge performance boost!

Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Optimized SSM_SCAN kernel for metal

This used Claude Code and resulted in a modest performance improvement
while maintaining correctness.

Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* test: Add test-backend-ops perf tests for SSM_CONV

Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* test: Real representitive tests for SSM_CONV

Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Use function constant for ssm_conv batch size

Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* test: backend op tests for ssm_scan from granite4 1b-h

Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* style: remove commented out templates

Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: float4 version of ssm_conv_batched

Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Add missing ggml_metal_cv_free

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Gabe Goodhart 1 сар өмнө
parent
commit
086a63e3a5

+ 38 - 1
ggml/src/ggml-metal/ggml-metal-device.cpp

@@ -411,6 +411,38 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_me
     return res;
 }
 
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) {
+    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
+    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+    GGML_ASSERT(ggml_is_contiguous(op->src[1]));
+
+    char base[256];
+    char name[256];
+
+    const char * suffix = "";
+    if (op->src[1]->ne[0] % 4 == 0) {
+        suffix = "_4";
+    }
+
+    snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
+    snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
+    }
+
+    return res;
+}
+
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op)  {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
 
@@ -427,7 +459,12 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_me
         res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res.smem = 32*sizeof(float)*nsg;
+    // Shared memory layout:
+    // - sgptg * NW floats for partial sums (nsg * 32)
+    // - sgptg floats for shared_x_dt (nsg)
+    // - sgptg floats for shared_dA (nsg)
+    // Total: nsg * (32 + 2) floats
+    res.smem = (32 + 2)*sizeof(float)*nsg;
 
     return res;
 }

+ 1 - 0
ggml/src/ggml-metal/ggml-metal-device.h

@@ -117,6 +117,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_ad
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri               (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max          (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv          (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched  (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan          (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv              (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext        (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);

+ 1 - 0
ggml/src/ggml-metal/ggml-metal-impl.h

@@ -77,6 +77,7 @@
 #define FC_MUL_MV                      600
 #define FC_MUL_MM                      700
 #define FC_ROPE                        800
+#define FC_SSM_CONV                    900
 
 // op-specific constants
 #define OP_FLASH_ATTN_EXT_NQPTG 8

+ 35 - 7
ggml/src/ggml-metal/ggml-metal-ops.cpp

@@ -1365,15 +1365,43 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
         /*.nb2  =*/ nb2,
     };
 
-    auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
+    // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
+    const bool use_batched = (ne1 > 1);
+
+    if (use_batched) {
+        // Determine the smallest power of 2 that's >= ne1, but <= 256
+        int BATCH_SIZE;
+        if      (ne1 > 128) BATCH_SIZE = 256;
+        else if (ne1 > 64 ) BATCH_SIZE = 128;
+        else if (ne1 > 32 ) BATCH_SIZE = 64;
+        else if (ne1 > 16 ) BATCH_SIZE = 32;
+        else if (ne1 > 8  ) BATCH_SIZE = 16;
+        else if (ne1 > 4  ) BATCH_SIZE = 8;
+        else                BATCH_SIZE = 2;
+
+        auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE);
 
-    ggml_metal_encoder_set_pipeline(enc, pipeline);
-    ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
-    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op),         3);
+        ggml_metal_encoder_set_pipeline(enc, pipeline);
+        ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op),         3);
 
-    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
+        // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
+        // Each threadgroup has BATCH_SIZE threads, each handling one token
+        const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
+        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
+    } else {
+        auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
+
+        ggml_metal_encoder_set_pipeline(enc, pipeline);
+        ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op),         3);
+
+        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
+    }
 
     return 1;
 }

+ 127 - 9
ggml/src/ggml-metal/ggml-metal.metal

@@ -2343,7 +2343,102 @@ kernel void kernel_ssm_conv_f32_f32_4(
     x[0] = sumf;
 }
 
+constant short FC_ssm_conv_bs   [[function_constant(FC_SSM_CONV + 0)]];
+
+// Batched version: each threadgroup processes multiple tokens for better efficiency
+// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
+kernel void kernel_ssm_conv_f32_f32_batched(
+        constant ggml_metal_kargs_ssm_conv & args,
+        device const  void * src0,
+        device const  void * src1,
+        device       float * dst,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    // tgpig.x = row index (ir)
+    // tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
+    // tgpig.z = sequence index (i3)
+    // tpitg.x = thread within batch (0..BATCH_SIZE-1)
+    const short BATCH_SIZE = FC_ssm_conv_bs;
+
+    const int64_t ir      = tgpig.x;
+    const int64_t i2_base = tgpig.y * BATCH_SIZE;
+    const int64_t i3      = tgpig.z;
+    const int64_t i2_off  = tpitg.x;
+    const int64_t i2      = i2_base + i2_off;
+
+    const int64_t nc  = args.ne10;  // conv kernel size (typically 4)
+    const int64_t n_t = args.ne1;   // number of tokens
+
+    // Bounds check for partial batches at the end
+    if (i2 >= n_t) {
+        return;
+    }
+
+    // Load conv weights (shared across all tokens for this row)
+    device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
+
+    // Load source for this specific token
+    device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
+
+    // Output location for this token
+    device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
+
+    float sumf = 0.0f;
+    for (int64_t i0 = 0; i0 < nc; ++i0) {
+        sumf += s[i0] * c[i0];
+    }
+
+    x[0] = sumf;
+}
+
+kernel void kernel_ssm_conv_f32_f32_batched_4(
+        constant ggml_metal_kargs_ssm_conv & args,
+        device const  void * src0,
+        device const  void * src1,
+        device       float * dst,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    // tgpig.x = row index (ir)
+    // tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
+    // tgpig.z = sequence index (i3)
+    // tpitg.x = thread within batch (0..BATCH_SIZE-1)
+    const short BATCH_SIZE = FC_ssm_conv_bs;
+
+    const int64_t ir      = tgpig.x;
+    const int64_t i2_base = tgpig.y * BATCH_SIZE;
+    const int64_t i3      = tgpig.z;
+    const int64_t i2_off  = tpitg.x;
+    const int64_t i2      = i2_base + i2_off;
+
+    const int64_t nc  = args.ne10;  // conv kernel size (typically 4)
+    const int64_t n_t = args.ne1;   // number of tokens
+
+    // Bounds check for partial batches at the end
+    if (i2 >= n_t) {
+        return;
+    }
+
+    // Load conv weights (shared across all tokens for this row)
+    device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
+
+    // Load source for this specific token
+    device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
+
+    // Output location for this token
+    device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
+
+    float sumf = 0.0f;
+    for (int64_t i0 = 0; i0 < nc/4; ++i0) {
+        sumf += dot(s[i0], c[i0]);
+    }
+
+    x[0] = sumf;
+}
+
 // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
+// Optimized version: reduces redundant memory loads by having one thread load shared values
 kernel void kernel_ssm_scan_f32(
         constant ggml_metal_kargs_ssm_scan & args,
         device const void * src0,
@@ -2363,7 +2458,15 @@ kernel void kernel_ssm_scan_f32(
         uint3    tgpg[[threadgroups_per_grid]]) {
     constexpr short NW = N_SIMDWIDTH;
 
-    shared[tpitg.x] = 0.0f;
+    // Shared memory layout:
+    // [0..sgptg*NW-1]: partial sums for reduction (existing)
+    // [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
+    // [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
+    threadgroup float * shared_sums = shared;
+    threadgroup float * shared_x_dt = shared + sgptg * NW;
+    threadgroup float * shared_dA   = shared + sgptg * NW + sgptg;
+
+    shared_sums[tpitg.x] = 0.0f;
 
     const int32_t i0 = tpitg.x;
     const int32_t i1 = tgpig.x;
@@ -2403,32 +2506,47 @@ kernel void kernel_ssm_scan_f32(
     for (int i2 = 0; i2 < n_t; i2 += sgptg) {
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
-        for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
-            const float dt0  = dt[0];
+        // Pre-compute x_dt and dA for this batch of tokens
+        // Only first sgptg threads do the loads and expensive math
+        if (i0 < sgptg && i2 + i0 < n_t) {
+            // ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
+            device const float * x_t  = x  + i0 * args.ns12;
+            device const float * dt_t = dt + i0 * args.ns21;
+
+            const float dt0  = dt_t[0];
             const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
-            const float x_dt = x[0] * dtsp;
-            const float dA   = exp(dtsp * A0);
+            shared_x_dt[i0] = x_t[0] * dtsp;
+            shared_dA[i0]   = dtsp;  // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
+            const float x_dt = shared_x_dt[t];
+            const float dA   = exp(shared_dA[t] * A0);
 
             s = (s0 * dA) + (B[i0] * x_dt);
 
             const float sumf = simd_sum(s * C[i0]);
 
             if (tiisg == 0) {
-                shared[t*NW + sgitg] = sumf;
+                shared_sums[t*NW + sgitg] = sumf;
             }
 
             // recurse
             s0 = s;
 
-            x  += args.ns12;
-            dt += args.ns21;
             B  += args.ns42;
             C  += args.ns52;
         }
 
+        // Advance pointers for next batch
+        x  += sgptg * args.ns12;
+        dt += sgptg * args.ns21;
+
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
-        const float sumf = simd_sum(shared[sgitg*NW + tiisg]);
+        const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
 
         if (tiisg == 0 && i2 + sgitg < n_t) {
             y[sgitg*nh*nr] = sumf;

+ 7 - 0
tests/test-backend-ops.cpp

@@ -8193,6 +8193,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
         }
     }
 
+    // Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf
+    test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill
+    test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4,   3328, 1, 1}, {4, 3328, 1, 1})); // generate
+    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill
+    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1,   1)); // generate
+
+
     return test_cases;
 }