|
@@ -1841,6 +1841,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
|
"FLASH_ATTN",
|
|
"FLASH_ATTN",
|
|
|
"FLASH_FF",
|
|
"FLASH_FF",
|
|
|
"FLASH_ATTN_BACK",
|
|
"FLASH_ATTN_BACK",
|
|
|
|
|
+ "SSM_CONV",
|
|
|
|
|
+ "SSM_SCAN",
|
|
|
"WIN_PART",
|
|
"WIN_PART",
|
|
|
"WIN_UNPART",
|
|
"WIN_UNPART",
|
|
|
"GET_REL_POS",
|
|
"GET_REL_POS",
|
|
@@ -1863,7 +1865,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
|
"CROSS_ENTROPY_LOSS_BACK",
|
|
"CROSS_ENTROPY_LOSS_BACK",
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
-static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
|
|
|
|
|
|
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
|
|
|
|
|
|
|
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
|
"none",
|
|
"none",
|
|
@@ -1929,6 +1931,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
|
"flash_attn(x)",
|
|
"flash_attn(x)",
|
|
|
"flash_ff(x)",
|
|
"flash_ff(x)",
|
|
|
"flash_attn_back(x)",
|
|
"flash_attn_back(x)",
|
|
|
|
|
+ "ssm_conv(x)",
|
|
|
|
|
+ "ssm_scan(x)",
|
|
|
"win_part(x)",
|
|
"win_part(x)",
|
|
|
"win_unpart(x)",
|
|
"win_unpart(x)",
|
|
|
"get_rel_pos(x)",
|
|
"get_rel_pos(x)",
|
|
@@ -1951,7 +1955,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
|
"cross_entropy_loss_back(x,y)",
|
|
"cross_entropy_loss_back(x,y)",
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
-static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
|
|
|
|
|
|
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
|
|
|
|
|
|
|
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
|
|
|
|
|
|
@@ -6154,6 +6158,108 @@ struct ggml_tensor * ggml_flash_attn_back(
|
|
|
return result;
|
|
return result;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+// ggml_ssm_conv
|
|
|
|
|
+
|
|
|
|
|
+struct ggml_tensor * ggml_ssm_conv(
|
|
|
|
|
+ struct ggml_context * ctx,
|
|
|
|
|
+ struct ggml_tensor * s,
|
|
|
|
|
+ struct ggml_tensor * x,
|
|
|
|
|
+ struct ggml_tensor * c,
|
|
|
|
|
+ struct ggml_tensor * sq) {
|
|
|
|
|
+ GGML_ASSERT(ggml_is_3d(s));
|
|
|
|
|
+ GGML_ASSERT(ggml_is_matrix(x));
|
|
|
|
|
+ GGML_ASSERT(ggml_is_matrix(c));
|
|
|
|
|
+ GGML_ASSERT(ggml_is_matrix(sq));
|
|
|
|
|
+ GGML_ASSERT(sq->type == GGML_TYPE_I32);
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t d_conv = c->ne[0];
|
|
|
|
|
+ const int64_t d_inner = c->ne[1];
|
|
|
|
|
+ const int64_t n_tokens = x->ne[1];
|
|
|
|
|
+ const int64_t n_kv = s->ne[2];
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT( s->ne[0] == d_conv - 1);
|
|
|
|
|
+ GGML_ASSERT( s->ne[1] == d_inner);
|
|
|
|
|
+ GGML_ASSERT( x->ne[0] == d_inner);
|
|
|
|
|
+ GGML_ASSERT(sq->ne[0] == n_kv);
|
|
|
|
|
+ GGML_ASSERT(sq->ne[1] == n_tokens);
|
|
|
|
|
+
|
|
|
|
|
+ bool is_node = false;
|
|
|
|
|
+
|
|
|
|
|
+ if (s->grad || x->grad || c->grad || sq->grad) {
|
|
|
|
|
+ GGML_ASSERT(false); // TODO: implement
|
|
|
|
|
+ is_node = true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv}
|
|
|
|
|
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
|
|
|
|
|
+
|
|
|
|
|
+ result->op = GGML_OP_SSM_CONV;
|
|
|
|
|
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
|
|
|
|
+ result->src[0] = s;
|
|
|
|
|
+ result->src[1] = x;
|
|
|
|
|
+ result->src[2] = c;
|
|
|
|
|
+ result->src[3] = sq;
|
|
|
|
|
+
|
|
|
|
|
+ return result;
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// ggml_ssm_scan
|
|
|
|
|
+
|
|
|
|
|
+struct ggml_tensor * ggml_ssm_scan(
|
|
|
|
|
+ struct ggml_context * ctx,
|
|
|
|
|
+ struct ggml_tensor * s,
|
|
|
|
|
+ struct ggml_tensor * x,
|
|
|
|
|
+ struct ggml_tensor * dt,
|
|
|
|
|
+ struct ggml_tensor * A,
|
|
|
|
|
+ struct ggml_tensor * B,
|
|
|
|
|
+ struct ggml_tensor * C,
|
|
|
|
|
+ struct ggml_tensor * sq) {
|
|
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(s));
|
|
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(x));
|
|
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(dt));
|
|
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(A));
|
|
|
|
|
+ GGML_ASSERT(sq->type == GGML_TYPE_I32);
|
|
|
|
|
+ GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
|
|
|
|
+ GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
|
|
|
|
|
+ GGML_ASSERT(ggml_are_same_shape(x, dt));
|
|
|
|
|
+
|
|
|
|
|
+ {
|
|
|
|
|
+ const int64_t d_state = s->ne[0];
|
|
|
|
|
+ const int64_t d_inner = s->ne[1];
|
|
|
|
|
+ const int64_t n_tokens = x->ne[1];
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT(x->ne[0] == d_inner);
|
|
|
|
|
+ GGML_ASSERT(A->ne[0] == d_state);
|
|
|
|
|
+ GGML_ASSERT(A->ne[1] == d_inner);
|
|
|
|
|
+ GGML_ASSERT(B->ne[0] == d_state);
|
|
|
|
|
+ GGML_ASSERT(B->ne[1] == n_tokens);
|
|
|
|
|
+ GGML_ASSERT(C->ne[0] == d_state);
|
|
|
|
|
+ GGML_ASSERT(C->ne[1] == n_tokens);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ bool is_node = false;
|
|
|
|
|
+
|
|
|
|
|
+ if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) {
|
|
|
|
|
+ GGML_ASSERT(false); // TODO: implement
|
|
|
|
|
+ is_node = true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
|
|
|
|
|
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
|
|
|
|
+
|
|
|
|
|
+ result->op = GGML_OP_SSM_SCAN;
|
|
|
|
|
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
|
|
|
|
+ result->src[0] = s;
|
|
|
|
|
+ result->src[1] = x;
|
|
|
|
|
+ result->src[2] = dt;
|
|
|
|
|
+ result->src[3] = A;
|
|
|
|
|
+ result->src[4] = B;
|
|
|
|
|
+ result->src[5] = C;
|
|
|
|
|
+ result->src[6] = sq;
|
|
|
|
|
+
|
|
|
|
|
+ return result;
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
// ggml_win_part
|
|
// ggml_win_part
|
|
|
|
|
|
|
|
struct ggml_tensor * ggml_win_part(
|
|
struct ggml_tensor * ggml_win_part(
|
|
@@ -14771,6 +14877,257 @@ static void ggml_compute_forward_flash_attn_back(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+// ggml_compute_forward_ssm_conv
|
|
|
|
|
+
|
|
|
|
|
+static void ggml_compute_forward_ssm_conv_f32(
|
|
|
|
|
+ const struct ggml_compute_params * params,
|
|
|
|
|
+ struct ggml_tensor * dst) {
|
|
|
|
|
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ const struct ggml_tensor * src0 = dst->src[0]; // conv_state
|
|
|
|
|
+ const struct ggml_tensor * src1 = dst->src[1]; // x
|
|
|
|
|
+ const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
|
|
|
|
|
+ const struct ggml_tensor * src3 = dst->src[3]; // state_seq
|
|
|
|
|
+
|
|
|
|
|
+ const int ith = params->ith;
|
|
|
|
|
+ const int nth = params->nth;
|
|
|
|
|
+
|
|
|
|
|
+ const int nc = src2->ne[0]; // d_conv
|
|
|
|
|
+ const int nr = src0->ne[1]; // d_inner
|
|
|
|
|
+ const int n_t = src1->ne[1]; // n_tokens
|
|
|
|
|
+ const int n_kv = src0->ne[2]; // max number of sequences in the batch
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
|
|
|
|
|
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
|
|
|
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
|
|
|
|
|
+ GGML_ASSERT(src2->nb[0] == sizeof(float));
|
|
|
|
|
+ GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
|
|
|
|
|
+ GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
|
|
|
|
+ // for use with the destination state offset between sequences
|
|
|
|
|
+ GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
|
|
|
|
|
+
|
|
|
|
|
+ // rows per thread
|
|
|
|
|
+ const int dr = (nr + nth - 1)/nth;
|
|
|
|
|
+
|
|
|
|
|
+ // row range for this thread
|
|
|
|
|
+ const int ir0 = dr*ith;
|
|
|
|
|
+ const int ir1 = MIN(ir0 + dr, nr);
|
|
|
|
|
+ const int ir = ir1 - ir0;
|
|
|
|
|
+
|
|
|
|
|
+ if (n_kv > 1) {
|
|
|
|
|
+ // multiple sequences means it's hard to know when it's the first time a state is read,
|
|
|
|
|
+ // so copy them all over to the destination, just to be sure.
|
|
|
|
|
+ for (int i3 = 0; i3 < n_kv; ++i3) {
|
|
|
|
|
+ float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
|
|
|
|
+ float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
|
|
|
|
|
+ // can't use memcpy because of d_conv vs d_conv - 1
|
|
|
|
|
+ for (int i1 = 0; i1 < ir; ++i1) {
|
|
|
|
|
+ for (int i0 = 0; i0 < nc - 1; ++i0) {
|
|
|
|
|
+ // copy s0 to last (d_conv - 1) columns of s
|
|
|
|
|
+ s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for (int i2 = 0; i2 < n_t; ++i2) {
|
|
|
|
|
+ int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
|
|
|
|
|
+ float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
|
|
|
|
|
+ float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
|
|
|
|
|
+ float * s0; // {d_conv - 1, d_inner, n_kv}
|
|
|
|
|
+ float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
|
|
|
|
+ float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
|
|
|
|
|
+ int ne0s0;
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
|
|
|
|
|
+
|
|
|
|
|
+ // avoid needing to copy the state for the first token
|
|
|
|
|
+ if (i2 == 0) {
|
|
|
|
|
+ s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
|
|
|
|
|
+ ne0s0 = src0->ne[0];
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // the source is the last (d_conv - 1) columns of the destination
|
|
|
|
|
+ s0 = s + 1;
|
|
|
|
|
+ ne0s0 = nc;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // d_inner
|
|
|
|
|
+ for (int i1 = 0; i1 < ir; ++i1) {
|
|
|
|
|
+ // shift state left
|
|
|
|
|
+ for (int i0 = 0; i0 < nc - 1; ++i0) {
|
|
|
|
|
+ s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
|
|
|
|
|
+ }
|
|
|
|
|
+ // insert x on the last column
|
|
|
|
|
+ s[(nc - 1) + i1*nc] = x0[i1];
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // handle copies when there are multiple output states
|
|
|
|
|
+ for (int i3 = 1; i3 < n_kv; ++i3) {
|
|
|
|
|
+ int32_t seq = sq[i3];
|
|
|
|
|
+ if (0 <= seq && seq < n_kv) {
|
|
|
|
|
+ float * s1 = s + (seq - sq[0])*nc*nr;
|
|
|
|
|
+ memcpy(s1, s, nc*ir*sizeof(float));
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // stop at negative or too big seq_ids
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // it seems a little faster when this is separate from the state shift
|
|
|
|
|
+ for (int i1 = 0; i1 < ir; ++i1) {
|
|
|
|
|
+ // rowwise dot product
|
|
|
|
|
+ float sumf = 0.0f;
|
|
|
|
|
+ for (int i0 = 0; i0 < nc; ++i0) {
|
|
|
|
|
+ int i = i0 + i1*nc;
|
|
|
|
|
+ sumf += s[i] * c[i];
|
|
|
|
|
+ }
|
|
|
|
|
+ x[i1] = sumf;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+static void ggml_compute_forward_ssm_conv(
|
|
|
|
|
+ const struct ggml_compute_params * params,
|
|
|
|
|
+ struct ggml_tensor * dst) {
|
|
|
|
|
+ switch (dst->src[0]->type) {
|
|
|
|
|
+ case GGML_TYPE_F32:
|
|
|
|
|
+ {
|
|
|
|
|
+ ggml_compute_forward_ssm_conv_f32(params, dst);
|
|
|
|
|
+ } break;
|
|
|
|
|
+ default:
|
|
|
|
|
+ {
|
|
|
|
|
+ GGML_ASSERT(false);
|
|
|
|
|
+ } break;
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// ggml_compute_forward_ssm_scan
|
|
|
|
|
+
|
|
|
|
|
+static void ggml_compute_forward_ssm_scan_f32(
|
|
|
|
|
+ const struct ggml_compute_params * params,
|
|
|
|
|
+ struct ggml_tensor * dst) {
|
|
|
|
|
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ const struct ggml_tensor * src0 = dst->src[0]; // s
|
|
|
|
|
+ const struct ggml_tensor * src1 = dst->src[1]; // x
|
|
|
|
|
+ const struct ggml_tensor * src2 = dst->src[2]; // dt
|
|
|
|
|
+ const struct ggml_tensor * src3 = dst->src[3]; // A
|
|
|
|
|
+ const struct ggml_tensor * src4 = dst->src[4]; // B
|
|
|
|
|
+ const struct ggml_tensor * src5 = dst->src[5]; // C
|
|
|
|
|
+ const struct ggml_tensor * src6 = dst->src[6]; // sq
|
|
|
|
|
+
|
|
|
|
|
+ const int ith = params->ith;
|
|
|
|
|
+ const int nth = params->nth;
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t nc = src0->ne[0]; // d_state
|
|
|
|
|
+ const int64_t nr = src0->ne[1]; // d_inner
|
|
|
|
|
+ const int64_t n_t = src1->ne[1]; // number of tokens in the batch
|
|
|
|
|
+ const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
|
|
|
|
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
|
|
|
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
|
|
|
|
|
+ GGML_ASSERT(src2->nb[0] == sizeof(float));
|
|
|
|
|
+ GGML_ASSERT(src3->nb[0] == sizeof(float));
|
|
|
|
|
+ GGML_ASSERT(src4->nb[0] == sizeof(float));
|
|
|
|
|
+ GGML_ASSERT(src5->nb[0] == sizeof(float));
|
|
|
|
|
+ // required for the dot product between s and C, and when copying the states
|
|
|
|
|
+ GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
|
|
|
|
+ // required for per-sequence offsets for states
|
|
|
|
|
+ GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
|
|
|
|
+ // required to get correct offset for state destination (i.e. src1->nb[2])
|
|
|
|
|
+ GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
|
|
|
|
|
+
|
|
|
|
|
+ // rows per thread
|
|
|
|
|
+ const int dr = (nr + nth - 1)/nth;
|
|
|
|
|
+
|
|
|
|
|
+ // row range for this thread
|
|
|
|
|
+ const int ir0 = dr*ith;
|
|
|
|
|
+ const int ir1 = MIN(ir0 + dr, nr);
|
|
|
|
|
+ const int ir = ir1 - ir0;
|
|
|
|
|
+
|
|
|
|
|
+ if (n_kv > 1) {
|
|
|
|
|
+ // it's hard to know if the source states have already been copied
|
|
|
|
|
+ // when there are multiple, so copy them already.
|
|
|
|
|
+ for (int i3 = 0; i3 < n_kv; ++i3) {
|
|
|
|
|
+ float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
|
|
|
|
+ float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
|
|
|
|
|
+ memcpy(s, s0, nc*ir*sizeof(float));
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for (int i2 = 0; i2 < n_t; ++i2) {
|
|
|
|
|
+ int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens}
|
|
|
|
|
+ float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
|
|
|
|
+ float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
|
|
|
|
|
+ float * s0;
|
|
|
|
|
+ float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
|
|
|
|
+ float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
|
|
|
|
|
+ float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
|
|
|
|
+ float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
|
|
|
|
|
+ float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
|
|
|
|
|
+
|
|
|
|
|
+ // avoid needing to copy the state for the first token
|
|
|
|
|
+ if (i2 == 0) {
|
|
|
|
|
+ s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv}
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // otherwise the source is the same as the destination
|
|
|
|
|
+ s0 = s;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // d_inner
|
|
|
|
|
+ for (int i1 = 0; i1 < ir; ++i1) {
|
|
|
|
|
+ // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
|
|
|
|
|
+ float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
|
|
|
|
+ float x_dt = x[i1] * dt_soft_plus;
|
|
|
|
|
+ float sumf = 0.0f;
|
|
|
|
|
+ // d_state
|
|
|
|
|
+ for (int i0 = 0; i0 < nc; ++i0) {
|
|
|
|
|
+ int i = i0 + i1*nc;
|
|
|
|
|
+ // state = prev_state * dA + dB * x
|
|
|
|
|
+ float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
|
|
|
|
+ // y = rowwise_dotprod(state, C)
|
|
|
|
|
+ sumf += state * C[i0];
|
|
|
|
|
+ s[i] = state;
|
|
|
|
|
+ }
|
|
|
|
|
+ y[i1] = sumf;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // handle copies when there are multiple output states
|
|
|
|
|
+ for (int i3 = 1; i3 < n_kv; ++i3) {
|
|
|
|
|
+ int32_t seq = sq[i3];
|
|
|
|
|
+ if (0 <= seq && seq < n_kv) {
|
|
|
|
|
+ float * s1 = s + (seq - sq[0])*nc*nr;
|
|
|
|
|
+ memcpy(s1, s, nc*ir*sizeof(float));
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // stop at negative or too big seq_ids
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+static void ggml_compute_forward_ssm_scan(
|
|
|
|
|
+ const struct ggml_compute_params * params,
|
|
|
|
|
+ struct ggml_tensor * dst) {
|
|
|
|
|
+ switch (dst->src[0]->type) {
|
|
|
|
|
+ case GGML_TYPE_F32:
|
|
|
|
|
+ {
|
|
|
|
|
+ ggml_compute_forward_ssm_scan_f32(params, dst);
|
|
|
|
|
+ } break;
|
|
|
|
|
+ default:
|
|
|
|
|
+ {
|
|
|
|
|
+ GGML_ASSERT(false);
|
|
|
|
|
+ } break;
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
// ggml_compute_forward_win_part
|
|
// ggml_compute_forward_win_part
|
|
|
|
|
|
|
|
static void ggml_compute_forward_win_part_f32(
|
|
static void ggml_compute_forward_win_part_f32(
|
|
@@ -15830,6 +16187,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
|
bool masked = t != 0;
|
|
bool masked = t != 0;
|
|
|
ggml_compute_forward_flash_attn_back(params, masked, tensor);
|
|
ggml_compute_forward_flash_attn_back(params, masked, tensor);
|
|
|
} break;
|
|
} break;
|
|
|
|
|
+ case GGML_OP_SSM_CONV:
|
|
|
|
|
+ {
|
|
|
|
|
+ ggml_compute_forward_ssm_conv(params, tensor);
|
|
|
|
|
+ } break;
|
|
|
|
|
+ case GGML_OP_SSM_SCAN:
|
|
|
|
|
+ {
|
|
|
|
|
+ ggml_compute_forward_ssm_scan(params, tensor);
|
|
|
|
|
+ } break;
|
|
|
case GGML_OP_WIN_PART:
|
|
case GGML_OP_WIN_PART:
|
|
|
{
|
|
{
|
|
|
ggml_compute_forward_win_part(params, tensor);
|
|
ggml_compute_forward_win_part(params, tensor);
|
|
@@ -16884,6 +17249,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
|
{
|
|
{
|
|
|
GGML_ASSERT(false); // not supported
|
|
GGML_ASSERT(false); // not supported
|
|
|
} break;
|
|
} break;
|
|
|
|
|
+ case GGML_OP_SSM_CONV:
|
|
|
|
|
+ case GGML_OP_SSM_SCAN:
|
|
|
|
|
+ {
|
|
|
|
|
+ GGML_ASSERT(false); // TODO: not implemented
|
|
|
|
|
+ } break;
|
|
|
case GGML_OP_WIN_PART:
|
|
case GGML_OP_WIN_PART:
|
|
|
case GGML_OP_WIN_UNPART:
|
|
case GGML_OP_WIN_UNPART:
|
|
|
case GGML_OP_UNARY:
|
|
case GGML_OP_UNARY:
|
|
@@ -17590,6 +17960,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
|
{
|
|
{
|
|
|
n_tasks = n_threads;
|
|
n_tasks = n_threads;
|
|
|
} break;
|
|
} break;
|
|
|
|
|
+ case GGML_OP_SSM_CONV:
|
|
|
|
|
+ case GGML_OP_SSM_SCAN:
|
|
|
|
|
+ {
|
|
|
|
|
+ n_tasks = n_threads;
|
|
|
|
|
+ } break;
|
|
|
case GGML_OP_WIN_PART:
|
|
case GGML_OP_WIN_PART:
|
|
|
case GGML_OP_WIN_UNPART:
|
|
case GGML_OP_WIN_UNPART:
|
|
|
case GGML_OP_GET_REL_POS:
|
|
case GGML_OP_GET_REL_POS:
|