|
|
@@ -23,19 +23,7 @@ static ggml_tensor* ggml_conv_1d_dw_f32(
|
|
|
|
|
|
// Reshape the result following ggml_conv_1d_dw: [result->ne[0], result->ne[2], 1]
|
|
|
ggml_tensor* output_3d = ggml_reshape_3d(ctx, mul_result, mul_result->ne[0], mul_result->ne[2], 1);
|
|
|
-
|
|
|
- // Use ggml_permute to reorder dimensions from [length, channels, batch] to [batch, channels, length]
|
|
|
- // Current: [length, channels, batch] - axes 0,1,2
|
|
|
- // Need: [batch, channels, length] - should come from axes 2,1,0
|
|
|
- // ggml_permute(ctx, tensor, axis0, axis1, axis2, axis3) - where axisN specifies which original axis becomes new axis N
|
|
|
- // So to get [length,channels,batch] -> [batch,channels,length], we want: new_dim0=old_dim2, new_dim1=old_dim1, new_dim2=old_dim0
|
|
|
- // This means: permute(2,1,0,3) - new axis 0 comes from old axis 2, new axis 1 from old axis 1, new axis 2 from old axis 0
|
|
|
- ggml_tensor* output_permuted = ggml_permute(ctx, output_3d, 2, 1, 0, 3);
|
|
|
-
|
|
|
- // Use ggml_cont to ensure contiguous layout
|
|
|
- ggml_tensor* output = ggml_cont(ctx, output_permuted);
|
|
|
-
|
|
|
- return output;
|
|
|
+ return output_3d;
|
|
|
}
|
|
|
|
|
|
llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) :
|
|
|
@@ -111,9 +99,9 @@ struct ggml_tensor * llm_build_qwen3next::build_q3n_norm(struct ggml_tensor * in
|
|
|
}
|
|
|
|
|
|
// ggml_delta_net
|
|
|
-struct ggml_tensor * llm_build_qwen3next::ggml_delta_net(struct ggml_tensor * k,
|
|
|
+struct ggml_tensor * llm_build_qwen3next::ggml_delta_net(struct ggml_tensor * q,
|
|
|
+ struct ggml_tensor * k,
|
|
|
struct ggml_tensor * v,
|
|
|
- struct ggml_tensor * q,
|
|
|
struct ggml_tensor * g,
|
|
|
struct ggml_tensor * beta,
|
|
|
struct ggml_tensor * state,
|
|
|
@@ -127,6 +115,13 @@ struct ggml_tensor * llm_build_qwen3next::ggml_delta_net(struct ggml_tensor * k,
|
|
|
GGML_ASSERT(ggml_is_contiguous(beta));
|
|
|
GGML_ASSERT(ggml_is_contiguous(state));
|
|
|
|
|
|
+ cb(k, "k_delta_in", il);
|
|
|
+ cb(v, "v_delta_in", il);
|
|
|
+ cb(q, "q_delta_in", il);
|
|
|
+ cb(g, "g_delta_in", il);
|
|
|
+ cb(beta, "beta_delta_in", il);
|
|
|
+ cb(state, "state_delta_in", il);
|
|
|
+
|
|
|
const int64_t S_k = k->ne[0];
|
|
|
const int64_t H_k = k->ne[1];
|
|
|
const int64_t n_tokens = k->ne[2];
|
|
|
@@ -137,7 +132,7 @@ struct ggml_tensor * llm_build_qwen3next::ggml_delta_net(struct ggml_tensor * k,
|
|
|
|
|
|
GGML_ASSERT(v->ne[2] == n_tokens);
|
|
|
GGML_ASSERT(q->ne[2] == n_tokens);
|
|
|
- GGML_ASSERT(beta->ne[0] == H_v && beta->ne[1] == n_tokens && beta->ne[3] == n_seqs);
|
|
|
+ GGML_ASSERT(beta->ne[0] == H_v && beta->ne[1] == n_tokens && beta->ne[2] == n_seqs);
|
|
|
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == n_seqs && state->ne[3] == n_tokens);
|
|
|
|
|
|
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens);
|
|
|
@@ -228,13 +223,11 @@ struct ggml_tensor * llm_build_qwen3next::ggml_delta_net_op(struct ggml_tensor *
|
|
|
struct ggml_tensor * kv_mem_presum = ggml_mul(ctx0, state_decay, k);
|
|
|
|
|
|
// Gotta do some squeezing here...
|
|
|
- struct ggml_tensor * kv_mem_presum_squeeze = ggml_reshape_4d(ctx0, kv_mem_presum, S_v, S_v, H_v, n_seq * n_tokens);
|
|
|
-
|
|
|
- struct ggml_tensor * kv_mem = ggml_permute(
|
|
|
- ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, kv_mem_presum_squeeze, 1, 2, 0, 3))), 2, 0, 1, 3);
|
|
|
+ struct ggml_tensor * kv_mem_presum_squeeze = ggml_cont_4d(ctx0, kv_mem_presum, S_v, S_v, H_v, n_seq * n_tokens);
|
|
|
+ struct ggml_tensor * kv_mem = ggml_permute(ctx0, ggml_sum_rows(ctx0, kv_mem_presum_squeeze), 3, 0, 1, 2);
|
|
|
cb(kv_mem, "kv_mem", il);
|
|
|
- struct ggml_tensor * kv_mem_reshape = ggml_reshape_4d(ctx0, kv_mem, S_v, S_v, n_seq, n_tokens);
|
|
|
- struct ggml_tensor * delta = ggml_mul(ctx0, ggml_sub(ctx0, kv_mem_reshape, v), beta);
|
|
|
+ struct ggml_tensor * kv_mem_reshape = ggml_reshape_4d(ctx0, kv_mem, S_v, H_v, n_seq, n_tokens);
|
|
|
+ struct ggml_tensor * delta = ggml_mul(ctx0, ggml_sub(ctx0, v, kv_mem_reshape), beta);
|
|
|
cb(delta, "delta", il);
|
|
|
struct ggml_tensor * delta_kt = ggml_mul(ctx0, delta, k);
|
|
|
cb(delta_kt, "delta_kt", il);
|
|
|
@@ -456,16 +449,20 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
// Apply convolution
|
|
|
ggml_tensor * conv_output = ggml_conv_1d_dw_f32(ctx0, conv_kernel, conv_input, 1, conv_kernel_size - 1, n_seqs);
|
|
|
cb(conv_output, "conv_output_raw", il);
|
|
|
- conv_output = ggml_permute(ctx0, conv_output, 0, 1, 3, 2);
|
|
|
|
|
|
- // Take only the values slice - offset the size of the convolution states
|
|
|
- ggml_tensor * conv_output_proper = ggml_view_4d(ctx0, conv_output, conv_output->ne[0], conv_output->ne[1], conv_output->ne[2], n_tokens * n_seqs,
|
|
|
+ // Remove the padding
|
|
|
+ ggml_tensor * conv_output_no_padding = ggml_view_4d(ctx0, conv_output, conv_output->ne[0] - (conv_kernel_size - 1), conv_output->ne[1], conv_output->ne[2], conv_output->ne[3],
|
|
|
conv_output->nb[1], conv_output->nb[2], conv_output->nb[3],
|
|
|
- conv_output->ne[0] * conv_output->ne[1] * conv_output->ne[2] *
|
|
|
- (conv_output->ne[3] - (n_tokens * n_seqs)) * ggml_element_size(conv_output));
|
|
|
+ (conv_kernel_size - 1) * ggml_element_size(conv_output));
|
|
|
+ cb(conv_output_no_padding, "conv_output_no_padding", il);
|
|
|
+
|
|
|
+ // Take only the first (n_tokens * n_seqs) values
|
|
|
+ ggml_tensor * conv_output_proper = ggml_view_4d(ctx0, conv_output_no_padding, n_tokens * n_seqs, conv_output_no_padding->ne[1], conv_output_no_padding->ne[2], conv_output_no_padding->ne[3],
|
|
|
+ conv_output_no_padding->nb[1], conv_output_no_padding->nb[2], conv_output_no_padding->nb[3], 0);
|
|
|
cb(conv_output_proper, "conv_output_proper", il);
|
|
|
|
|
|
- conv_output_proper = ggml_reshape_4d(ctx0, conv_output_proper, qkv_dim, 1, n_tokens, n_seqs);
|
|
|
+ conv_output_proper = ggml_permute(ctx0, conv_output_proper, 0, 1, 3, 2);
|
|
|
+ conv_output_proper = ggml_cont_4d(ctx0, conv_output_proper, qkv_dim, 1, n_tokens, n_seqs);
|
|
|
|
|
|
ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
|
|
|
cb(conv_output_silu, "conv_output_silu", il);
|
|
|
@@ -483,26 +480,30 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
ggml_element_size(conv_states_all))));
|
|
|
cb(conv_states_all, "conv_states_updated", il);
|
|
|
|
|
|
- // Reshape conv_output back to proper dimensions
|
|
|
- conv_output_proper = ggml_cont_4d(ctx0, conv_output_silu, qkv_dim, n_seqs, n_seq_tokens, 1);
|
|
|
- cb(conv_output_proper, "conv_output_reshaped", il);
|
|
|
- conv_output_proper = ggml_permute(ctx0, conv_output_proper, 0, 2, 1, 3);
|
|
|
+ conv_output_proper = ggml_reshape_2d(ctx0, conv_output_silu, n_tokens * n_seqs, qkv_dim);
|
|
|
cb(conv_output_proper, "conv_output_final", il);
|
|
|
|
|
|
+ ggml_tensor * conv_transposed = ggml_transpose(ctx0, conv_output_proper);
|
|
|
+ cb(conv_transposed, "conv_transposed", il);
|
|
|
+
|
|
|
+ ggml_tensor * conv_qkv_mix = ggml_cont_2d(ctx0, conv_transposed, qkv_dim, n_tokens * n_seqs);
|
|
|
+ cb(conv_qkv_mix, "conv_qkv_mix", il);
|
|
|
+
|
|
|
// Extract the convolved Q, K, V from conv_output
|
|
|
- ggml_tensor * q_conv = ggml_cont_4d(ctx0, ggml_view_4d(ctx0, conv_output_proper, head_k_dim * num_k_heads, 1, n_tokens, n_seqs,
|
|
|
- conv_output_proper->nb[1], conv_output_proper->nb[2], conv_output_proper->nb[3], 0), head_k_dim, num_k_heads, n_tokens, n_seqs);
|
|
|
+ ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_tokens * n_seqs,
|
|
|
+ conv_qkv_mix->nb[1], 0);
|
|
|
cb(q_conv, "q_conv", il);
|
|
|
- ggml_tensor * k_conv = ggml_cont_4d(ctx0, ggml_view_4d(ctx0, conv_output_proper, head_k_dim * num_k_heads, 1, n_tokens, n_seqs,
|
|
|
- conv_output_proper->nb[1], conv_output_proper->nb[2], conv_output_proper->nb[3], head_k_dim * num_k_heads * ggml_element_size(conv_output_proper)),
|
|
|
- head_k_dim, num_k_heads, n_tokens, n_seqs);
|
|
|
- cb(q_conv, "k_conv", il);
|
|
|
- ggml_tensor * v_conv = ggml_cont_4d(ctx0, ggml_view_4d(ctx0, conv_output_proper, head_v_dim, num_v_heads, n_tokens, n_seqs,
|
|
|
- conv_output_proper->nb[1], conv_output_proper->nb[2], conv_output_proper->nb[3], 2 * head_k_dim * num_k_heads * ggml_element_size(conv_output_proper)),
|
|
|
- head_v_dim, num_v_heads, n_tokens, n_seqs);
|
|
|
- cb(q_conv, "v_conv", il);
|
|
|
-
|
|
|
- ggml_build_forward_expand(gf, ssm_states_all);
|
|
|
+ ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_tokens * n_seqs,
|
|
|
+ conv_qkv_mix->nb[1], head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
|
|
|
+ cb(k_conv, "k_conv", il);
|
|
|
+ ggml_tensor * v_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_tokens * n_seqs,
|
|
|
+ conv_qkv_mix->nb[1], 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
|
|
|
+ cb(v_conv, "v_conv", il);
|
|
|
+
|
|
|
+ // Unsqueeze them
|
|
|
+ q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_tokens, n_seqs);
|
|
|
+ k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_tokens, n_seqs);
|
|
|
+ v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_tokens, n_seqs);
|
|
|
|
|
|
// Beta tensor
|
|
|
beta = ggml_reshape_3d(ctx0, beta, n_heads, n_tokens, n_seqs);
|
|
|
@@ -514,7 +515,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
gate = ggml_repeat(ctx0, gate_broadcast, target_gate);
|
|
|
|
|
|
// Call the new ggml_delta_net function with the corrected flow
|
|
|
- ggml_tensor * output = ggml_delta_net(k_conv, v_conv, q_conv, gate, beta, state_broadcast, true, 1.0f, il);
|
|
|
+ ggml_tensor * output = ggml_delta_net(q_conv, k_conv, v_conv, gate, beta, state_broadcast, true, 1.0f, il);
|
|
|
cb(q_conv, "delta_output", il);
|
|
|
|
|
|
// Extract the output part
|
|
|
@@ -548,13 +549,13 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
// Apply gated normalization: self.norm(core_attn_out, z)
|
|
|
// This is Qwen3NextRMSNormGated which applies: RMSNorm(x) * silu(gate)
|
|
|
ggml_tensor * attn_out_norm = build_norm(attn_out_2d, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
|
|
|
- cb(output, "attn_out_norm", il);
|
|
|
+ cb(attn_out_norm, "attn_out_norm", il);
|
|
|
|
|
|
// Apply silu gate: attn_out_norm * silu(z_2d)
|
|
|
ggml_tensor * z_silu = ggml_silu(ctx0, z_2d);
|
|
|
- cb(output, "z_silu", il);
|
|
|
+ cb(z_silu, "z_silu", il);
|
|
|
ggml_tensor * gated_output = ggml_mul(ctx0, attn_out_norm, z_silu);
|
|
|
- cb(output, "gated_output", il);
|
|
|
+ cb(gated_output, "gated_output", il);
|
|
|
|
|
|
// Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
|
|
|
ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, gated_output, head_dim, n_heads, n_tokens, n_seqs);
|
|
|
@@ -569,7 +570,6 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
|
|
|
// Reshape back to original dimensions
|
|
|
cur = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens));
|
|
|
-
|
|
|
return cur;
|
|
|
}
|
|
|
|