|
|
@@ -386,14 +386,13 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
const auto * mctx_cur = inp->mctx;
|
|
|
|
|
|
const int64_t d_inner = hparams.ssm_d_inner;
|
|
|
- const int64_t n_heads = hparams.ssm_dt_rank;
|
|
|
- const int64_t head_dim = d_inner / n_heads;
|
|
|
+
|
|
|
const int64_t n_seqs = ubatch.n_seqs;
|
|
|
|
|
|
const int64_t head_k_dim = hparams.ssm_d_state;
|
|
|
- const int64_t head_v_dim = hparams.ssm_d_state;
|
|
|
const int64_t num_k_heads = hparams.ssm_n_group;
|
|
|
const int64_t num_v_heads = hparams.ssm_dt_rank;
|
|
|
+ const int64_t head_v_dim = d_inner / num_v_heads;
|
|
|
|
|
|
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
|
|
|
|
|
@@ -408,7 +407,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
|
|
|
cb(mixed_ba, "linear_attn_mixed_ba", il);
|
|
|
|
|
|
- int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads;
|
|
|
+ int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
|
|
|
ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
|
|
|
|
|
|
// Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
|
|
|
@@ -441,15 +440,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus
|
|
|
cb(gate, "gate", il);
|
|
|
|
|
|
- // Get convolution states from cache
|
|
|
- ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
|
|
- ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
|
|
-
|
|
|
- // Build the convolution states tensor
|
|
|
- ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
|
|
|
- cb(conv_states, "conv_states", il);
|
|
|
-
|
|
|
- // Split mixed_qkvz into query, key, value, z
|
|
|
+ // Split mixed_qkvz into query, key, value, z
|
|
|
int64_t split_sizes_qkvz[4] = {
|
|
|
head_k_dim, // query size
|
|
|
head_k_dim, // key size
|
|
|
@@ -457,47 +448,50 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
head_v_dim * num_v_heads / num_k_heads // z size
|
|
|
};
|
|
|
|
|
|
- ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
|
|
|
- mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0));
|
|
|
+ ggml_tensor * query = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
|
|
|
+ mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
|
|
|
cb(query, "q", il);
|
|
|
|
|
|
- ggml_tensor * key = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
|
|
|
+ ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
|
|
|
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
|
|
|
- split_sizes_qkvz[0] * sizeof(float)));
|
|
|
+ split_sizes_qkvz[0] * sizeof(float));
|
|
|
cb(key, "k", il);
|
|
|
|
|
|
- ggml_tensor * value = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
|
|
|
+ ggml_tensor * value = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
|
|
|
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
|
|
|
- (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)));
|
|
|
+ (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float));
|
|
|
cb(value, "v", il);
|
|
|
|
|
|
- ggml_tensor * z = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
|
|
|
+ ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
|
|
|
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
|
|
|
- (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)));
|
|
|
+ (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
|
|
|
cb(z, "z", il);
|
|
|
|
|
|
- // Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
|
|
|
- ggml_tensor * value_reshaped =
|
|
|
- ggml_reshape_4d(ctx0, ggml_cont(ctx0, value), head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
|
|
- ggml_tensor * z_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, z), head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
|
|
-
|
|
|
- GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value_reshaped) +
|
|
|
- ggml_nelements(z_reshaped) ==
|
|
|
+ GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value) +
|
|
|
+ ggml_nelements(z) ==
|
|
|
ggml_nelements(mixed_qkvz));
|
|
|
|
|
|
// After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
|
|
|
// query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
|
|
|
- ggml_tensor * query_flat = ggml_reshape_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
|
|
|
+ ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
|
|
|
cb(query_flat, "query_flat", il);
|
|
|
|
|
|
// key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
|
|
|
- ggml_tensor * key_flat = ggml_reshape_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
|
|
|
+ ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
|
|
|
cb(key_flat, "key_flat", il);
|
|
|
|
|
|
// value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
|
|
|
- ggml_tensor * value_flat = ggml_reshape_3d(ctx0, value_reshaped, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
|
|
|
+ ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
|
|
|
cb(value_flat, "value_flat", il);
|
|
|
|
|
|
+ // Get convolution states from cache
|
|
|
+ ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
|
|
+ ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
|
|
+
|
|
|
+ // Build the convolution states tensor
|
|
|
+ ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
|
|
|
+ cb(conv_states, "conv_states", il);
|
|
|
+
|
|
|
// Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
|
|
|
ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
|
|
|
qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
|
|
|
@@ -578,7 +572,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
|
|
|
|
|
|
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
|
|
|
- state = ggml_reshape_4d(ctx0, state, head_dim, head_dim * n_heads, 1, n_seqs);
|
|
|
+ state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
|
|
|
|
|
|
// if head keys and value keys are different, repeat to force tensors into matching shapes
|
|
|
if (num_k_heads != num_v_heads) {
|
|
|
@@ -598,17 +592,17 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
cb(attn_out, "attn_out", il);
|
|
|
|
|
|
// The tensors were concatenated 1d, so we need to extract them 1d as well
|
|
|
- const int64_t output_flat_size = head_dim * n_heads * n_seq_tokens * n_seqs;
|
|
|
+ const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs;
|
|
|
ggml_tensor * attn_out_1d =
|
|
|
ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
|
|
|
cb(attn_out_1d, "attn_out_1d", il);
|
|
|
|
|
|
- ggml_tensor * attn_out_final = ggml_cont(ctx0, ggml_permute(ctx0, ggml_cont_4d(ctx0, attn_out_1d, head_dim, n_seq_tokens, n_heads, n_seqs), 0, 2, 1, 3));
|
|
|
+ ggml_tensor * attn_out_final = ggml_cont(ctx0, ggml_permute(ctx0, ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, n_seq_tokens, num_v_heads, n_seqs), 0, 2, 1, 3));
|
|
|
cb(attn_out_final, "attn_out_final", il);
|
|
|
|
|
|
// Extract the state part (second part of the concatenated tensor)
|
|
|
// State starts after n_tokens elements along dimension 1
|
|
|
- const int64_t state_flat_size = head_dim * head_dim * n_heads * n_seqs;
|
|
|
+ const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs;
|
|
|
|
|
|
ggml_tensor * state_1d = ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out));
|
|
|
cb(state_1d, "state_1d", il);
|
|
|
@@ -620,19 +614,19 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
|
|
|
// Reshape both attn_out_final and z to 2D tensors for normalization
|
|
|
// attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
|
|
|
- ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out_final), head_dim, n_heads * n_seq_tokens * n_seqs);
|
|
|
+ ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out_final), head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
|
|
|
|
|
|
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
|
|
|
- ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_seq_tokens * n_seqs);
|
|
|
+ ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
|
|
|
|
|
|
// Apply gated normalization: self.norm(core_attn_out, z)
|
|
|
ggml_tensor * attn_out_norm = build_q3n_gated_norm(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
|
|
|
|
|
|
// Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
|
|
|
- ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, attn_out_norm, head_dim, n_heads, n_seq_tokens, n_seqs);
|
|
|
+ ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, attn_out_norm, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
|
|
|
|
|
// Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
|
|
|
- ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_seq_tokens, n_seqs);
|
|
|
+ ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
|
|
|
cb(final_output, "final_output", il);
|
|
|
|
|
|
// Output projection
|