|
|
@@ -382,6 +382,144 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
+// delta_net_recurrent
|
|
|
+// Recurrent version of delta_net for sequence_length = 1
|
|
|
+struct ggml_tensor * llm_build_qwen3next::delta_net_recurrent(
|
|
|
+ struct ggml_context * ctx,
|
|
|
+ struct ggml_tensor * q,
|
|
|
+ struct ggml_tensor * k,
|
|
|
+ struct ggml_tensor * v,
|
|
|
+ struct ggml_tensor * g,
|
|
|
+ struct ggml_tensor * beta,
|
|
|
+ struct ggml_tensor * state,
|
|
|
+ bool use_qk_l2norm,
|
|
|
+ float eps_norm,
|
|
|
+ const int il
|
|
|
+ ) {
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(q));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(k));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(v));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(g));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(beta));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(state));
|
|
|
+
|
|
|
+ const int64_t S_k = q->ne[0];
|
|
|
+ const int64_t H_k = q->ne[1];
|
|
|
+ const int64_t n_tokens = q->ne[2];
|
|
|
+ const int64_t n_seqs = q->ne[3];
|
|
|
+
|
|
|
+ const int64_t S_v = v->ne[0];
|
|
|
+ const int64_t H_v = v->ne[1];
|
|
|
+
|
|
|
+ GGML_ASSERT(n_tokens == 1); // Recurrent version only supports sequence_length = 1
|
|
|
+ GGML_ASSERT(v->ne[2] == n_tokens);
|
|
|
+ GGML_ASSERT(k->ne[2] == n_tokens);
|
|
|
+ GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
|
|
+ GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
|
|
|
+ GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
|
|
|
+
|
|
|
+ GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
|
|
+ GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
|
|
+
|
|
|
+ GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
|
|
|
+
|
|
|
+ cb(q, "q_prenorm", il);
|
|
|
+ cb(k, "k_prenorm", il);
|
|
|
+
|
|
|
+ if (use_qk_l2norm) {
|
|
|
+ q = ggml_l2_norm(ctx, q, eps_norm);
|
|
|
+ k = ggml_l2_norm(ctx, k, eps_norm);
|
|
|
+ }
|
|
|
+
|
|
|
+ cb(k, "k_postnorm", il);
|
|
|
+ cb(q, "q_prescale", il);
|
|
|
+
|
|
|
+ float scale = 1.0f / sqrtf(S_v);
|
|
|
+ q = ggml_scale(ctx, q, scale);
|
|
|
+
|
|
|
+ cb(beta, "beta_raw", il);
|
|
|
+ beta = ggml_sigmoid(ctx, beta);
|
|
|
+
|
|
|
+ cb(q, "q_postscale", il);
|
|
|
+ cb(beta, "beta_sigmoid", il);
|
|
|
+
|
|
|
+ // Reshape tensors for recurrent computation
|
|
|
+ // From [S_k, H_k, n_tokens, n_seqs] to [S_k, n_tokens, H_k, n_seqs]
|
|
|
+ q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3));
|
|
|
+ cb(q, "q_reshape", il);
|
|
|
+ k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3));
|
|
|
+ cb(k, "k_reshape", il);
|
|
|
+ v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3));
|
|
|
+ cb(v, "v_reshape", il);
|
|
|
+
|
|
|
+ beta = ggml_cont(ctx, ggml_permute(ctx, beta, 1, 2, 0, 3));
|
|
|
+ cb(beta, "beta_reshape", il);
|
|
|
+
|
|
|
+ g = ggml_cont(ctx, ggml_permute(ctx, g, 2, 0, 3, 1));
|
|
|
+ cb(g, "g_permute", il);
|
|
|
+
|
|
|
+ ggml_tensor * q_t = ggml_cont_4d(ctx, q, 1, S_k, H_k, n_seqs);
|
|
|
+ ggml_tensor * k_t = ggml_cont_4d(ctx, k, 1, S_k, H_k, n_seqs);
|
|
|
+ ggml_tensor * v_t = ggml_cont_4d(ctx, v, 1, S_v, H_k, n_seqs);
|
|
|
+ ggml_tensor * g_t = ggml_cont_4d(ctx, g, 1, 1, H_k, n_seqs);
|
|
|
+ ggml_tensor * beta_t = ggml_cont_4d(ctx, beta, 1, 1, H_k, n_seqs);
|
|
|
+ state = ggml_cont_4d(ctx, state, S_v, S_v, H_k, n_seqs);
|
|
|
+
|
|
|
+ // Apply exponential to gate: exp(g)
|
|
|
+ ggml_tensor * g_exp = ggml_exp(ctx, g_t);
|
|
|
+ cb(g_exp, "g_exp", il);
|
|
|
+
|
|
|
+ // Apply gate to state: state = state * exp(g)
|
|
|
+ ggml_tensor * gated_state = ggml_mul(ctx, state, g_exp);
|
|
|
+ cb(gated_state, "gated_state", il);
|
|
|
+
|
|
|
+ // Compute kv_memory from state and key
|
|
|
+ // kv_mem = (state * k.unsqueeze(-1)).sum(dim=-2)
|
|
|
+
|
|
|
+ // Reshape gated_state from [S_v, S_v*H_v, 1, n_seqs] to [S_v, S_v, H_v, n_seqs]
|
|
|
+ // to make it compatible with k_expanded for element-wise multiplication
|
|
|
+ ggml_tensor * gated_state_reshaped = ggml_reshape_4d(ctx, gated_state, S_v, S_v, H_v, n_seqs);
|
|
|
+ cb(gated_state_reshaped, "gated_state_reshaped", il);
|
|
|
+
|
|
|
+ ggml_tensor * state_k_product = ggml_mul(ctx, gated_state_reshaped, k_t);
|
|
|
+ cb(state_k_product, "state_k_product", il);
|
|
|
+
|
|
|
+ ggml_tensor * kv_memory = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_k_product)));
|
|
|
+ cb(kv_memory, "kv_memory", il);
|
|
|
+
|
|
|
+ // Compute delta = (v - kv_memory) * beta
|
|
|
+ ggml_tensor * v_diff = ggml_sub(ctx, v_t, kv_memory);
|
|
|
+ ggml_tensor * delta = ggml_mul(ctx, v_diff, beta_t);
|
|
|
+ cb(delta, "delta", il);
|
|
|
+
|
|
|
+ // Update state = state + k * delta
|
|
|
+ // In the reference: last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
|
|
|
+ ggml_tensor * delta_t = ggml_transpose(ctx, delta);
|
|
|
+
|
|
|
+ // Will need to broadcast here since GGML doesn't support auto-double-broadcasting on mul
|
|
|
+ ggml_tensor * delta_t_broadcast = ggml_repeat_4d(ctx, delta_t, S_v, S_v, H_v, n_seqs);
|
|
|
+ ggml_tensor * k_t_broadcast = ggml_repeat_4d(ctx, k_t, S_v, S_v, H_v, n_seqs);
|
|
|
+ ggml_tensor * k_delta_product = ggml_mul(ctx, k_t_broadcast, delta_t_broadcast);
|
|
|
+ cb(k_delta_product, "k_delta", il);
|
|
|
+
|
|
|
+ ggml_tensor * updated_state = ggml_add(ctx, gated_state_reshaped, k_delta_product);
|
|
|
+ cb(updated_state, "updated_state", il);
|
|
|
+
|
|
|
+ ggml_tensor * state_q_product = ggml_mul(ctx, updated_state, q_t);
|
|
|
+ cb(state_q_product, "state_q_product", il);
|
|
|
+ ggml_tensor * output = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_q_product)));
|
|
|
+ cb(output, "output", il);
|
|
|
+
|
|
|
+ // Concatenate output and updated_state into a single tensor
|
|
|
+ // First, flatten both tensors to 1D
|
|
|
+ ggml_tensor * output_1d = ggml_cont_1d(ctx, output, ggml_nelements(output));
|
|
|
+ ggml_tensor * updated_state_1d = ggml_cont_1d(ctx, updated_state, ggml_nelements(updated_state));
|
|
|
+
|
|
|
+ // Concatenate them: [output, updated_state]
|
|
|
+ ggml_tensor * result = ggml_concat(ctx, output_1d, updated_state_1d, 0);
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
|
|
|
ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_input_rs * inp,
|
|
|
ggml_tensor * cur,
|
|
|
@@ -402,6 +540,8 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
|
|
|
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
|
|
|
|
|
+ const auto kv_head = mctx_cur->get_head();
|
|
|
+
|
|
|
GGML_ASSERT(n_seqs != 0);
|
|
|
GGML_ASSERT(ubatch.equal_seqs());
|
|
|
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
|
|
@@ -494,6 +634,8 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
|
|
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
|
|
|
|
|
+ bool is_generation = mctx_cur->get_rs_z() < 0;
|
|
|
+
|
|
|
// Build the convolution states tensor
|
|
|
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
|
|
|
cb(conv_states, "conv_states", il);
|
|
|
@@ -528,7 +670,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
cb(last_conv_states, "last_conv_states", il);
|
|
|
|
|
|
ggml_tensor * state_update_target = ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
|
|
|
- mctx_cur->get_head() * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
|
|
|
+ kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
|
|
|
cb(state_update_target, "state_update_target", il);
|
|
|
|
|
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
|
|
|
@@ -584,6 +726,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
|
|
|
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
|
|
|
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
|
|
|
+ cb(state, "state_predelta", il);
|
|
|
|
|
|
// if head keys and value keys are different, repeat to force tensors into matching shapes
|
|
|
if (num_k_heads != num_v_heads) {
|
|
|
@@ -598,8 +741,15 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
cb(k_conv, "k_conv_predelta", il);
|
|
|
cb(v_conv, "v_conv_predelta", il);
|
|
|
|
|
|
- // Call the new delta_net function with the corrected flow
|
|
|
- ggml_tensor * attn_out = delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, hparams.f_norm_rms_eps, il);
|
|
|
+ // Choose between delta_net and delta_net_recurrent based on generation mode
|
|
|
+ ggml_tensor * attn_out;
|
|
|
+ if (is_generation) {
|
|
|
+ // Use delta_net_recurrent for single token generation
|
|
|
+ attn_out = delta_net_recurrent(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, hparams.f_norm_rms_eps, il);
|
|
|
+ } else {
|
|
|
+ // Use regular delta_net for prompt processing
|
|
|
+ attn_out = delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, hparams.f_norm_rms_eps, il);
|
|
|
+ }
|
|
|
cb(attn_out, "attn_out", il);
|
|
|
|
|
|
// The tensors were concatenated 1d, so we need to extract them 1d as well
|
|
|
@@ -621,7 +771,9 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
// Update the recurrent states
|
|
|
ggml_build_forward_expand(gf,
|
|
|
ggml_cpy(ctx0, state_1d, ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
|
|
|
- hparams.n_embd_s() * mctx_cur->get_head() * ggml_element_size(ssm_states_all))));
|
|
|
+ kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
|
|
|
+
|
|
|
+ GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out));
|
|
|
|
|
|
// Reshape both attn_out_final and z to 2D tensors for normalization
|
|
|
// attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
|