llm_build_qwen3next.cpp 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. #include "llm_build_qwen3next.h"
  2. #include <cmath>
  3. // Implementation of depthwise 1D convolution using F32 to avoid F16 limitations
  4. static ggml_tensor* ggml_conv_1d_dw_f32(
  5. ggml_context * ctx,
  6. ggml_tensor * kernel,
  7. ggml_tensor * input,
  8. int stride,
  9. int padding,
  10. int dilation) {
  11. // Following the pattern from ggml_conv_1d_dw but using F32
  12. // Reshape input from [length, channels, batch, dummy] to [length, 1, channels, batch]
  13. ggml_tensor* reshaped_input = ggml_reshape_4d(ctx, input, input->ne[0], 1, input->ne[1], input->ne[2]);
  14. // Apply im2col with F32 destination type to avoid F16 requirement
  15. ggml_tensor* im2col_result = ggml_im2col(ctx, kernel, reshaped_input, stride, 0, padding, 0, dilation, 0, false, GGML_TYPE_F32);
  16. // Now multiply: im2col_result * kernel (following the exact pattern from ggml_conv_1d_dw)
  17. // In ggml_conv_1d_dw: ggml_mul_mat(ctx, im2col, a) where a is the kernel
  18. ggml_tensor* mul_result = ggml_mul_mat(ctx, im2col_result, kernel);
  19. // Reshape the result following ggml_conv_1d_dw: [result->ne[0], result->ne[2], 1]
  20. ggml_tensor* output_3d = ggml_reshape_3d(ctx, mul_result, mul_result->ne[0], mul_result->ne[2], 1);
  21. return output_3d;
  22. }
  23. llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) :
  24. llm_graph_context_mamba(params) {
  25. const int64_t n_embd_head = hparams.n_embd_head_v;
  26. GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
  27. ggml_tensor * cur;
  28. ggml_tensor * inpL;
  29. inpL = build_inp_embd(model.tok_embd);
  30. cb(inpL, "model.embed_tokens", -1);
  31. auto * inp = build_inp_mem_hybrid();
  32. ggml_tensor * inp_pos = build_inp_pos();
  33. ggml_tensor * inp_out_ids = build_inp_out_ids();
  34. for (int il = 0; il < n_layer; ++il) {
  35. struct ggml_tensor * inpSA = inpL;
  36. cur = build_q3n_norm(inpL, model.layers[il].attn_norm, il);
  37. cb(cur, "attn_norm", il);
  38. // Determine layer type and build appropriate attention mechanism
  39. if (hparams.is_recurrent(il)) {
  40. // Linear attention layer (gated delta net)
  41. cur = build_qwen3next_linear_attn_layer(inp->get_recr(), cur, model, ubatch, il);
  42. } else {
  43. // Full attention layer
  44. cur = build_qwen3next_attention_layer(cur, inp_pos, inp->get_attn(), model, n_embd_head, il);
  45. }
  46. // Post-attention norm
  47. cur = build_q3n_norm(cur, model.layers[il].attn_post_norm, il);
  48. cb(cur, "attn_post_norm", il);
  49. if (il == n_layer - 1 && inp_out_ids) {
  50. cur = ggml_get_rows(ctx0, cur, inp_out_ids);
  51. inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
  52. }
  53. // Residual connection
  54. cur = ggml_add(ctx0, cur, inpSA);
  55. cb(cur, "attn_residual", il);
  56. // FFN layer (MoE or dense)
  57. cur = build_layer_ffn(cur, model, il);
  58. cb(cur, "post_moe", il);
  59. // Input for next layer
  60. inpL = cur;
  61. }
  62. cur = inpL;
  63. // Final norm
  64. cur = build_q3n_norm(cur, model.output_norm, -1);
  65. cb(cur, "result_norm", -1);
  66. res->t_embd = cur;
  67. // LM head
  68. cur = build_lora_mm(model.output, cur);
  69. cb(cur, "result_output", -1);
  70. ggml_set_output(cur);
  71. res->t_logits = cur;
  72. ggml_build_forward_expand(gf, cur);
  73. }
  74. struct ggml_tensor * llm_build_qwen3next::build_q3n_norm(struct ggml_tensor * input, struct ggml_tensor * weights, int layer) {
  75. ggml_tensor * input_norm = ggml_scale_bias(ctx0, weights, 1.0f, 1.0f);
  76. return build_norm(input, input_norm, nullptr, LLM_NORM_RMS, layer);
  77. }
  78. ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_tensor * cur,
  79. ggml_tensor * inp_pos,
  80. llm_graph_input_attn_kv * inp_attn,
  81. const llama_model & model,
  82. const int64_t n_embd_head,
  83. const int il) {
  84. ggml_tensor * gate = build_lora_mm(model.layers[il].wq_gate, cur);
  85. // compute Q and K and RoPE them
  86. struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
  87. cb(Qcur, "Qcur", il);
  88. struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
  89. cb(Kcur, "Kcur", il);
  90. struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
  91. cb(Vcur, "Vcur", il);
  92. Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
  93. Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
  94. Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
  95. // Apply Q/K normalization
  96. Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
  97. Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
  98. cb(Kcur, "Qcur_normed", il);
  99. cb(Kcur, "Kcur_normed", il);
  100. // Apply RoPE
  101. Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
  102. attn_factor, beta_fast, beta_slow);
  103. Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
  104. attn_factor, beta_fast, beta_slow);
  105. cb(Qcur, "Qcur", il);
  106. cb(Kcur, "Kcur", il);
  107. cb(Vcur, "Vcur", il);
  108. // Attention computation
  109. const float kq_scale =
  110. hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
  111. cur = build_attn(inp_attn, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
  112. // Apply gating
  113. cur = ggml_cont(ctx0, ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)));
  114. cb(cur, "attn_gated", il);
  115. cur = build_lora_mm(model.layers[il].wo, cur);
  116. cb(cur, "attn_output", il);
  117. return cur;
  118. }
  119. ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_input_rs * inp,
  120. ggml_tensor * cur,
  121. const llama_model & model,
  122. const llama_ubatch & ubatch,
  123. int il) {
  124. // Gated Delta Net implementation using the new ggml_delta_net function
  125. const auto * mctx_cur = inp->mctx;
  126. const int64_t d_inner = hparams.ssm_d_inner;
  127. const int64_t n_heads = hparams.ssm_dt_rank;
  128. const int64_t head_dim = d_inner / n_heads;
  129. const int64_t n_seqs = ubatch.n_seqs;
  130. const int64_t head_k_dim = hparams.ssm_d_state;
  131. const int64_t head_v_dim = hparams.ssm_d_state;
  132. const int64_t num_k_heads = hparams.ssm_n_group;
  133. const int64_t num_v_heads = hparams.ssm_dt_rank;
  134. const int64_t n_seq_tokens = ubatch.n_seq_tokens;
  135. const int64_t n_tokens = ubatch.n_tokens;
  136. GGML_ASSERT(n_seqs != 0);
  137. GGML_ASSERT(ubatch.equal_seqs());
  138. GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
  139. // Input projections
  140. ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur);
  141. cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
  142. ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
  143. cb(mixed_ba, "linear_attn_mixed_ba", il);
  144. int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads;
  145. ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tokens, n_seqs);
  146. // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
  147. int64_t ba_new_dim = 2 * num_v_heads / num_k_heads;
  148. ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tokens, n_seqs);
  149. // Split mixed_ba into b and a (beta and alpha parameters)
  150. int64_t split_sizes_ba[2] = {
  151. num_v_heads / num_k_heads, // beta size
  152. num_v_heads / num_k_heads // alpha size
  153. };
  154. ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tokens, n_seqs,
  155. mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0);
  156. cb(b, "b", il);
  157. ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tokens, n_seqs,
  158. mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3],
  159. split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
  160. cb(a, "a", il);
  161. // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
  162. ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_tokens, n_seqs);
  163. ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_tokens, n_seqs);
  164. GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
  165. ggml_tensor * alpha_softplus = softplus(alpha, model.layers[il].ssm_dt);
  166. cb(alpha_softplus, "a_softplus", il);
  167. ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus
  168. cb(gate, "gate", il);
  169. // Get convolution states from cache
  170. ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
  171. ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
  172. // Build the convolution states tensor
  173. ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
  174. cb(conv_states, "conv_states", il);
  175. // Split mixed_qkvz into query, key, value, z
  176. int64_t split_sizes_qkvz[4] = {
  177. head_k_dim, // query size
  178. head_k_dim, // key size
  179. head_v_dim * num_v_heads / num_k_heads, // value size
  180. head_v_dim * num_v_heads / num_k_heads // z size
  181. };
  182. ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_tokens, n_seqs,
  183. mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0));
  184. cb(query, "q", il);
  185. ggml_tensor * key = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, n_seqs,
  186. mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
  187. split_sizes_qkvz[0] * sizeof(float)));
  188. cb(key, "k", il);
  189. ggml_tensor * value = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tokens, n_seqs,
  190. mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
  191. (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)));
  192. cb(value, "v", il);
  193. ggml_tensor * z = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tokens, n_seqs,
  194. mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
  195. (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)));
  196. cb(z, "z", il);
  197. // Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
  198. ggml_tensor * value_reshaped =
  199. ggml_reshape_4d(ctx0, ggml_cont(ctx0, value), head_v_dim, num_v_heads, n_tokens, n_seqs);
  200. ggml_tensor * z_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, z), head_v_dim, num_v_heads, n_tokens, n_seqs);
  201. GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value_reshaped) +
  202. ggml_nelements(z_reshaped) ==
  203. ggml_nelements(mixed_qkvz));
  204. // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
  205. // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
  206. ggml_tensor * query_flat = ggml_reshape_3d(ctx0, query, head_k_dim * num_k_heads, n_tokens, n_seqs);
  207. cb(query_flat, "query_flat", il);
  208. // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
  209. ggml_tensor * key_flat = ggml_reshape_3d(ctx0, key, head_k_dim * num_k_heads, n_tokens, n_seqs);
  210. cb(key_flat, "key_flat", il);
  211. // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
  212. ggml_tensor * value_flat = ggml_reshape_3d(ctx0, value_reshaped, head_v_dim * num_v_heads, n_tokens, n_seqs);
  213. cb(value_flat, "value_flat", il);
  214. // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
  215. ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
  216. qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
  217. qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
  218. cb(qkv_mixed, "qkv_mixed_concatenated", il);
  219. // Calculate the total conv dimension
  220. int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
  221. // Calculate convolution kernel size
  222. ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
  223. const int64_t conv_kernel_size = conv_kernel->ne[0];
  224. conv_kernel = ggml_permute(ctx0, conv_kernel, 0, 2, 1, 3);
  225. conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state, n_seqs);
  226. cb(conv_states, "conv_states_reshaped", il);
  227. ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
  228. cb(conv_input, "conv_input", il);
  229. // Apply convolution
  230. ggml_tensor * conv_output = ggml_conv_1d_dw_f32(ctx0, conv_kernel, conv_input, 1, conv_kernel_size - 1, n_seqs);
  231. cb(conv_output, "conv_output_raw", il);
  232. // Remove the padding
  233. ggml_tensor * conv_output_no_padding = ggml_view_4d(ctx0, conv_output, conv_output->ne[0] - (conv_kernel_size - 1), conv_output->ne[1], conv_output->ne[2], conv_output->ne[3],
  234. conv_output->nb[1], conv_output->nb[2], conv_output->nb[3],
  235. (conv_kernel_size - 1) * ggml_element_size(conv_output));
  236. cb(conv_output_no_padding, "conv_output_no_padding", il);
  237. // Take only the first (n_tokens * n_seqs) values
  238. ggml_tensor * conv_output_proper = ggml_view_4d(ctx0, conv_output_no_padding, n_tokens * n_seqs, conv_output_no_padding->ne[1], conv_output_no_padding->ne[2], conv_output_no_padding->ne[3],
  239. conv_output_no_padding->nb[1], conv_output_no_padding->nb[2], conv_output_no_padding->nb[3], 0);
  240. cb(conv_output_proper, "conv_output_proper", il);
  241. conv_output_proper = ggml_permute(ctx0, conv_output_proper, 0, 1, 3, 2);
  242. conv_output_proper = ggml_cont_4d(ctx0, conv_output_proper, qkv_dim, 1, n_tokens, n_seqs);
  243. ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
  244. cb(conv_output_silu, "conv_output_silu", il);
  245. // Update convolution state cache
  246. // Extract the last (conv_kernel_size - 1) states from conv_input
  247. ggml_tensor * last_conv_states =
  248. ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, qkv_dim, n_seqs, conv_input->nb[1], conv_input->nb[2],
  249. n_seq_tokens * conv_input->nb[0]);
  250. ggml_build_forward_expand(gf,
  251. ggml_cpy(ctx0, last_conv_states,
  252. ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * qkv_dim * n_seqs,
  253. mctx_cur->get_head() * (conv_kernel_size - 1) * qkv_dim *
  254. ggml_element_size(conv_states_all))));
  255. cb(conv_states_all, "conv_states_updated", il);
  256. conv_output_proper = ggml_reshape_2d(ctx0, conv_output_silu, n_tokens * n_seqs, qkv_dim);
  257. cb(conv_output_proper, "conv_output_final", il);
  258. ggml_tensor * conv_transposed = ggml_transpose(ctx0, conv_output_proper);
  259. cb(conv_transposed, "conv_transposed", il);
  260. ggml_tensor * conv_qkv_mix = ggml_cont_2d(ctx0, conv_transposed, qkv_dim, n_tokens * n_seqs);
  261. cb(conv_qkv_mix, "conv_qkv_mix", il);
  262. // Extract the convolved Q, K, V from conv_output
  263. ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_tokens * n_seqs,
  264. conv_qkv_mix->nb[1], 0);
  265. cb(q_conv, "q_conv", il);
  266. ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_tokens * n_seqs,
  267. conv_qkv_mix->nb[1], head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
  268. cb(k_conv, "k_conv", il);
  269. ggml_tensor * v_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_tokens * n_seqs,
  270. conv_qkv_mix->nb[1], 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
  271. cb(v_conv, "v_conv", il);
  272. // Unsqueeze them
  273. q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_tokens, n_seqs);
  274. k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_tokens, n_seqs);
  275. v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_tokens, n_seqs);
  276. beta = ggml_cont_4d(ctx0, b, 1, num_v_heads, n_tokens, n_seqs);
  277. alpha = ggml_cont_4d(ctx0, a, 1, num_v_heads, n_tokens, n_seqs);
  278. ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_states_all, head_dim, head_dim * n_heads, 1, 1);
  279. gate = ggml_reshape_4d(ctx0, gate, 1, n_heads, n_tokens, n_seqs);
  280. // if head keys and value keys are different, repeat to force tensors into matching shapes
  281. if (num_k_heads != num_v_heads) {
  282. GGML_ASSERT(num_v_heads % num_k_heads == 0);
  283. int64_t repeat_factor = num_v_heads / num_k_heads;
  284. q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, n_tokens, num_k_heads, n_seqs);
  285. k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, n_tokens, num_k_heads, n_seqs);
  286. q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, n_tokens * repeat_factor, num_k_heads, n_seqs);
  287. k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, n_tokens * repeat_factor, num_k_heads, n_seqs);
  288. // Fix dimension order: last two should be [tokens, batches]
  289. q_conv = ggml_reshape_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_tokens, n_seqs);
  290. k_conv = ggml_reshape_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_tokens, n_seqs);
  291. }
  292. // Call the new ggml_delta_net function with the corrected flow
  293. const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(num_k_heads)) : hparams.f_attention_scale;
  294. ggml_tensor * attn_out = ggml_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, kq_scale, hparams.f_norm_rms_eps);
  295. cb(attn_out, "attn_out", il);
  296. // The tensors were concatenated 1d, so we need to extract them 1d as well
  297. const int64_t output_flat_size = head_dim * n_heads * n_tokens * n_seqs;
  298. ggml_tensor * attn_out_1d =
  299. ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
  300. cb(attn_out_1d, "attn_out_1d", il);
  301. ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_dim, n_heads, n_tokens, n_seqs);
  302. cb(attn_out_final, "attn_out_final", il);
  303. // Extract the state part (second part of the concatenated tensor)
  304. // State starts after n_tokens elements along dimension 1
  305. const int64_t state_flat_size = head_dim * head_dim * n_heads * n_seqs;
  306. ggml_tensor * state_1d = ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out));
  307. cb(state_1d, "state_1d", il);
  308. ggml_tensor * new_state = ggml_reshape_4d(ctx0, state_1d, head_dim, head_dim, n_heads, n_seqs);
  309. cb(new_state, "new_state", il);
  310. // Update the recurrent states - we use the new_state directly since it's already the last state
  311. ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_state, ssm_states_all));
  312. // Reshape both attn_out_final and z to 2D tensors for normalization
  313. // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
  314. ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out_final), head_dim, n_heads * n_tokens * n_seqs);
  315. // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
  316. ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_tokens * n_seqs);
  317. // Apply gated normalization: self.norm(core_attn_out, z)
  318. // This is Qwen3NextRMSNormGated which applies: RMSNorm(x) * silu(gate)
  319. ggml_tensor * attn_out_norm = build_norm(attn_out_2d_final, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
  320. cb(attn_out_norm, "attn_out_norm", il);
  321. // Apply silu gate: attn_out_norm * silu(z_2d)
  322. ggml_tensor * z_silu = ggml_silu(ctx0, z_2d);
  323. cb(z_silu, "z_silu", il);
  324. ggml_tensor * gated_output = ggml_mul(ctx0, attn_out_norm, z_silu);
  325. cb(gated_output, "gated_output", il);
  326. // Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
  327. ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, gated_output, head_dim, n_heads, n_tokens, n_seqs);
  328. // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
  329. ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_tokens, n_seqs);
  330. cb(final_output, "final_output", il);
  331. // Output projection
  332. cur = build_lora_mm(model.layers[il].ssm_out, final_output);
  333. cb(cur, "linear_attn_out", il);
  334. // Reshape back to original dimensions
  335. cur = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens));
  336. return cur;
  337. }
  338. ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il) {
  339. // Check if this is an MoE layer
  340. if (model.layers[il].ffn_gate_inp != nullptr) {
  341. // MoE branch
  342. ggml_tensor * moe_out =
  343. build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
  344. model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, nullptr, n_expert,
  345. n_expert_used, LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
  346. cb(moe_out, "ffn_moe_out", il);
  347. // Add shared experts if present
  348. if (model.layers[il].ffn_up_shexp != nullptr) {
  349. ggml_tensor * ffn_shexp =
  350. build_ffn(cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL,
  351. model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
  352. cb(ffn_shexp, "ffn_shexp", il);
  353. cur = ggml_add(ctx0, moe_out, ffn_shexp);
  354. cb(cur, "ffn_out", il);
  355. } else {
  356. cur = moe_out;
  357. }
  358. } else {
  359. // Dense FFN branch
  360. cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL,
  361. model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
  362. cb(cur, "ffn_out", il);
  363. }
  364. // Residual connection
  365. cur = ggml_add(ctx0, cur, cur); // This should be the residual from before FFN
  366. cb(cur, "ffn_residual", il);
  367. return cur;
  368. };
  369. ggml_tensor * llm_build_qwen3next::softplus(ggml_tensor * alpha, ggml_tensor * dt_bias) {
  370. ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, dt_bias); // a + dt_bias
  371. ggml_tensor * alpha_exp = ggml_exp(ctx0, alpha_biased); // exp(a + dt_bias)
  372. ggml_tensor * one_plus_exp = ggml_scale_bias(ctx0, alpha_exp, 1.0f, 1.0f); // 1 + exp(a + dt_bias)
  373. ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp); // log(1 + exp(...))
  374. return alpha_softplus;
  375. }