|
|
@@ -7633,39 +7633,83 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
|
const int ir = ir1 - ir0;
|
|
|
|
|
|
- for (int i3 = 0; i3 < n_s; ++i3) {
|
|
|
- for (int i2 = 0; i2 < n_t; ++i2) {
|
|
|
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
|
|
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
|
|
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
|
|
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
|
|
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
|
|
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
|
|
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
|
|
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
|
|
-
|
|
|
- // use the output as the source for the next token-wise iterations
|
|
|
- if (i2 > 0) { s0 = s; }
|
|
|
-
|
|
|
- // d_inner
|
|
|
- for (int i1 = 0; i1 < ir; ++i1) {
|
|
|
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
|
|
|
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
|
|
- float x_dt = x[i1] * dt_soft_plus;
|
|
|
- float sumf = 0.0f;
|
|
|
- // d_state
|
|
|
- for (int i0 = 0; i0 < nc; ++i0) {
|
|
|
- int i = i0 + i1*nc;
|
|
|
- // state = prev_state * dA + dB * x
|
|
|
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
|
|
- // y = rowwise_dotprod(state, C)
|
|
|
- sumf += state * C[i0];
|
|
|
- s[i] = state;
|
|
|
+ #ifdef __ARM_FEATURE_SVE
|
|
|
+ for (int i3 = 0; i3 < n_s; ++i3) {
|
|
|
+ for (int i2 = 0; i2 < n_t; ++i2) {
|
|
|
+ const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
|
|
+ const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
|
|
+ const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
|
|
+ const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
|
|
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
|
|
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
|
|
+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
|
|
+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
|
|
+
|
|
|
+ // use the output as the source for the next token-wise iterations
|
|
|
+ if (i2 > 0) { s0 = s; }
|
|
|
+
|
|
|
+ // d_inner
|
|
|
+ for (int i1 = 0; i1 < ir; ++i1) {
|
|
|
+ float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
|
|
+ float x_dt = x[i1] * dt_soft_plus;
|
|
|
+ svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
|
|
|
+ svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
|
|
|
+ svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
|
|
|
+
|
|
|
+ for (int64_t k = 0; k < nc; k += svcntw()) {
|
|
|
+ svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
|
|
|
+ svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
|
|
|
+ svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
|
|
|
+ svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
|
|
|
+
|
|
|
+ svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
|
|
+ t1 = exp_ps_sve(svptrue_b32(), t1);
|
|
|
+ svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
|
|
+
|
|
|
+ vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
|
|
|
+ r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
|
|
+
|
|
|
+ GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
|
|
|
+ }
|
|
|
+ y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
|
|
}
|
|
|
- y[i1] = sumf;
|
|
|
}
|
|
|
}
|
|
|
- }
|
|
|
+ #else
|
|
|
+ for (int i3 = 0; i3 < n_s; ++i3) {
|
|
|
+ for (int i2 = 0; i2 < n_t; ++i2) {
|
|
|
+ const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
|
|
+ const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
|
|
+ const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
|
|
+ const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
|
|
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
|
|
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
|
|
+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
|
|
+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
|
|
+
|
|
|
+ // use the output as the source for the next token-wise iterations
|
|
|
+ if (i2 > 0) { s0 = s; }
|
|
|
+
|
|
|
+ // d_inner
|
|
|
+ for (int i1 = 0; i1 < ir; ++i1) {
|
|
|
+ // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
|
|
|
+ float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
|
|
+ float x_dt = x[i1] * dt_soft_plus;
|
|
|
+ float sumf = 0.0f;
|
|
|
+ // d_state
|
|
|
+ for (int i0 = 0; i0 < nc; ++i0) {
|
|
|
+ int i = i0 + i1*nc;
|
|
|
+ // state = prev_state * dA + dB * x
|
|
|
+ float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
|
|
+ // y = rowwise_dotprod(state, C)
|
|
|
+ sumf += state * C[i0];
|
|
|
+ s[i] = state;
|
|
|
+ }
|
|
|
+ y[i1] = sumf;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ #endif
|
|
|
}
|
|
|
|
|
|
void ggml_compute_forward_ssm_scan(
|