|
|
@@ -2,6 +2,42 @@
|
|
|
|
|
|
#include <cmath>
|
|
|
|
|
|
+// Implementation of depthwise 1D convolution using F32 to avoid F16 limitations
|
|
|
+static ggml_tensor* ggml_conv_1d_dw_f32(
|
|
|
+ ggml_context * ctx,
|
|
|
+ ggml_tensor * kernel,
|
|
|
+ ggml_tensor * input,
|
|
|
+ int stride,
|
|
|
+ int padding,
|
|
|
+ int dilation) {
|
|
|
+ // Following the pattern from ggml_conv_1d_dw but using F32
|
|
|
+ // Reshape input from [length, channels, batch, dummy] to [length, 1, channels, batch]
|
|
|
+ ggml_tensor* reshaped_input = ggml_reshape_4d(ctx, input, input->ne[0], 1, input->ne[1], input->ne[2]);
|
|
|
+
|
|
|
+ // Apply im2col with F32 destination type to avoid F16 requirement
|
|
|
+ ggml_tensor* im2col_result = ggml_im2col(ctx, kernel, reshaped_input, stride, 0, padding, 0, dilation, 0, false, GGML_TYPE_F32);
|
|
|
+
|
|
|
+ // Now multiply: im2col_result * kernel (following the exact pattern from ggml_conv_1d_dw)
|
|
|
+ // In ggml_conv_1d_dw: ggml_mul_mat(ctx, im2col, a) where a is the kernel
|
|
|
+ ggml_tensor* mul_result = ggml_mul_mat(ctx, im2col_result, kernel);
|
|
|
+
|
|
|
+ // Reshape the result following ggml_conv_1d_dw: [result->ne[0], result->ne[2], 1]
|
|
|
+ ggml_tensor* output_3d = ggml_reshape_3d(ctx, mul_result, mul_result->ne[0], mul_result->ne[2], 1);
|
|
|
+
|
|
|
+ // Use ggml_permute to reorder dimensions from [length, channels, batch] to [batch, channels, length]
|
|
|
+ // Current: [length, channels, batch] - axes 0,1,2
|
|
|
+ // Need: [batch, channels, length] - should come from axes 2,1,0
|
|
|
+ // ggml_permute(ctx, tensor, axis0, axis1, axis2, axis3) - where axisN specifies which original axis becomes new axis N
|
|
|
+ // So to get [length,channels,batch] -> [batch,channels,length], we want: new_dim0=old_dim2, new_dim1=old_dim1, new_dim2=old_dim0
|
|
|
+ // This means: permute(2,1,0,3) - new axis 0 comes from old axis 2, new axis 1 from old axis 1, new axis 2 from old axis 0
|
|
|
+ ggml_tensor* output_permuted = ggml_permute(ctx, output_3d, 2, 1, 0, 3);
|
|
|
+
|
|
|
+ // Use ggml_cont to ensure contiguous layout
|
|
|
+ ggml_tensor* output = ggml_cont(ctx, output_permuted);
|
|
|
+
|
|
|
+ return output;
|
|
|
+}
|
|
|
+
|
|
|
llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) :
|
|
|
llm_graph_context_mamba(params) {
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
|
@@ -400,18 +436,17 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
|
|
|
// Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
|
|
|
ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
|
|
|
- qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
|
|
|
+ qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
|
|
|
+ qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
|
|
|
cb(qkv_mixed, "qkv_mixed_concatenated", il);
|
|
|
|
|
|
// Calculate the total conv dimension
|
|
|
int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
|
|
|
|
|
|
- // Reshape to [n_tokens, qkv_dim, n_seqs] for proper convolution input format
|
|
|
- qkv_mixed = ggml_cont_3d(ctx0, ggml_transpose(ctx0, qkv_mixed), n_tokens, qkv_dim, n_seqs);
|
|
|
- cb(qkv_mixed, "qkv_mixed_for_conv", il);
|
|
|
-
|
|
|
// Calculate convolution kernel size
|
|
|
- const int64_t conv_kernel_size = model.layers[il].ssm_conv1d->ne[0];
|
|
|
+ ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
|
|
|
+ const int64_t conv_kernel_size = conv_kernel->ne[0];
|
|
|
+ conv_kernel = ggml_permute(ctx0, conv_kernel, 0, 2, 1, 3);
|
|
|
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state, n_seqs);
|
|
|
cb(conv_states, "conv_states_reshaped", il);
|
|
|
|
|
|
@@ -419,15 +454,21 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
cb(conv_input, "conv_input", il);
|
|
|
|
|
|
// Apply convolution
|
|
|
- ggml_tensor * conv_output = ggml_ssm_conv(ctx0, conv_input, model.layers[il].ssm_conv1d);
|
|
|
+ ggml_tensor * conv_output = ggml_conv_1d_dw_f32(ctx0, conv_kernel, conv_input, 1, conv_kernel_size - 1, n_seqs);
|
|
|
cb(conv_output, "conv_output_raw", il);
|
|
|
+ conv_output = ggml_permute(ctx0, conv_output, 0, 1, 3, 2);
|
|
|
|
|
|
- if (model.layers[il].ssm_conv1d_b) {
|
|
|
- conv_output = ggml_add(ctx0, conv_output, model.layers[il].ssm_conv1d_b);
|
|
|
- cb(conv_output, "conv_output_bias", il);
|
|
|
- }
|
|
|
- conv_output = ggml_silu(ctx0, conv_output);
|
|
|
- cb(conv_output, "conv_output_silu", il);
|
|
|
+ // Take only the values slice - offset the size of the convolution states
|
|
|
+ ggml_tensor * conv_output_proper = ggml_view_4d(ctx0, conv_output, conv_output->ne[0], conv_output->ne[1], conv_output->ne[2], n_tokens * n_seqs,
|
|
|
+ conv_output->nb[1], conv_output->nb[2], conv_output->nb[3],
|
|
|
+ conv_output->ne[0] * conv_output->ne[1] * conv_output->ne[2] *
|
|
|
+ (conv_output->ne[3] - (n_tokens * n_seqs)) * ggml_element_size(conv_output));
|
|
|
+ cb(conv_output_proper, "conv_output_proper", il);
|
|
|
+
|
|
|
+ conv_output_proper = ggml_reshape_4d(ctx0, conv_output_proper, qkv_dim, 1, n_tokens, n_seqs);
|
|
|
+
|
|
|
+ ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
|
|
|
+ cb(conv_output_silu, "conv_output_silu", il);
|
|
|
|
|
|
// Update convolution state cache
|
|
|
// Extract the last (conv_kernel_size - 1) states from conv_input
|
|
|
@@ -443,24 +484,22 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
cb(conv_states_all, "conv_states_updated", il);
|
|
|
|
|
|
// Reshape conv_output back to proper dimensions
|
|
|
- conv_output = ggml_reshape_4d(ctx0, conv_output, qkv_dim, n_seqs, n_seq_tokens, 1);
|
|
|
- cb(conv_output, "conv_output_reshaped", il);
|
|
|
- conv_output = ggml_permute(ctx0, conv_output, 0, 2, 1, 3);
|
|
|
- cb(conv_output, "conv_output_final", il);
|
|
|
+ conv_output_proper = ggml_cont_4d(ctx0, conv_output_silu, qkv_dim, n_seqs, n_seq_tokens, 1);
|
|
|
+ cb(conv_output_proper, "conv_output_reshaped", il);
|
|
|
+ conv_output_proper = ggml_permute(ctx0, conv_output_proper, 0, 2, 1, 3);
|
|
|
+ cb(conv_output_proper, "conv_output_final", il);
|
|
|
|
|
|
// Extract the convolved Q, K, V from conv_output
|
|
|
- ggml_tensor * q_conv = ggml_cont(ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens, n_seqs,
|
|
|
- conv_output->nb[1], conv_output->nb[2], conv_output->nb[3], 0));
|
|
|
+ ggml_tensor * q_conv = ggml_cont_4d(ctx0, ggml_view_4d(ctx0, conv_output_proper, head_k_dim * num_k_heads, 1, n_tokens, n_seqs,
|
|
|
+ conv_output_proper->nb[1], conv_output_proper->nb[2], conv_output_proper->nb[3], 0), head_k_dim, num_k_heads, n_tokens, n_seqs);
|
|
|
cb(q_conv, "q_conv", il);
|
|
|
- ggml_tensor * k_conv = ggml_cont(
|
|
|
- ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens, n_seqs,
|
|
|
- conv_output->nb[1], conv_output->nb[2], conv_output->nb[3],
|
|
|
- head_k_dim * num_k_heads * ggml_element_size(conv_output)));
|
|
|
+ ggml_tensor * k_conv = ggml_cont_4d(ctx0, ggml_view_4d(ctx0, conv_output_proper, head_k_dim * num_k_heads, 1, n_tokens, n_seqs,
|
|
|
+ conv_output_proper->nb[1], conv_output_proper->nb[2], conv_output_proper->nb[3], head_k_dim * num_k_heads * ggml_element_size(conv_output_proper)),
|
|
|
+ head_k_dim, num_k_heads, n_tokens, n_seqs);
|
|
|
cb(q_conv, "k_conv", il);
|
|
|
- ggml_tensor * v_conv = ggml_cont(
|
|
|
- ctx0, ggml_view_4d(ctx0, conv_output, head_v_dim, num_v_heads, n_tokens, n_seqs,
|
|
|
- conv_output->nb[1], conv_output->nb[2], conv_output->nb[3],
|
|
|
- 2 * head_k_dim * num_k_heads * ggml_element_size(conv_output)));
|
|
|
+ ggml_tensor * v_conv = ggml_cont_4d(ctx0, ggml_view_4d(ctx0, conv_output_proper, head_v_dim, num_v_heads, n_tokens, n_seqs,
|
|
|
+ conv_output_proper->nb[1], conv_output_proper->nb[2], conv_output_proper->nb[3], 2 * head_k_dim * num_k_heads * ggml_element_size(conv_output_proper)),
|
|
|
+ head_v_dim, num_v_heads, n_tokens, n_seqs);
|
|
|
cb(q_conv, "v_conv", il);
|
|
|
|
|
|
ggml_build_forward_expand(gf, ssm_states_all);
|
|
|
@@ -476,6 +515,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
|
|
|
// Call the new ggml_delta_net function with the corrected flow
|
|
|
ggml_tensor * output = ggml_delta_net(k_conv, v_conv, q_conv, gate, beta, state_broadcast, true, 1.0f, il);
|
|
|
+ cb(q_conv, "delta_output", il);
|
|
|
|
|
|
// Extract the output part
|
|
|
ggml_tensor * attn_out =
|