|
|
@@ -374,7 +374,6 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
const int64_t num_v_heads = hparams.ssm_dt_rank;
|
|
|
|
|
|
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
|
|
- const int64_t n_tokens = ubatch.n_tokens;
|
|
|
|
|
|
GGML_ASSERT(n_seqs != 0);
|
|
|
GGML_ASSERT(ubatch.equal_seqs());
|
|
|
@@ -388,11 +387,11 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
cb(mixed_ba, "linear_attn_mixed_ba", il);
|
|
|
|
|
|
int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads;
|
|
|
- ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tokens, n_seqs);
|
|
|
+ ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
|
|
|
|
|
|
// Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
|
|
|
int64_t ba_new_dim = 2 * num_v_heads / num_k_heads;
|
|
|
- ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tokens, n_seqs);
|
|
|
+ ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs);
|
|
|
|
|
|
// Split mixed_ba into b and a (beta and alpha parameters)
|
|
|
int64_t split_sizes_ba[2] = {
|
|
|
@@ -400,18 +399,18 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
num_v_heads / num_k_heads // alpha size
|
|
|
};
|
|
|
|
|
|
- ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tokens, n_seqs,
|
|
|
+ ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_seq_tokens, n_seqs,
|
|
|
mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0);
|
|
|
cb(b, "b", il);
|
|
|
|
|
|
- ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tokens, n_seqs,
|
|
|
+ ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_seq_tokens, n_seqs,
|
|
|
mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3],
|
|
|
split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
|
|
|
cb(a, "a", il);
|
|
|
|
|
|
// Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
|
|
|
- ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_tokens, n_seqs);
|
|
|
- ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_tokens, n_seqs);
|
|
|
+ ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs);
|
|
|
+ ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
|
|
|
|
|
|
GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
|
|
|
|
|
|
@@ -436,29 +435,29 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
head_v_dim * num_v_heads / num_k_heads // z size
|
|
|
};
|
|
|
|
|
|
- ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_tokens, n_seqs,
|
|
|
+ ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
|
|
|
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0));
|
|
|
cb(query, "q", il);
|
|
|
|
|
|
- ggml_tensor * key = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, n_seqs,
|
|
|
+ ggml_tensor * key = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
|
|
|
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
|
|
|
split_sizes_qkvz[0] * sizeof(float)));
|
|
|
cb(key, "k", il);
|
|
|
|
|
|
- ggml_tensor * value = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tokens, n_seqs,
|
|
|
+ ggml_tensor * value = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
|
|
|
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
|
|
|
(split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)));
|
|
|
cb(value, "v", il);
|
|
|
|
|
|
- ggml_tensor * z = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tokens, n_seqs,
|
|
|
+ ggml_tensor * z = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
|
|
|
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
|
|
|
(split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)));
|
|
|
cb(z, "z", il);
|
|
|
|
|
|
// Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
|
|
|
ggml_tensor * value_reshaped =
|
|
|
- ggml_reshape_4d(ctx0, ggml_cont(ctx0, value), head_v_dim, num_v_heads, n_tokens, n_seqs);
|
|
|
- ggml_tensor * z_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, z), head_v_dim, num_v_heads, n_tokens, n_seqs);
|
|
|
+ ggml_reshape_4d(ctx0, ggml_cont(ctx0, value), head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
|
|
+ ggml_tensor * z_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, z), head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
|
|
|
|
|
GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value_reshaped) +
|
|
|
ggml_nelements(z_reshaped) ==
|
|
|
@@ -466,15 +465,15 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
|
|
|
// After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
|
|
|
// query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
|
|
|
- ggml_tensor * query_flat = ggml_reshape_3d(ctx0, query, head_k_dim * num_k_heads, n_tokens, n_seqs);
|
|
|
+ ggml_tensor * query_flat = ggml_reshape_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
|
|
|
cb(query_flat, "query_flat", il);
|
|
|
|
|
|
// key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
|
|
|
- ggml_tensor * key_flat = ggml_reshape_3d(ctx0, key, head_k_dim * num_k_heads, n_tokens, n_seqs);
|
|
|
+ ggml_tensor * key_flat = ggml_reshape_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
|
|
|
cb(key_flat, "key_flat", il);
|
|
|
|
|
|
// value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
|
|
|
- ggml_tensor * value_flat = ggml_reshape_3d(ctx0, value_reshaped, head_v_dim * num_v_heads, n_tokens, n_seqs);
|
|
|
+ ggml_tensor * value_flat = ggml_reshape_3d(ctx0, value_reshaped, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
|
|
|
cb(value_flat, "value_flat", il);
|
|
|
|
|
|
// Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
|
|
|
@@ -512,7 +511,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
cb(conv_output_proper, "conv_output_proper", il);
|
|
|
|
|
|
conv_output_proper = ggml_permute(ctx0, conv_output_proper, 0, 1, 3, 2);
|
|
|
- conv_output_proper = ggml_cont_4d(ctx0, conv_output_proper, qkv_dim, 1, n_tokens, n_seqs);
|
|
|
+ conv_output_proper = ggml_cont_4d(ctx0, conv_output_proper, qkv_dim, 1, n_seq_tokens, n_seqs);
|
|
|
|
|
|
ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
|
|
|
cb(conv_output_silu, "conv_output_silu", il);
|
|
|
@@ -530,32 +529,32 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
ggml_element_size(conv_states_all))));
|
|
|
cb(conv_states_all, "conv_states_updated", il);
|
|
|
|
|
|
- conv_output_proper = ggml_reshape_2d(ctx0, conv_output_silu, n_tokens * n_seqs, qkv_dim);
|
|
|
+ conv_output_proper = ggml_reshape_2d(ctx0, conv_output_silu, n_seq_tokens * n_seqs, qkv_dim);
|
|
|
cb(conv_output_proper, "conv_output_final", il);
|
|
|
|
|
|
ggml_tensor * conv_transposed = ggml_transpose(ctx0, conv_output_proper);
|
|
|
cb(conv_transposed, "conv_transposed", il);
|
|
|
|
|
|
- ggml_tensor * conv_qkv_mix = ggml_cont_2d(ctx0, conv_transposed, qkv_dim, n_tokens * n_seqs);
|
|
|
+ ggml_tensor * conv_qkv_mix = ggml_cont_2d(ctx0, conv_transposed, qkv_dim, n_seq_tokens * n_seqs);
|
|
|
cb(conv_qkv_mix, "conv_qkv_mix", il);
|
|
|
|
|
|
// Extract the convolved Q, K, V from conv_output
|
|
|
- ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_tokens * n_seqs,
|
|
|
+ ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs,
|
|
|
conv_qkv_mix->nb[1], 0);
|
|
|
cb(q_conv, "q_conv", il);
|
|
|
- ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_tokens * n_seqs,
|
|
|
+ ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs,
|
|
|
conv_qkv_mix->nb[1], head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
|
|
|
cb(k_conv, "k_conv", il);
|
|
|
- ggml_tensor * v_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_tokens * n_seqs,
|
|
|
+ ggml_tensor * v_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs,
|
|
|
conv_qkv_mix->nb[1], 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
|
|
|
cb(v_conv, "v_conv", il);
|
|
|
|
|
|
// Unsqueeze them
|
|
|
- q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_tokens, n_seqs);
|
|
|
- k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_tokens, n_seqs);
|
|
|
- v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_tokens, n_seqs);
|
|
|
+ q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
|
|
+ k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
|
|
+ v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
|
|
|
|
|
- beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_tokens, n_seqs);
|
|
|
+ beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
|
|
|
|
|
|
ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_states_all, head_dim, head_dim * n_heads, 1, 1);
|
|
|
|
|
|
@@ -564,8 +563,8 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
GGML_ASSERT(num_v_heads % num_k_heads == 0);
|
|
|
int64_t repeat_factor = num_v_heads / num_k_heads;
|
|
|
|
|
|
- q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_k_heads * repeat_factor, n_tokens, n_seqs);
|
|
|
- k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_k_heads * repeat_factor, n_tokens, n_seqs);
|
|
|
+ q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
|
|
|
+ k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
|
|
|
}
|
|
|
|
|
|
cb(q_conv, "q_conv_predelta", il);
|
|
|
@@ -577,12 +576,12 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
cb(attn_out, "attn_out", il);
|
|
|
|
|
|
// The tensors were concatenated 1d, so we need to extract them 1d as well
|
|
|
- const int64_t output_flat_size = head_dim * n_heads * n_tokens * n_seqs;
|
|
|
+ const int64_t output_flat_size = head_dim * n_heads * n_seq_tokens * n_seqs;
|
|
|
ggml_tensor * attn_out_1d =
|
|
|
ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
|
|
|
cb(attn_out_1d, "attn_out_1d", il);
|
|
|
|
|
|
- ggml_tensor * attn_out_final = ggml_cont(ctx0, ggml_permute(ctx0, ggml_cont_4d(ctx0, attn_out_1d, head_dim, n_tokens, n_heads, n_seqs), 0, 2, 1, 3));
|
|
|
+ ggml_tensor * attn_out_final = ggml_cont(ctx0, ggml_permute(ctx0, ggml_cont_4d(ctx0, attn_out_1d, head_dim, n_seq_tokens, n_heads, n_seqs), 0, 2, 1, 3));
|
|
|
cb(attn_out_final, "attn_out_final", il);
|
|
|
|
|
|
// Extract the state part (second part of the concatenated tensor)
|
|
|
@@ -600,19 +599,19 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
|
|
|
// Reshape both attn_out_final and z to 2D tensors for normalization
|
|
|
// attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
|
|
|
- ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out_final), head_dim, n_heads * n_tokens * n_seqs);
|
|
|
+ ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out_final), head_dim, n_heads * n_seq_tokens * n_seqs);
|
|
|
|
|
|
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
|
|
|
- ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_tokens * n_seqs);
|
|
|
+ ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_seq_tokens * n_seqs);
|
|
|
|
|
|
// Apply gated normalization: self.norm(core_attn_out, z)
|
|
|
ggml_tensor * attn_out_norm = build_q3n_gated_norm(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
|
|
|
|
|
|
// Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
|
|
|
- ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, attn_out_norm, head_dim, n_heads, n_tokens, n_seqs);
|
|
|
+ ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, attn_out_norm, head_dim, n_heads, n_seq_tokens, n_seqs);
|
|
|
|
|
|
// Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
|
|
|
- ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_tokens, n_seqs);
|
|
|
+ ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_seq_tokens, n_seqs);
|
|
|
cb(final_output, "final_output", il);
|
|
|
|
|
|
// Output projection
|
|
|
@@ -620,7 +619,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
cb(cur, "linear_attn_out", il);
|
|
|
|
|
|
// Reshape back to original dimensions
|
|
|
- cur = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens));
|
|
|
+ cur = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens));
|
|
|
return cur;
|
|
|
}
|
|
|
|