Просмотр исходного кода

Linear layer output convergence

Piotr Wilkin 3 месяцев назад
Родитель
Сommit
adcbd9428f
2 измененных файлов с 11 добавлено и 12 удалено
  1. 10 12
      src/models/llm_build_qwen3next.cpp
  2. 1 0
      src/models/llm_build_qwen3next.h

+ 10 - 12
src/models/llm_build_qwen3next.cpp

@@ -99,6 +99,12 @@ struct ggml_tensor * llm_build_qwen3next::build_q3n_norm(struct ggml_tensor * in
     return build_norm(input, input_norm, nullptr, LLM_NORM_RMS, layer);
     return build_norm(input, input_norm, nullptr, LLM_NORM_RMS, layer);
 }
 }
 
 
+struct ggml_tensor * llm_build_qwen3next::build_q3n_gated_norm(struct ggml_tensor * input, struct ggml_tensor * weights, struct ggml_tensor * gate, int layer) {
+    ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer);    
+    ggml_tensor * gated_silu = ggml_silu(ctx0, gate);
+    return ggml_mul(ctx0, normalized, gated_silu);
+}
+
 struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_tensor *             cur,
 struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_tensor *             cur,
                                                                    ggml_tensor *             inp_pos,
                                                                    ggml_tensor *             inp_pos,
                                                                    llm_graph_input_attn_kv * inp_attn,
                                                                    llm_graph_input_attn_kv * inp_attn,
@@ -550,7 +556,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
         ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
         ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
     cb(attn_out_1d, "attn_out_1d", il);
     cb(attn_out_1d, "attn_out_1d", il);
     
     
-    ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_dim, n_tokens, n_heads, n_seqs);
+    ggml_tensor * attn_out_final = ggml_cont(ctx0, ggml_permute(ctx0, ggml_cont_4d(ctx0, attn_out_1d, head_dim, n_tokens, n_heads, n_seqs), 0, 2, 1, 3));
     cb(attn_out_final, "attn_out_final", il);
     cb(attn_out_final, "attn_out_final", il);
    
    
     // Extract the state part (second part of the concatenated tensor)
     // Extract the state part (second part of the concatenated tensor)
@@ -574,18 +580,10 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_tokens * n_seqs);
     ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_tokens * n_seqs);
 
 
     // Apply gated normalization: self.norm(core_attn_out, z)
     // Apply gated normalization: self.norm(core_attn_out, z)
-    // This is Qwen3NextRMSNormGated which applies: RMSNorm(x) * silu(gate)
-    ggml_tensor * attn_out_norm = build_norm(attn_out_2d_final, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
-    cb(attn_out_norm, "attn_out_norm", il);
-
-    // Apply silu gate: attn_out_norm * silu(z_2d)
-    ggml_tensor * z_silu = ggml_silu(ctx0, z_2d);
-    cb(z_silu, "z_silu", il);
-    ggml_tensor * gated_output = ggml_mul(ctx0, attn_out_norm, z_silu);
-    cb(gated_output, "gated_output", il);
-
+    ggml_tensor * attn_out_norm = build_q3n_gated_norm(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
+    
     // Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
     // Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
-    ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, gated_output, head_dim, n_heads, n_tokens, n_seqs);
+    ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, attn_out_norm, head_dim, n_heads, n_tokens, n_seqs);
 
 
     // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
     // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
     ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_tokens, n_seqs);
     ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_tokens, n_seqs);

+ 1 - 0
src/models/llm_build_qwen3next.h

@@ -41,5 +41,6 @@ private:
     ggml_tensor * softplus(ggml_tensor * alpha, ggml_tensor * dt_bias);
     ggml_tensor * softplus(ggml_tensor * alpha, ggml_tensor * dt_bias);
 
 
     ggml_tensor * build_q3n_norm(struct ggml_tensor * input, struct ggml_tensor * weights, int layer);
     ggml_tensor * build_q3n_norm(struct ggml_tensor * input, struct ggml_tensor * weights, int layer);
+    ggml_tensor * build_q3n_gated_norm(struct ggml_tensor * input, struct ggml_tensor * weights, struct ggml_tensor * gate, int layer);
 
 
 };
 };