|
|
@@ -1823,10 +1823,16 @@ kernel void kernel_ssm_scan_f32(
|
|
|
device const void * src5,
|
|
|
device const void * src6,
|
|
|
device float * dst,
|
|
|
+ threadgroup float * shared [[threadgroup(0)]],
|
|
|
constant ggml_metal_kargs_ssm_scan & args,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
- uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
+ uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
|
+ ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
+ ushort sgptg[[simdgroups_per_threadgroup]],
|
|
|
+ uint3 tgpg[[threadgroups_per_grid]]) {
|
|
|
+
|
|
|
+ const int64_t i0 = tpitg.x;
|
|
|
const int64_t i1 = 0;
|
|
|
const int64_t ir = tgpig.x; // current head
|
|
|
const int64_t i3 = tgpig.y; // current seq
|
|
|
@@ -1841,41 +1847,88 @@ kernel void kernel_ssm_scan_f32(
|
|
|
const int64_t ng = args.n_group;
|
|
|
const int64_t n_t = args.n_seq_tokens;
|
|
|
|
|
|
- const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
|
|
+ const int64_t s_off = args.s_off;
|
|
|
|
|
|
device const int32_t * ids = (device const int32_t *) src6;
|
|
|
|
|
|
- device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
|
|
- device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
|
|
+ device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
|
|
+ device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
|
|
+ const int64_t i = i0 + i1*nc;
|
|
|
+ float s0 = s0_buff[i];
|
|
|
+ float s = s_buff[i];
|
|
|
+
|
|
|
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
|
|
|
+ device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
|
|
+ device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
|
|
|
+ device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
|
|
|
+ device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
|
|
|
+ device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
|
|
|
|
|
|
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
|
|
- device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
|
|
- device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
|
|
- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
|
|
|
- device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
|
|
- device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
|
|
- device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
|
|
+ device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
|
|
|
+ device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
|
|
|
+ device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
|
|
|
+ device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
|
|
|
+ device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
|
|
|
|
|
|
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
|
|
const float x_dt = x[0] * dt_soft_plus;
|
|
|
- float sumf = 0.0f;
|
|
|
|
|
|
- for (int64_t i0 = 0; i0 < nc; ++i0) {
|
|
|
- const int64_t i = i0 + i1*nc;
|
|
|
- const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
|
|
- sumf += state * C[i0];
|
|
|
- s[i] = state;
|
|
|
- }
|
|
|
+ const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
|
|
+ s = state;
|
|
|
+
|
|
|
+ // Parallel sum: This relies on the fact that this kernel will be
|
|
|
+ // dispatched with each threadgroup having (d_state, 1, 1) threads which
|
|
|
+ // are subdivided into SIMD groups of size `sgptg`. The goal is to
|
|
|
+ // compute y = sum({state * C[i] for i in range(d_state)}).
|
|
|
+ // To parallelize this effectively, we first use simd_sum over each SIMD
|
|
|
+ // group to compute the sum of each SIMD group, then place the result in
|
|
|
+ // the SIMD group's indexed bucket in the shared memory. We then sum
|
|
|
+ // over the individual group sums to compute the final sum.
|
|
|
+
|
|
|
+ // Computed for each thread
|
|
|
+ float sumf = state * C[i0];
|
|
|
|
|
|
- y[0] = sumf;
|
|
|
+ // Sum the threads in the simd group => simd sum
|
|
|
+ sumf = simd_sum(sumf);
|
|
|
+
|
|
|
+ if (sgptg > 1) {
|
|
|
+
|
|
|
+ // Once per simd group, place the group sum into the shared buffer
|
|
|
+ if (tiisg == 0) {
|
|
|
+ shared[sgitg] = sumf;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Wait for all threads in the threadgroup to reach this point. This
|
|
|
+ // ensures that all elements of the shared buffer are populated with the
|
|
|
+ // sum of the individual simd groups.
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
+
|
|
|
+ // For simd group 0 at indices < num simd groups, extract the shared
|
|
|
+ // simd sum
|
|
|
+ sumf = 0.0f;
|
|
|
+ if (sgitg == 0) {
|
|
|
+ if (tiisg < sgptg) {
|
|
|
+ sumf = shared[tiisg];
|
|
|
+ }
|
|
|
+ sumf = simd_sum(sumf);
|
|
|
+ if (tiisg == 0) {
|
|
|
+ y[0] = sumf;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else if (tiisg == 0) {
|
|
|
+ y[0] = sumf;
|
|
|
+ }
|
|
|
|
|
|
// recurse
|
|
|
s0 = s;
|
|
|
}
|
|
|
+
|
|
|
+ // Assign the final state to the output buffer
|
|
|
+ s_buff[i] = s;
|
|
|
}
|
|
|
|
|
|
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
|
|
-// TODO: optimize (e.g. by parallelizing over d_state)
|
|
|
kernel void kernel_ssm_scan_f32_group(
|
|
|
device const void * src0,
|
|
|
device const void * src1,
|
|
|
@@ -1885,10 +1938,16 @@ kernel void kernel_ssm_scan_f32_group(
|
|
|
device const void * src5,
|
|
|
device const void * src6,
|
|
|
device float * dst,
|
|
|
+ threadgroup float * shared [[threadgroup(0)]],
|
|
|
constant ggml_metal_kargs_ssm_scan & args,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
- uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
+ uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
|
+ ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
+ ushort sgptg[[simdgroups_per_threadgroup]],
|
|
|
+ uint3 tgpg[[threadgroups_per_grid]]) {
|
|
|
+
|
|
|
+ const int64_t i0 = tpitg.x;
|
|
|
const int64_t i1 = tgpig.x;
|
|
|
const int64_t ir = tgpig.y; // current head
|
|
|
const int64_t i3 = tgpig.z; // current seq
|
|
|
@@ -1903,38 +1962,81 @@ kernel void kernel_ssm_scan_f32_group(
|
|
|
const int64_t ng = args.n_group;
|
|
|
const int64_t n_t = args.n_seq_tokens;
|
|
|
|
|
|
- const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
|
|
+ const int64_t s_off = args.s_off;
|
|
|
|
|
|
device const int32_t * ids = (device const int32_t *) src6;
|
|
|
|
|
|
- device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
|
|
- device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
|
|
+ device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
|
|
+ device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
|
|
+ const int64_t i = i0 + i1*nc;
|
|
|
+ float s0 = s0_buff[i];
|
|
|
+ float s = s_buff[i];
|
|
|
+
|
|
|
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
|
|
+ device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
|
|
+ device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
|
|
|
+ device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
|
|
|
+ device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
|
|
|
+ device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
|
|
|
|
|
|
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
|
|
- device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
|
|
- device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
|
|
- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
|
|
- device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
|
|
- device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
|
|
- device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
|
|
+ device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
|
|
|
+ device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
|
|
|
+ device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
|
|
|
+ device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
|
|
|
+ device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
|
|
|
|
|
|
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
|
|
const float x_dt = x[0] * dt_soft_plus;
|
|
|
const float dA = exp(dt_soft_plus * A[0]);
|
|
|
- float sumf = 0.0f;
|
|
|
|
|
|
- for (int64_t i0 = 0; i0 < nc; ++i0) {
|
|
|
- const int64_t i = i0 + i1*nc;
|
|
|
- const float state = (s0[i] * dA) + (B[i0] * x_dt);
|
|
|
- sumf += state * C[i0];
|
|
|
- s[i] = state;
|
|
|
+ const float state = (s0 * dA) + (B[i0] * x_dt);
|
|
|
+ s = state;
|
|
|
+
|
|
|
+ // Parallel sum: This relies on the fact that this kernel will be
|
|
|
+ // dispatched with each threadgroup having (d_state, 1, 1) threads which
|
|
|
+ // are subdivided into SIMD groups of size `sgptg`. The goal is to
|
|
|
+ // compute y = sum({state * C[i] for i in range(d_state)}).
|
|
|
+ // To parallelize this effectively, we first use simd_sum over each SIMD
|
|
|
+ // group to compute the sum of each SIMD group, then place the result in
|
|
|
+ // the SIMD group's indexed bucket in the shared memory. We then sum
|
|
|
+ // over the individual group sums to compute the final sum.
|
|
|
+
|
|
|
+ // Computed for each thread
|
|
|
+ float sumf = state * C[i0];
|
|
|
+
|
|
|
+ // Sum the threads in the simd group => simd sum
|
|
|
+ sumf = simd_sum(sumf);
|
|
|
+
|
|
|
+ // Once per simd group, place the group sum into the shared buffer
|
|
|
+ if (tiisg == 0) {
|
|
|
+ shared[sgitg] = sumf;
|
|
|
}
|
|
|
|
|
|
- y[0] = sumf;
|
|
|
+ // Wait for all threads in the threadgroup to reach this point. This
|
|
|
+ // ensures that all elements of the shared buffer are populated with the
|
|
|
+ // sum of the individual simd groups.
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
+
|
|
|
+ // For simd group 0 at indices < num simd groups, extract the shared
|
|
|
+ // simd sum
|
|
|
+ sumf = 0.0f;
|
|
|
+ if (sgitg == 0) {
|
|
|
+ if (tiisg < sgptg) {
|
|
|
+ sumf = shared[tiisg];
|
|
|
+ }
|
|
|
+ sumf = simd_sum(sumf);
|
|
|
+ if (tiisg == 0) {
|
|
|
+ y[0] = sumf;
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
// recurse
|
|
|
s0 = s;
|
|
|
}
|
|
|
+
|
|
|
+ // Assign the final state to the output buffer
|
|
|
+ s_buff[i] = s;
|
|
|
}
|
|
|
|
|
|
kernel void kernel_rwkv_wkv6_f32(
|