|
@@ -2343,7 +2343,102 @@ kernel void kernel_ssm_conv_f32_f32_4(
|
|
|
x[0] = sumf;
|
|
x[0] = sumf;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]];
|
|
|
|
|
+
|
|
|
|
|
+// Batched version: each threadgroup processes multiple tokens for better efficiency
|
|
|
|
|
+// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
|
|
|
|
|
+kernel void kernel_ssm_conv_f32_f32_batched(
|
|
|
|
|
+ constant ggml_metal_kargs_ssm_conv & args,
|
|
|
|
|
+ device const void * src0,
|
|
|
|
|
+ device const void * src1,
|
|
|
|
|
+ device float * dst,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
+ uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
+ // tgpig.x = row index (ir)
|
|
|
|
|
+ // tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
|
|
|
|
|
+ // tgpig.z = sequence index (i3)
|
|
|
|
|
+ // tpitg.x = thread within batch (0..BATCH_SIZE-1)
|
|
|
|
|
+ const short BATCH_SIZE = FC_ssm_conv_bs;
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t ir = tgpig.x;
|
|
|
|
|
+ const int64_t i2_base = tgpig.y * BATCH_SIZE;
|
|
|
|
|
+ const int64_t i3 = tgpig.z;
|
|
|
|
|
+ const int64_t i2_off = tpitg.x;
|
|
|
|
|
+ const int64_t i2 = i2_base + i2_off;
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t nc = args.ne10; // conv kernel size (typically 4)
|
|
|
|
|
+ const int64_t n_t = args.ne1; // number of tokens
|
|
|
|
|
+
|
|
|
|
|
+ // Bounds check for partial batches at the end
|
|
|
|
|
+ if (i2 >= n_t) {
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Load conv weights (shared across all tokens for this row)
|
|
|
|
|
+ device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
|
|
|
|
|
+
|
|
|
|
|
+ // Load source for this specific token
|
|
|
|
|
+ device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
|
|
|
|
+
|
|
|
|
|
+ // Output location for this token
|
|
|
|
|
+ device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
|
|
|
|
+
|
|
|
|
|
+ float sumf = 0.0f;
|
|
|
|
|
+ for (int64_t i0 = 0; i0 < nc; ++i0) {
|
|
|
|
|
+ sumf += s[i0] * c[i0];
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ x[0] = sumf;
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+kernel void kernel_ssm_conv_f32_f32_batched_4(
|
|
|
|
|
+ constant ggml_metal_kargs_ssm_conv & args,
|
|
|
|
|
+ device const void * src0,
|
|
|
|
|
+ device const void * src1,
|
|
|
|
|
+ device float * dst,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
+ uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
+ // tgpig.x = row index (ir)
|
|
|
|
|
+ // tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
|
|
|
|
|
+ // tgpig.z = sequence index (i3)
|
|
|
|
|
+ // tpitg.x = thread within batch (0..BATCH_SIZE-1)
|
|
|
|
|
+ const short BATCH_SIZE = FC_ssm_conv_bs;
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t ir = tgpig.x;
|
|
|
|
|
+ const int64_t i2_base = tgpig.y * BATCH_SIZE;
|
|
|
|
|
+ const int64_t i3 = tgpig.z;
|
|
|
|
|
+ const int64_t i2_off = tpitg.x;
|
|
|
|
|
+ const int64_t i2 = i2_base + i2_off;
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t nc = args.ne10; // conv kernel size (typically 4)
|
|
|
|
|
+ const int64_t n_t = args.ne1; // number of tokens
|
|
|
|
|
+
|
|
|
|
|
+ // Bounds check for partial batches at the end
|
|
|
|
|
+ if (i2 >= n_t) {
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Load conv weights (shared across all tokens for this row)
|
|
|
|
|
+ device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
|
|
|
|
|
+
|
|
|
|
|
+ // Load source for this specific token
|
|
|
|
|
+ device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
|
|
|
|
+
|
|
|
|
|
+ // Output location for this token
|
|
|
|
|
+ device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
|
|
|
|
+
|
|
|
|
|
+ float sumf = 0.0f;
|
|
|
|
|
+ for (int64_t i0 = 0; i0 < nc/4; ++i0) {
|
|
|
|
|
+ sumf += dot(s[i0], c[i0]);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ x[0] = sumf;
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
|
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
|
|
|
|
+// Optimized version: reduces redundant memory loads by having one thread load shared values
|
|
|
kernel void kernel_ssm_scan_f32(
|
|
kernel void kernel_ssm_scan_f32(
|
|
|
constant ggml_metal_kargs_ssm_scan & args,
|
|
constant ggml_metal_kargs_ssm_scan & args,
|
|
|
device const void * src0,
|
|
device const void * src0,
|
|
@@ -2363,7 +2458,15 @@ kernel void kernel_ssm_scan_f32(
|
|
|
uint3 tgpg[[threadgroups_per_grid]]) {
|
|
uint3 tgpg[[threadgroups_per_grid]]) {
|
|
|
constexpr short NW = N_SIMDWIDTH;
|
|
constexpr short NW = N_SIMDWIDTH;
|
|
|
|
|
|
|
|
- shared[tpitg.x] = 0.0f;
|
|
|
|
|
|
|
+ // Shared memory layout:
|
|
|
|
|
+ // [0..sgptg*NW-1]: partial sums for reduction (existing)
|
|
|
|
|
+ // [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
|
|
|
|
|
+ // [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
|
|
|
|
|
+ threadgroup float * shared_sums = shared;
|
|
|
|
|
+ threadgroup float * shared_x_dt = shared + sgptg * NW;
|
|
|
|
|
+ threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
|
|
|
|
|
+
|
|
|
|
|
+ shared_sums[tpitg.x] = 0.0f;
|
|
|
|
|
|
|
|
const int32_t i0 = tpitg.x;
|
|
const int32_t i0 = tpitg.x;
|
|
|
const int32_t i1 = tgpig.x;
|
|
const int32_t i1 = tgpig.x;
|
|
@@ -2403,32 +2506,47 @@ kernel void kernel_ssm_scan_f32(
|
|
|
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
|
|
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
- for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
|
|
|
|
|
- const float dt0 = dt[0];
|
|
|
|
|
|
|
+ // Pre-compute x_dt and dA for this batch of tokens
|
|
|
|
|
+ // Only first sgptg threads do the loads and expensive math
|
|
|
|
|
+ if (i0 < sgptg && i2 + i0 < n_t) {
|
|
|
|
|
+ // ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
|
|
|
|
|
+ device const float * x_t = x + i0 * args.ns12;
|
|
|
|
|
+ device const float * dt_t = dt + i0 * args.ns21;
|
|
|
|
|
+
|
|
|
|
|
+ const float dt0 = dt_t[0];
|
|
|
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
|
|
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
|
|
|
- const float x_dt = x[0] * dtsp;
|
|
|
|
|
- const float dA = exp(dtsp * A0);
|
|
|
|
|
|
|
+ shared_x_dt[i0] = x_t[0] * dtsp;
|
|
|
|
|
+ shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
+
|
|
|
|
|
+ for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
|
|
|
|
|
+ const float x_dt = shared_x_dt[t];
|
|
|
|
|
+ const float dA = exp(shared_dA[t] * A0);
|
|
|
|
|
|
|
|
s = (s0 * dA) + (B[i0] * x_dt);
|
|
s = (s0 * dA) + (B[i0] * x_dt);
|
|
|
|
|
|
|
|
const float sumf = simd_sum(s * C[i0]);
|
|
const float sumf = simd_sum(s * C[i0]);
|
|
|
|
|
|
|
|
if (tiisg == 0) {
|
|
if (tiisg == 0) {
|
|
|
- shared[t*NW + sgitg] = sumf;
|
|
|
|
|
|
|
+ shared_sums[t*NW + sgitg] = sumf;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// recurse
|
|
// recurse
|
|
|
s0 = s;
|
|
s0 = s;
|
|
|
|
|
|
|
|
- x += args.ns12;
|
|
|
|
|
- dt += args.ns21;
|
|
|
|
|
B += args.ns42;
|
|
B += args.ns42;
|
|
|
C += args.ns52;
|
|
C += args.ns52;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // Advance pointers for next batch
|
|
|
|
|
+ x += sgptg * args.ns12;
|
|
|
|
|
+ dt += sgptg * args.ns21;
|
|
|
|
|
+
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
- const float sumf = simd_sum(shared[sgitg*NW + tiisg]);
|
|
|
|
|
|
|
+ const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
|
|
|
|
|
|
|
|
if (tiisg == 0 && i2 + sgitg < n_t) {
|
|
if (tiisg == 0 && i2 + sgitg < n_t) {
|
|
|
y[sgitg*nh*nr] = sumf;
|
|
y[sgitg*nh*nr] = sumf;
|