|
|
@@ -1656,172 +1656,6 @@ static void ggml_compute_forward_mul_mat_id(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-// ggml_compute_forward_delta_net
|
|
|
-
|
|
|
-static void ggml_compute_forward_delta_net(
|
|
|
- const struct ggml_compute_params * params,
|
|
|
- struct ggml_tensor * dst) {
|
|
|
-
|
|
|
- const struct ggml_tensor * src0 = dst->src[0]; // query
|
|
|
- const struct ggml_tensor * src1 = dst->src[1]; // key
|
|
|
- const struct ggml_tensor * src2 = dst->src[2]; // value
|
|
|
- const struct ggml_tensor * src3 = dst->src[3]; // gate
|
|
|
- const struct ggml_tensor * src4 = dst->src[4]; // beta
|
|
|
- const struct ggml_tensor * src5 = dst->src[5]; // state
|
|
|
-
|
|
|
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
|
- GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
|
|
- GGML_ASSERT(src3->type == GGML_TYPE_F32);
|
|
|
- GGML_ASSERT(src4->type == GGML_TYPE_F32);
|
|
|
- GGML_ASSERT(src5->type == GGML_TYPE_F32);
|
|
|
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
|
-
|
|
|
- GGML_TENSOR_TERNARY_OP_LOCALS;
|
|
|
- GGML_TENSOR_LOCALS(int64_t, ne3, src3, ne);
|
|
|
- GGML_TENSOR_LOCALS(size_t, nb3, src3, nb);
|
|
|
- GGML_TENSOR_LOCALS(int64_t, ne4, src4, ne);
|
|
|
- GGML_TENSOR_LOCALS(size_t, nb4, src4, nb);
|
|
|
- GGML_TENSOR_LOCALS(int64_t, ne5, src5, ne);
|
|
|
- GGML_TENSOR_LOCALS(size_t, nb5, src5, nb);
|
|
|
-
|
|
|
- const int ith = params->ith;
|
|
|
- const int nth = params->nth;
|
|
|
-
|
|
|
- const int64_t S = src0->ne[0]; // head dimension
|
|
|
- const int64_t H = src0->ne[1]; // number of heads
|
|
|
- const int64_t n_tokens = src0->ne[2];
|
|
|
- const int64_t n_seqs = src0->ne[3];
|
|
|
-
|
|
|
- GGML_ASSERT(ne00 == S && ne01 == H && ne02 == n_tokens && ne03 == n_seqs);
|
|
|
- GGML_ASSERT(ne10 == S && ne11 == H && ne12 == n_tokens && ne13 == n_seqs);
|
|
|
- GGML_ASSERT(ne20 == S && ne21 == H && ne22 == n_tokens && ne23 == n_seqs);
|
|
|
- GGML_ASSERT(ne30 == S && ne31 == H && ne32 == n_tokens && ne33 == n_seqs);
|
|
|
- GGML_ASSERT(ne40 == H && ne41 == n_tokens && ne42 == n_seqs && ne43 == 1);
|
|
|
- GGML_ASSERT(ne50 == S && ne51 == S && ne52 == H && ne53 == n_seqs);
|
|
|
-
|
|
|
- // Get operation parameters
|
|
|
- bool use_qk_l2norm = ggml_get_op_params_i32(dst, 1) != 0;
|
|
|
- float scale;
|
|
|
- memcpy(&scale, ((int32_t*)dst->op_params) + 4, sizeof(float));
|
|
|
-
|
|
|
- GGML_ASSERT(ne0 == S * H);
|
|
|
- GGML_ASSERT(ne1 == n_tokens + S * n_seqs);
|
|
|
-
|
|
|
- // Parallelize over sequences and heads
|
|
|
- const int64_t n_total = n_seqs * H;
|
|
|
- const int64_t n_per_thread = (n_total + nth - 1) / nth;
|
|
|
- const int64_t n_start = ith * n_per_thread;
|
|
|
- const int64_t n_end = MIN(n_start + n_per_thread, n_total);
|
|
|
-
|
|
|
- for (int64_t n = n_start; n < n_end; ++n) {
|
|
|
- const int64_t seq_idx = n / H;
|
|
|
- const int64_t head_idx = n % H;
|
|
|
-
|
|
|
- // Get pointers to current sequence and head
|
|
|
- float * q_ptr = (float *)((char *)src0->data + seq_idx * nb03 + head_idx * nb01);
|
|
|
- float * k_ptr = (float *)((char *)src1->data + seq_idx * nb13 + head_idx * nb11);
|
|
|
- float * v_ptr = (float *)((char *)src2->data + seq_idx * nb23 + head_idx * nb21);
|
|
|
- float * g_ptr = (float *)((char *)src3->data + seq_idx * nb33 + head_idx * nb31);
|
|
|
- float * beta_ptr = (float *)((char *)src4->data + seq_idx * nb43);
|
|
|
- float * state_ptr = (float *)((char *)src5->data + seq_idx * nb53 + head_idx * nb51);
|
|
|
-
|
|
|
- float * out_ptr = (float *)((char *)dst->data + n * ne0 * sizeof(float));
|
|
|
- float * new_state_ptr = out_ptr + n_tokens * S;
|
|
|
-
|
|
|
- // Apply L2 normalization if requested
|
|
|
- if (use_qk_l2norm) {
|
|
|
- // Normalize query and key
|
|
|
- for (int64_t t = 0; t < n_tokens; ++t) {
|
|
|
- float q_sum = 0.0f, k_sum = 0.0f;
|
|
|
- for (int64_t s = 0; s < S; ++s) {
|
|
|
- float q_val = q_ptr[t * nb02 / sizeof(float) + s];
|
|
|
- float k_val = k_ptr[t * nb12 / sizeof(float) + s];
|
|
|
- q_sum += q_val * q_val;
|
|
|
- k_sum += k_val * k_val;
|
|
|
- }
|
|
|
- float q_norm = sqrtf(q_sum + 1e-6f);
|
|
|
- float k_norm = sqrtf(k_sum + 1e-6f);
|
|
|
-
|
|
|
- for (int64_t s = 0; s < S; ++s) {
|
|
|
- q_ptr[t * nb02 / sizeof(float) + s] /= q_norm;
|
|
|
- k_ptr[t * nb12 / sizeof(float) + s] /= k_norm;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // Apply scaling to query
|
|
|
- for (int64_t i = 0; i < n_tokens * S; ++i) {
|
|
|
- q_ptr[i] *= scale;
|
|
|
- }
|
|
|
-
|
|
|
- // Apply sigmoid to beta
|
|
|
- float * beta_sigmoid = (float *)alloca(n_tokens * sizeof(float));
|
|
|
- for (int64_t t = 0; t < n_tokens; ++t) {
|
|
|
- beta_sigmoid[t] = 1.0f / (1.0f + expf(-beta_ptr[t * nb42 / sizeof(float)]));
|
|
|
- }
|
|
|
-
|
|
|
- // Complete implementation of gated delta rule
|
|
|
- // Based on torch_recurrent_gated_delta_rule from the reference implementation
|
|
|
-
|
|
|
- // Process each token sequentially for recurrent computation
|
|
|
- for (int64_t t = 0; t < n_tokens; ++t) {
|
|
|
- // Get pointers to current token data
|
|
|
- float * q_t = q_ptr + t * (nb02 / sizeof(float));
|
|
|
- float * k_t = k_ptr + t * (nb12 / sizeof(float));
|
|
|
- float * v_t = v_ptr + t * (nb22 / sizeof(float));
|
|
|
- float * g_t = g_ptr + t * (nb32 / sizeof(float));
|
|
|
-
|
|
|
- // Apply exponential to gate and multiply by beta
|
|
|
- float g_exp = expf(g_t[0]); // g is per-head, not per-dimension
|
|
|
- float beta_t = beta_sigmoid[t];
|
|
|
-
|
|
|
- // Update recurrent state: state = state * g_exp
|
|
|
- for (int64_t i = 0; i < S * S; ++i) {
|
|
|
- state_ptr[i] *= g_exp;
|
|
|
- }
|
|
|
-
|
|
|
- // Compute kv_mem = (state * k_t^T).sum(dim=-1)
|
|
|
- // This is a matrix-vector multiplication: state[S×S] @ k_t[S]
|
|
|
- float kv_mem[S];
|
|
|
- for (int64_t i = 0; i < S; ++i) {
|
|
|
- kv_mem[i] = 0.0f;
|
|
|
- for (int64_t j = 0; j < S; ++j) {
|
|
|
- kv_mem[i] += state_ptr[i * S + j] * k_t[j];
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // Compute delta = (v_t - kv_mem) * beta_t
|
|
|
- float delta[S];
|
|
|
- for (int64_t i = 0; i < S; ++i) {
|
|
|
- delta[i] = (v_t[i] - kv_mem[i]) * beta_t;
|
|
|
- }
|
|
|
-
|
|
|
- // Update state: state = state + k_t * delta^T
|
|
|
- // This is an outer product: k_t[S] ⊗ delta[S]
|
|
|
- for (int64_t i = 0; i < S; ++i) {
|
|
|
- for (int64_t j = 0; j < S; ++j) {
|
|
|
- state_ptr[i * S + j] += k_t[i] * delta[j];
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // Compute output: out = (state * q_t^T).sum(dim=-1)
|
|
|
- // This is a matrix-vector multiplication: state[S×S] @ q_t[S]
|
|
|
- float * out_t = out_ptr + t * S;
|
|
|
- for (int64_t i = 0; i < S; ++i) {
|
|
|
- out_t[i] = 0.0f;
|
|
|
- for (int64_t j = 0; j < S; ++j) {
|
|
|
- out_t[i] += state_ptr[i * S + j] * q_t[j];
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // Copy final state to new_state
|
|
|
- memcpy(new_state_ptr, state_ptr, S * S * sizeof(float));
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-
|
|
|
/////////////////////////////////
|
|
|
|
|
|
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
|
|
@@ -2164,10 +1998,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
|
{
|
|
|
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
|
|
} break;
|
|
|
- case GGML_OP_DELTA_NET:
|
|
|
- {
|
|
|
- ggml_compute_forward_delta_net(params, tensor);
|
|
|
- } break;
|
|
|
case GGML_OP_MAP_CUSTOM1:
|
|
|
{
|
|
|
ggml_compute_forward_map_custom1(params, tensor);
|
|
|
@@ -2461,7 +2291,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
|
case GGML_OP_RWKV_WKV6:
|
|
|
case GGML_OP_GATED_LINEAR_ATTN:
|
|
|
case GGML_OP_RWKV_WKV7:
|
|
|
- case GGML_OP_DELTA_NET:
|
|
|
{
|
|
|
n_tasks = n_threads;
|
|
|
} break;
|