|
|
@@ -412,7 +412,6 @@ struct ggml_tensor * llm_build_qwen3next::delta_net_recurrent(
|
|
|
const int64_t S_v = v->ne[0];
|
|
|
const int64_t H_v = v->ne[1];
|
|
|
|
|
|
- GGML_ASSERT(n_tokens == 1); // Recurrent version only supports sequence_length = 1
|
|
|
GGML_ASSERT(v->ne[2] == n_tokens);
|
|
|
GGML_ASSERT(k->ne[2] == n_tokens);
|
|
|
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
|
|
@@ -459,62 +458,85 @@ struct ggml_tensor * llm_build_qwen3next::delta_net_recurrent(
|
|
|
g = ggml_cont(ctx, ggml_permute(ctx, g, 2, 0, 3, 1));
|
|
|
cb(g, "g_permute", il);
|
|
|
|
|
|
- ggml_tensor * q_t = ggml_cont_4d(ctx, q, 1, S_k, H_k, n_seqs);
|
|
|
- ggml_tensor * k_t = ggml_cont_4d(ctx, k, 1, S_k, H_k, n_seqs);
|
|
|
- ggml_tensor * v_t = ggml_cont_4d(ctx, v, 1, S_v, H_k, n_seqs);
|
|
|
- ggml_tensor * g_t = ggml_cont_4d(ctx, g, 1, 1, H_k, n_seqs);
|
|
|
- ggml_tensor * beta_t = ggml_cont_4d(ctx, beta, 1, 1, H_k, n_seqs);
|
|
|
+ ggml_tensor * q_tokens = ggml_cont_4d(ctx, q, n_tokens, S_k, H_k, n_seqs);
|
|
|
+ ggml_tensor * k_tokens = ggml_cont_4d(ctx, k, n_tokens, S_k, H_k, n_seqs);
|
|
|
+ ggml_tensor * v_tokens = ggml_cont_4d(ctx, v, n_tokens, S_v, H_k, n_seqs);
|
|
|
+ ggml_tensor * g_tokens = ggml_cont_4d(ctx, g, n_tokens, 1, H_k, n_seqs);
|
|
|
+ ggml_tensor * beta_tokens = ggml_cont_4d(ctx, beta, n_tokens, 1, H_k, n_seqs);
|
|
|
+
|
|
|
state = ggml_cont_4d(ctx, state, S_v, S_v, H_k, n_seqs);
|
|
|
+ ggml_tensor * g_tokens_exp = ggml_exp(ctx, g_tokens);
|
|
|
+
|
|
|
+ ggml_tensor * final_output = nullptr;
|
|
|
+ ggml_tensor * q_t, * k_t, * v_t, * g_t_exp, * beta_t;
|
|
|
+ for (int i = 0; i < n_tokens; i++) { // this part is per token
|
|
|
+ if (n_tokens == 1) { // don't do unnecessary reshapes / views
|
|
|
+ q_t = q_tokens;
|
|
|
+ k_t = k_tokens;
|
|
|
+ v_t = v_tokens;
|
|
|
+ g_t_exp = g_tokens_exp;
|
|
|
+ beta_t = beta_tokens;
|
|
|
+ } else {
|
|
|
+ q_t = ggml_view_4d(ctx, q_tokens, 1, S_k, H_k, n_seqs, q_tokens->nb[1], q_tokens->nb[2], q_tokens->nb[3], i * ggml_element_size(q_tokens));
|
|
|
+ k_t = ggml_view_4d(ctx, k_tokens, 1, S_k, H_k, n_seqs, k_tokens->nb[1], k_tokens->nb[2], k_tokens->nb[3], i * ggml_element_size(k_tokens));
|
|
|
+ v_t = ggml_view_4d(ctx, v_tokens, 1, S_v, H_k, n_seqs, v_tokens->nb[1], v_tokens->nb[2], v_tokens->nb[3], i * ggml_element_size(v_tokens));
|
|
|
+ g_t_exp = ggml_view_4d(ctx, g_tokens_exp, 1, 1, H_k, n_seqs, g_tokens_exp->nb[1], g_tokens_exp->nb[2], g_tokens_exp->nb[3], i * ggml_element_size(g_tokens_exp));
|
|
|
+ beta_t = ggml_view_4d(ctx, beta_tokens, 1, 1, H_k, n_seqs, beta_tokens->nb[1], beta_tokens->nb[2], beta_tokens->nb[3], i * ggml_element_size(beta_tokens));
|
|
|
+ }
|
|
|
|
|
|
- // Apply exponential to gate: exp(g)
|
|
|
- ggml_tensor * g_exp = ggml_exp(ctx, g_t);
|
|
|
- cb(g_exp, "g_exp", il);
|
|
|
+ // Apply gate to state: state = state * exp(g)
|
|
|
+ ggml_tensor * gated_state = ggml_mul(ctx, state, g_t_exp);
|
|
|
+ cb(gated_state, "gated_state", il);
|
|
|
|
|
|
- // Apply gate to state: state = state * exp(g)
|
|
|
- ggml_tensor * gated_state = ggml_mul(ctx, state, g_exp);
|
|
|
- cb(gated_state, "gated_state", il);
|
|
|
+ // Compute kv_memory from state and key
|
|
|
+ // kv_mem = (state * k.unsqueeze(-1)).sum(dim=-2)
|
|
|
+
|
|
|
+ // Reshape gated_state from [S_v, S_v*H_v, 1, n_seqs] to [S_v, S_v, H_v, n_seqs]
|
|
|
+ // to make it compatible with k_expanded for element-wise multiplication
|
|
|
+ ggml_tensor * gated_state_reshaped = ggml_reshape_4d(ctx, gated_state, S_v, S_v, H_v, n_seqs);
|
|
|
+ cb(gated_state_reshaped, "gated_state_reshaped", il);
|
|
|
+
|
|
|
+ ggml_tensor * state_k_product = ggml_mul(ctx, gated_state_reshaped, k_t);
|
|
|
+ cb(state_k_product, "state_k_product", il);
|
|
|
|
|
|
- // Compute kv_memory from state and key
|
|
|
- // kv_mem = (state * k.unsqueeze(-1)).sum(dim=-2)
|
|
|
-
|
|
|
- // Reshape gated_state from [S_v, S_v*H_v, 1, n_seqs] to [S_v, S_v, H_v, n_seqs]
|
|
|
- // to make it compatible with k_expanded for element-wise multiplication
|
|
|
- ggml_tensor * gated_state_reshaped = ggml_reshape_4d(ctx, gated_state, S_v, S_v, H_v, n_seqs);
|
|
|
- cb(gated_state_reshaped, "gated_state_reshaped", il);
|
|
|
-
|
|
|
- ggml_tensor * state_k_product = ggml_mul(ctx, gated_state_reshaped, k_t);
|
|
|
- cb(state_k_product, "state_k_product", il);
|
|
|
+ ggml_tensor * kv_memory = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_k_product)));
|
|
|
+ cb(kv_memory, "kv_memory", il);
|
|
|
|
|
|
- ggml_tensor * kv_memory = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_k_product)));
|
|
|
- cb(kv_memory, "kv_memory", il);
|
|
|
+ // Compute delta = (v - kv_memory) * beta
|
|
|
+ ggml_tensor * v_diff = ggml_sub(ctx, v_t, kv_memory);
|
|
|
+ ggml_tensor * delta = ggml_mul(ctx, v_diff, beta_t);
|
|
|
+ cb(delta, "delta", il);
|
|
|
|
|
|
- // Compute delta = (v - kv_memory) * beta
|
|
|
- ggml_tensor * v_diff = ggml_sub(ctx, v_t, kv_memory);
|
|
|
- ggml_tensor * delta = ggml_mul(ctx, v_diff, beta_t);
|
|
|
- cb(delta, "delta", il);
|
|
|
+ // Update state = state + k * delta
|
|
|
+ // In the reference: last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
|
|
|
+ ggml_tensor * delta_t = ggml_transpose(ctx, delta);
|
|
|
|
|
|
- // Update state = state + k * delta
|
|
|
- // In the reference: last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
|
|
|
- ggml_tensor * delta_t = ggml_transpose(ctx, delta);
|
|
|
+ // Will need to broadcast here since GGML doesn't support auto-double-broadcasting on mul
|
|
|
+ ggml_tensor * delta_t_broadcast = ggml_repeat_4d(ctx, delta_t, S_v, S_v, H_v, n_seqs);
|
|
|
+ ggml_tensor * k_t_broadcast = ggml_repeat_4d(ctx, k_t, S_v, S_v, H_v, n_seqs);
|
|
|
+ ggml_tensor * k_delta_product = ggml_mul(ctx, k_t_broadcast, delta_t_broadcast);
|
|
|
+ cb(k_delta_product, "k_delta", il);
|
|
|
|
|
|
- // Will need to broadcast here since GGML doesn't support auto-double-broadcasting on mul
|
|
|
- ggml_tensor * delta_t_broadcast = ggml_repeat_4d(ctx, delta_t, S_v, S_v, H_v, n_seqs);
|
|
|
- ggml_tensor * k_t_broadcast = ggml_repeat_4d(ctx, k_t, S_v, S_v, H_v, n_seqs);
|
|
|
- ggml_tensor * k_delta_product = ggml_mul(ctx, k_t_broadcast, delta_t_broadcast);
|
|
|
- cb(k_delta_product, "k_delta", il);
|
|
|
+ state = ggml_add(ctx, gated_state_reshaped, k_delta_product);
|
|
|
+ cb(state, "updated_state", il);
|
|
|
|
|
|
- ggml_tensor * updated_state = ggml_add(ctx, gated_state_reshaped, k_delta_product);
|
|
|
- cb(updated_state, "updated_state", il);
|
|
|
-
|
|
|
- ggml_tensor * state_q_product = ggml_mul(ctx, updated_state, q_t);
|
|
|
- cb(state_q_product, "state_q_product", il);
|
|
|
- ggml_tensor * output = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_q_product)));
|
|
|
- cb(output, "output", il);
|
|
|
+ ggml_tensor * state_q_product = ggml_mul(ctx, state, q_t);
|
|
|
+ cb(state_q_product, "state_q_product", il);
|
|
|
+
|
|
|
+ ggml_tensor * output = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_q_product)));
|
|
|
+ cb(output, "output", il);
|
|
|
|
|
|
+ if (final_output == nullptr) {
|
|
|
+ final_output = output;
|
|
|
+ } else {
|
|
|
+ final_output = ggml_concat(ctx, final_output, output, 0);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
// Concatenate output and updated_state into a single tensor
|
|
|
// First, flatten both tensors to 1D
|
|
|
- ggml_tensor * output_1d = ggml_cont_1d(ctx, output, ggml_nelements(output));
|
|
|
- ggml_tensor * updated_state_1d = ggml_cont_1d(ctx, updated_state, ggml_nelements(updated_state));
|
|
|
+ ggml_tensor * output_1d = ggml_cont_1d(ctx, final_output, ggml_nelements(final_output));
|
|
|
+ ggml_tensor * updated_state_1d = ggml_cont_1d(ctx, state, ggml_nelements(state));
|
|
|
|
|
|
// Concatenate them: [output, updated_state]
|
|
|
ggml_tensor * result = ggml_concat(ctx, output_1d, updated_state_1d, 0);
|