llm_build_qwen3next.cpp 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832
  1. #include "llm_build_qwen3next.h"
  2. #include "../../ggml/src/ggml-impl.h"
  3. #include <cmath>
  4. // Implementation of depthwise 1D convolution using F32 to avoid F16 limitations
  5. static ggml_tensor* ggml_conv_1d_dw_f32(
  6. ggml_context * ctx,
  7. ggml_tensor * kernel,
  8. ggml_tensor * input,
  9. int stride,
  10. int padding,
  11. int dilation) {
  12. const int64_t n_seqs = input->ne[2];
  13. const int64_t channels = input->ne[1];
  14. if (n_seqs > 1) {
  15. ggml_tensor* reshaped_input = ggml_reshape_4d(ctx, input, input->ne[0], 1, input->ne[1] * input->ne[2], 1);
  16. // For the kernel in [kernel_size, 1, channel_size, 1] format:
  17. // 1. First reshape to 2D: [kernel_size, channel_size]
  18. ggml_tensor* kernel_2d = ggml_reshape_2d(ctx, kernel, kernel->ne[0], kernel->ne[2]);
  19. // 2. Repeat the kernel for each sequence: [kernel_size, channel_size * n_seqs]
  20. ggml_tensor* repeated_kernel = ggml_repeat(ctx, kernel_2d,
  21. ggml_new_tensor_2d(ctx, kernel->type, kernel->ne[0], channels * n_seqs));
  22. // 3. Reshape back to 4D for im2col: [kernel_size, 1, channels * n_seqs, 1]
  23. ggml_tensor* reshaped_kernel = ggml_reshape_4d(ctx, repeated_kernel, kernel->ne[0], 1, channels * n_seqs, 1);
  24. // Apply im2col
  25. ggml_tensor* im2col_result = ggml_im2col(ctx, reshaped_kernel, reshaped_input, stride, 0, padding, 0, dilation, 0, false, GGML_TYPE_F32);
  26. // Multiply with the kernel
  27. ggml_tensor* mul_result = ggml_mul_mat(ctx, im2col_result, reshaped_kernel);
  28. // Reshape result back to [output_len, channels, n_seqs]
  29. ggml_tensor* output_3d = ggml_reshape_3d(ctx, mul_result, mul_result->ne[0], channels, n_seqs);
  30. return output_3d;
  31. } else {
  32. // Single sequence case - kernel is already in [kernel_size, 1, channel_size, 1]
  33. ggml_tensor* reshaped_input = ggml_reshape_4d(ctx, input, input->ne[0], 1, input->ne[1], 1);
  34. ggml_tensor* im2col_result = ggml_im2col(ctx, kernel, reshaped_input, stride, 0, padding, 0, dilation, 0, false, GGML_TYPE_F32);
  35. ggml_tensor* mul_result = ggml_mul_mat(ctx, im2col_result, kernel);
  36. ggml_tensor* output_3d = ggml_reshape_3d(ctx, mul_result, mul_result->ne[0], mul_result->ne[2], 1);
  37. return output_3d;
  38. }
  39. }
  40. llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) :
  41. llm_graph_context_mamba(params) {
  42. const int64_t n_embd_head = hparams.n_embd_head_v;
  43. GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
  44. //GGML_ASSERT(n_embd_head == hparams.n_rot);
  45. ggml_tensor * cur;
  46. ggml_tensor * inpL;
  47. inpL = build_inp_embd(model.tok_embd);
  48. cb(inpL, "model.embed_tokens", -1);
  49. auto * inp = build_inp_mem_hybrid();
  50. ggml_tensor * inp_pos = build_inp_pos();
  51. ggml_tensor * inp_out_ids = build_inp_out_ids();
  52. for (int il = 0; il < n_layer; ++il) {
  53. struct ggml_tensor * inpSA = inpL;
  54. cur = build_q3n_norm(inpL, model.layers[il].attn_norm, il);
  55. cb(cur, "attn_norm", il);
  56. // Determine layer type and build appropriate attention mechanism
  57. if (hparams.is_recurrent(il)) {
  58. // Linear attention layer (gated delta net)
  59. cur = build_qwen3next_linear_attn_layer(inp->get_recr(), cur, model, ubatch, il);
  60. } else {
  61. // Full attention layer
  62. cur = build_qwen3next_attention_layer(cur, inp_pos, inp->get_attn(), model, n_embd_head, il);
  63. }
  64. if (il == n_layer - 1 && inp_out_ids) {
  65. cur = ggml_get_rows(ctx0, cur, inp_out_ids);
  66. inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
  67. }
  68. // Residual connection
  69. cur = ggml_add(ctx0, cur, inpSA);
  70. cb(cur, "attn_residual", il);
  71. // Save the tensor before post-attention norm for residual connection
  72. ggml_tensor * ffn_residual = cur;
  73. // Post-attention norm
  74. ggml_tensor * attn_post_norm = build_q3n_norm(cur, model.layers[il].attn_post_norm, il);
  75. cb(attn_post_norm, "attn_post_norm", il);
  76. // FFN layer (MoE or dense) - without residual connection
  77. cur = build_layer_ffn(attn_post_norm, model, il);
  78. cb(cur, "ffn_out", il);
  79. // Residual connection for FFN - add to the tensor from before post_attention_layernorm
  80. cur = ggml_add(ctx0, cur, ffn_residual);
  81. cb(cur, "post_moe", il);
  82. // Input for next layer
  83. inpL = cur;
  84. }
  85. cur = inpL;
  86. // Final norm
  87. cur = build_q3n_norm(cur, model.output_norm, -1);
  88. cb(cur, "result_norm", -1);
  89. res->t_embd = cur;
  90. // LM head
  91. cur = build_lora_mm(model.output, cur);
  92. cb(cur, "result_output", -1);
  93. res->t_logits = cur;
  94. ggml_build_forward_expand(gf, cur);
  95. }
  96. struct ggml_tensor * llm_build_qwen3next::build_q3n_norm(struct ggml_tensor * input, struct ggml_tensor * weights, int layer) {
  97. // ggml_tensor * input_norm = ggml_scale_bias(ctx0, weights, 1.0f, 1.0f);
  98. // EDIT: we moved the shifting part to the conversion, so we just call normal build_norm
  99. return build_norm(input, weights, nullptr, LLM_NORM_RMS, layer);
  100. }
  101. struct ggml_tensor * llm_build_qwen3next::build_q3n_gated_norm(struct ggml_tensor * input, struct ggml_tensor * weights, struct ggml_tensor * gate, int layer) {
  102. ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer);
  103. ggml_tensor * gated_silu = ggml_silu(ctx0, gate);
  104. return ggml_mul(ctx0, normalized, gated_silu);
  105. }
  106. struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_tensor * cur,
  107. ggml_tensor * inp_pos,
  108. llm_graph_input_attn_kv * inp_attn,
  109. const llama_model & model,
  110. const int64_t n_embd_head,
  111. const int il) {
  112. // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
  113. // Qwen3Next uses a single Q projection that outputs query + gate
  114. struct ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur);
  115. cb(Qcur_full, "Qcur_full", il);
  116. Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1);
  117. // Split Q projection into query and gate
  118. // The split should be along dimension 0 (the feature dimension)
  119. struct ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
  120. struct ggml_tensor * gate = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3],
  121. n_embd_head * ggml_element_size(Qcur_full));
  122. cb(Qcur, "Qcur", il);
  123. cb(gate, "gate", il);
  124. // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention
  125. Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
  126. cb(Qcur, "Qcur_reshaped", il);
  127. // Apply Q normalization
  128. Qcur = build_q3n_norm(Qcur, model.layers[il].attn_q_norm, il);
  129. cb(Qcur, "Qcur_normed", il);
  130. struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
  131. cb(Kcur, "Kcur", il);
  132. struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
  133. cb(Vcur, "Vcur", il);
  134. // Apply K normalization
  135. Kcur = build_q3n_norm(Kcur, model.layers[il].attn_k_norm, il);
  136. cb(Kcur, "Kcur_normed", il);
  137. // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
  138. gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
  139. cb(gate, "gate_reshaped", il);
  140. Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
  141. Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
  142. Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
  143. // Apply RoPE
  144. Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
  145. attn_factor, beta_fast, beta_slow);
  146. Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
  147. attn_factor, beta_fast, beta_slow);
  148. cb(Qcur, "Qcur", il);
  149. cb(Kcur, "Kcur", il);
  150. cb(Vcur, "Vcur", il);
  151. // Attention computation
  152. const float kq_scale =
  153. hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
  154. cur = build_attn(inp_attn, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
  155. cb(cur, "attn_pregate", il);
  156. struct ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
  157. cb(gate_sigmoid, "gate_sigmoid", il);
  158. // Apply gating directly using the original gate tensor
  159. cur = ggml_mul(ctx0, cur, gate_sigmoid);
  160. cb(cur, "attn_gated", il);
  161. cur = build_lora_mm(model.layers[il].wo, cur);
  162. cb(cur, "attn_output", il);
  163. return cur;
  164. }
  165. // delta_net
  166. // prepare all the tensor data for the operation so we only
  167. // do the absolutely necessary steps in the op itself
  168. struct ggml_tensor * llm_build_qwen3next::delta_net(
  169. struct ggml_context * ctx,
  170. struct ggml_tensor * q,
  171. struct ggml_tensor * k,
  172. struct ggml_tensor * v,
  173. struct ggml_tensor * g,
  174. struct ggml_tensor * beta,
  175. struct ggml_tensor * state,
  176. bool use_qk_l2norm,
  177. float eps_norm,
  178. const int il
  179. ) {
  180. GGML_ASSERT(ggml_is_contiguous(q));
  181. GGML_ASSERT(ggml_is_contiguous(k));
  182. GGML_ASSERT(ggml_is_contiguous(v));
  183. GGML_ASSERT(ggml_is_contiguous(g));
  184. GGML_ASSERT(ggml_is_contiguous(beta));
  185. GGML_ASSERT(ggml_is_contiguous(state));
  186. const int64_t S_k = q->ne[0];
  187. const int64_t H_k = q->ne[1];
  188. const int64_t n_tokens = q->ne[2];
  189. const int64_t n_seqs = q->ne[3];
  190. const int64_t S_v = v->ne[0];
  191. const int64_t H_v = v->ne[1];
  192. GGML_ASSERT(v->ne[2] == n_tokens);
  193. GGML_ASSERT(k->ne[2] == n_tokens);
  194. GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
  195. GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
  196. GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
  197. GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
  198. GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
  199. GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
  200. cb(q, "q_prenorm", il);
  201. cb(k, "k_prenorm", il);
  202. if (use_qk_l2norm) {
  203. q = ggml_l2_norm(ctx, q, eps_norm);
  204. k = ggml_l2_norm(ctx, k, eps_norm);
  205. }
  206. cb(k, "k_postnorm", il);
  207. cb(q, "q_prescale", il);
  208. int64_t pad_size = (GGML_DELTA_NET_CHUNK - n_tokens % GGML_DELTA_NET_CHUNK) % GGML_DELTA_NET_CHUNK;
  209. // yes, n_tokens, not H_k, the reference implementation has wrong naming
  210. int64_t num_chunks = (n_tokens + pad_size) / GGML_DELTA_NET_CHUNK;
  211. float scale = 1.0f / sqrtf(S_v);
  212. q = ggml_scale(ctx, q, scale);
  213. cb(beta, "beta_raw", il);
  214. beta = ggml_sigmoid(ctx, beta);
  215. cb(q, "q_postscale", il);
  216. cb(beta, "beta_sigmoid", il);
  217. q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3));
  218. k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3));
  219. v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3));
  220. q = ggml_pad(ctx, q, 0, pad_size, 0, 0);
  221. k = ggml_pad(ctx, k, 0, pad_size, 0, 0);
  222. v = ggml_pad(ctx, v, 0, pad_size, 0, 0);
  223. beta = ggml_cont(ctx, ggml_permute(ctx, beta, 1, 2, 0, 3));
  224. cb(beta, "beta_reshape", il);
  225. g = ggml_cont(ctx, ggml_permute(ctx, g, 2, 0, 3, 1));
  226. cb(g, "g_permute", il);
  227. // ... except for beta and g, where we pad the last dimension
  228. beta = ggml_pad(ctx, beta, pad_size, 0, 0, 0);
  229. g = ggml_pad(ctx, g, pad_size, 0, 0, 0);
  230. cb(q, "q_pad", il);
  231. cb(k, "k_pad", il);
  232. cb(v, "v_pad", il);
  233. cb(beta, "beta_pad", il);
  234. cb(g, "g_pad", il);
  235. GGML_ASSERT(q->ne[1] % GGML_DELTA_NET_CHUNK == 0 && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
  236. GGML_ASSERT(k->ne[1] % GGML_DELTA_NET_CHUNK == 0 && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
  237. GGML_ASSERT(v->ne[1] % GGML_DELTA_NET_CHUNK == 0 && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
  238. GGML_ASSERT(beta->ne[0] % GGML_DELTA_NET_CHUNK == 0 && beta->ne[1] == H_k && beta->ne[2] == 1 && beta->ne[3] == n_seqs);
  239. GGML_ASSERT(g->ne[0] % GGML_DELTA_NET_CHUNK == 0 && g->ne[2] == H_k && g->ne[1] == 1 && g->ne[3] == n_seqs);
  240. ggml_tensor * beta_unsq = ggml_cont_4d(ctx, beta, 1, GGML_DELTA_NET_CHUNK * num_chunks, H_k, n_seqs);
  241. ggml_tensor * beta_bcast = ggml_repeat_4d(ctx, beta_unsq, S_v, GGML_DELTA_NET_CHUNK * num_chunks, H_k, n_seqs);
  242. cb(beta_unsq, "beta_unsq", il);
  243. cb(beta_bcast, "beta_bcast", il);
  244. struct ggml_tensor * v_beta = ggml_mul(ctx, v, beta_bcast);
  245. v_beta = ggml_reshape_4d(ctx, v_beta, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
  246. cb(v_beta, "v_beta", il);
  247. struct ggml_tensor * k_beta = ggml_mul(ctx, k, beta_bcast);
  248. k_beta = ggml_reshape_4d(ctx, k_beta, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
  249. cb(k_beta, "k_beta", il);
  250. k = ggml_reshape_4d(ctx, k, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
  251. cb(k_beta, "k_reshape", il);
  252. g = ggml_reshape_4d(ctx, g, GGML_DELTA_NET_CHUNK, 1, H_k * num_chunks, n_seqs);
  253. cb(g, "g_reshape", il);
  254. struct ggml_tensor * g_cumsum = ggml_cumsum(ctx, g);
  255. cb(g_cumsum, "g_cumsum", il);
  256. struct ggml_tensor * gcs_i = ggml_cont_4d(ctx, g_cumsum, GGML_DELTA_NET_CHUNK, 1, num_chunks * H_v, n_seqs); // [chunk_size, 1, n_tokens, n_seqs]
  257. struct ggml_tensor * gcs_j = ggml_cont_4d(ctx, g_cumsum, 1, GGML_DELTA_NET_CHUNK, num_chunks * H_v, n_seqs); // [1, chunk_size, n_tokens, n_seqs]
  258. // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs]
  259. struct ggml_tensor * gcs_i_broadcast = ggml_repeat_4d(ctx, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v, n_seqs); // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
  260. struct ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx, gcs_j, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v, n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
  261. struct ggml_tensor * decay_mask = ggml_sub(ctx, gcs_j_broadcast, gcs_i_broadcast);
  262. cb(decay_mask, "sub", il);
  263. // Apply lower triangular mask to ensure attention is causal (only past tokens influence current)
  264. decay_mask = ggml_tri_keep(ctx, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
  265. cb(decay_mask, "sub_tri", il);
  266. // Apply exponential to get the decay mask values
  267. decay_mask = ggml_exp(ctx, decay_mask);
  268. cb(decay_mask, "sub_tri_exp", il);
  269. // Apply lower triangular mask again to ensure only lower triangular values remain
  270. decay_mask = ggml_tri_keep(ctx, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
  271. cb(decay_mask, "decay_mask", il);
  272. struct ggml_tensor * kmulkbeta = ggml_mul_mat(ctx, ggml_cont(ctx, k), ggml_cont(ctx, k_beta));
  273. cb(kmulkbeta, "k_beta @ k_t ", il);
  274. struct ggml_tensor * k_decay = ggml_mul(ctx, kmulkbeta, decay_mask);
  275. cb(k_decay, "(k_beta @ k_t) * decay_mask", il);
  276. struct ggml_tensor * attn = ggml_neg(ctx, ggml_tri_keep(ctx, k_decay, GGML_TRI_TYPE_LOWER));
  277. cb(attn, "attn_in", il);
  278. // We'll be returning the result as a 1D tensor due to the dimensions mismatch of the state and output tensors
  279. const int64_t total_dims = (S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs);
  280. struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, total_dims);
  281. ggml_set_op_params_i32(result, 0, H_v);
  282. ggml_set_op_params_i32(result, 1, S_k);
  283. ggml_set_op_params_i32(result, 2, S_v);
  284. ggml_set_op_params_i32(result, 3, n_tokens); // Pass original n_tokens
  285. result->op = GGML_OP_DELTA_NET;
  286. result->src[0] = q;
  287. result->src[1] = k;
  288. result->src[2] = v;
  289. result->src[3] = g_cumsum;
  290. result->src[4] = state;
  291. result->src[5] = decay_mask;
  292. result->src[6] = v_beta;
  293. result->src[7] = k_beta;
  294. result->src[8] = attn;
  295. return result;
  296. }
  297. // delta_net_recurrent
  298. struct ggml_tensor * llm_build_qwen3next::delta_net_recurrent(
  299. struct ggml_context * ctx,
  300. struct ggml_tensor * q,
  301. struct ggml_tensor * k,
  302. struct ggml_tensor * v,
  303. struct ggml_tensor * g,
  304. struct ggml_tensor * beta,
  305. struct ggml_tensor * state,
  306. bool use_qk_l2norm,
  307. float eps_norm,
  308. const int il
  309. ) {
  310. GGML_ASSERT(ggml_is_contiguous(q));
  311. GGML_ASSERT(ggml_is_contiguous(k));
  312. GGML_ASSERT(ggml_is_contiguous(v));
  313. GGML_ASSERT(ggml_is_contiguous(g));
  314. GGML_ASSERT(ggml_is_contiguous(beta));
  315. GGML_ASSERT(ggml_is_contiguous(state));
  316. const int64_t S_k = q->ne[0];
  317. const int64_t H_k = q->ne[1];
  318. const int64_t n_tokens = q->ne[2];
  319. const int64_t n_seqs = q->ne[3];
  320. const int64_t S_v = v->ne[0];
  321. const int64_t H_v = v->ne[1];
  322. GGML_ASSERT(v->ne[2] == n_tokens);
  323. GGML_ASSERT(k->ne[2] == n_tokens);
  324. GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
  325. GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
  326. GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
  327. GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
  328. GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && q->ne[3] == n_seqs);
  329. GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
  330. cb(q, "q_prenorm", il);
  331. cb(k, "k_prenorm", il);
  332. if (use_qk_l2norm) {
  333. q = ggml_l2_norm(ctx, q, eps_norm);
  334. k = ggml_l2_norm(ctx, k, eps_norm);
  335. }
  336. cb(k, "k_postnorm", il);
  337. cb(q, "q_prescale", il);
  338. float scale = 1.0f / sqrtf(S_v);
  339. q = ggml_scale(ctx, q, scale);
  340. cb(beta, "beta_raw", il);
  341. beta = ggml_sigmoid(ctx, beta);
  342. cb(q, "q_postscale", il);
  343. cb(beta, "beta_sigmoid", il);
  344. // Reshape tensors for recurrent computation
  345. // From [S_k, H_k, n_tokens, n_seqs] to [S_k, n_tokens, H_k, n_seqs]
  346. q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3));
  347. cb(q, "q_reshape", il);
  348. k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3));
  349. cb(k, "k_reshape", il);
  350. v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3));
  351. cb(v, "v_reshape", il);
  352. beta = ggml_cont(ctx, ggml_permute(ctx, beta, 1, 2, 0, 3));
  353. cb(beta, "beta_reshape", il);
  354. g = ggml_cont(ctx, ggml_permute(ctx, g, 2, 0, 3, 1));
  355. cb(g, "g_permute", il);
  356. ggml_tensor * q_tokens = ggml_cont_4d(ctx, q, n_tokens, S_k, H_k, n_seqs);
  357. ggml_tensor * k_tokens = ggml_cont_4d(ctx, k, n_tokens, S_k, H_k, n_seqs);
  358. ggml_tensor * v_tokens = ggml_cont_4d(ctx, v, n_tokens, S_v, H_k, n_seqs);
  359. ggml_tensor * g_tokens = ggml_cont_4d(ctx, g, n_tokens, 1, H_k, n_seqs);
  360. ggml_tensor * beta_tokens = ggml_cont_4d(ctx, beta, n_tokens, 1, H_k, n_seqs);
  361. state = ggml_cont_4d(ctx, state, S_v, S_v, H_k, n_seqs);
  362. ggml_tensor * g_tokens_exp = ggml_exp(ctx, g_tokens);
  363. // Create result tensor with the same dimensions as delta_net
  364. const int64_t total_dims = (S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs);
  365. ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, total_dims);
  366. cb(q_tokens, "q_tokens", il);
  367. cb(k_tokens, "k_tokens", il);
  368. cb(v_tokens, "v_tokens", il);
  369. cb(g_tokens, "g_tokens", il);
  370. cb(beta_tokens, "beta_tokens", il);
  371. cb(g_tokens_exp, "g_tokens_exp", il);
  372. cb(state, "state_pre", il);
  373. // Set operation parameters
  374. ggml_set_op_params_i32(result, 0, H_v);
  375. ggml_set_op_params_i32(result, 1, S_k);
  376. ggml_set_op_params_i32(result, 2, S_v);
  377. ggml_set_op_params_i32(result, 3, n_tokens); // Pass original n_tokens
  378. // Set operation and source tensors
  379. result->op = GGML_OP_DELTA_NET_RECURRENT;
  380. result->src[0] = q_tokens;
  381. result->src[1] = k_tokens;
  382. result->src[2] = v_tokens;
  383. result->src[3] = g_tokens_exp;
  384. result->src[4] = beta_tokens;
  385. result->src[5] = state;
  386. return result;
  387. }
  388. ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_input_rs * inp,
  389. ggml_tensor * cur,
  390. const llama_model & model,
  391. const llama_ubatch & ubatch,
  392. int il) {
  393. // Gated Delta Net implementation using the new delta_net function
  394. const auto * mctx_cur = inp->mctx;
  395. const int64_t d_inner = hparams.ssm_d_inner;
  396. const int64_t n_seqs = ubatch.n_seqs;
  397. const int64_t head_k_dim = hparams.ssm_d_state;
  398. const int64_t num_k_heads = hparams.ssm_n_group;
  399. const int64_t num_v_heads = hparams.ssm_dt_rank;
  400. const int64_t head_v_dim = d_inner / num_v_heads;
  401. const int64_t n_seq_tokens = ubatch.n_seq_tokens;
  402. const auto kv_head = mctx_cur->get_head();
  403. GGML_ASSERT(n_seqs != 0);
  404. GGML_ASSERT(ubatch.equal_seqs());
  405. GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
  406. // Input projections
  407. ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur);
  408. cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
  409. ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
  410. cb(mixed_ba, "linear_attn_mixed_ba", il);
  411. int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
  412. ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
  413. // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
  414. int64_t ba_new_dim = 2 * num_v_heads / num_k_heads;
  415. ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs);
  416. // Split mixed_ba into b and a (beta and alpha parameters)
  417. int64_t split_sizes_ba[2] = {
  418. num_v_heads / num_k_heads, // beta size
  419. num_v_heads / num_k_heads // alpha size
  420. };
  421. ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_seq_tokens, n_seqs,
  422. mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0);
  423. cb(b, "b", il);
  424. ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_seq_tokens, n_seqs,
  425. mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3],
  426. split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
  427. cb(a, "a", il);
  428. // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
  429. ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs);
  430. ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
  431. GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
  432. ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
  433. ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
  434. cb(alpha_softplus, "a_softplus", il);
  435. ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus
  436. cb(gate, "gate", il);
  437. // Split mixed_qkvz into query, key, value, z
  438. int64_t split_sizes_qkvz[4] = {
  439. head_k_dim, // query size
  440. head_k_dim, // key size
  441. head_v_dim * num_v_heads / num_k_heads, // value size
  442. head_v_dim * num_v_heads / num_k_heads // z size
  443. };
  444. ggml_tensor * query = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
  445. mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
  446. cb(query, "q", il);
  447. ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
  448. mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
  449. split_sizes_qkvz[0] * sizeof(float));
  450. cb(key, "k", il);
  451. ggml_tensor * value = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
  452. mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
  453. (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float));
  454. cb(value, "v", il);
  455. ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
  456. mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
  457. (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
  458. cb(z, "z", il);
  459. GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value) +
  460. ggml_nelements(z) ==
  461. ggml_nelements(mixed_qkvz));
  462. // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
  463. // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
  464. ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
  465. cb(query_flat, "query_flat", il);
  466. // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
  467. ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
  468. cb(key_flat, "key_flat", il);
  469. // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
  470. ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
  471. cb(value_flat, "value_flat", il);
  472. // Get convolution states from cache
  473. ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
  474. ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
  475. bool is_generation = mctx_cur->get_rs_z() < 0;
  476. // Build the convolution states tensor
  477. ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
  478. cb(conv_states, "conv_states", il);
  479. // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
  480. ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
  481. qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
  482. cb(qkv_mixed, "qkv_mixed", il);
  483. qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
  484. cb(qkv_mixed, "qkv_mixed_permuted", il);
  485. // Calculate the total conv dimension
  486. int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
  487. // Calculate convolution kernel size
  488. ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
  489. const int64_t conv_kernel_size = conv_kernel->ne[0];
  490. const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
  491. //conv_kernel = ggml_permute(ctx0, conv_kernel, 0, 2, 1, 3);
  492. conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
  493. cb(conv_states, "conv_states_reshaped", il);
  494. ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
  495. cb(conv_input, "conv_input", il);
  496. // Update convolution state cache
  497. // Extract the last (conv_kernel_size - 1) states from conv_input
  498. ggml_tensor * last_conv_states =
  499. ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], conv_input->nb[2],
  500. (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
  501. cb(last_conv_states, "last_conv_states", il);
  502. ggml_tensor * state_update_target = ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
  503. kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
  504. cb(state_update_target, "state_update_target", il);
  505. ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
  506. cb(conv_states_all, "conv_states_updated", il);
  507. // Apply convolution
  508. ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); //ggml_conv_1d_dw_f32(ctx0, conv_kernel, conv_input, 1, conv_kernel_size - 1, 1);
  509. cb(conv_output_proper, "conv_output_raw", il);
  510. // Remove the padding
  511. // ggml_tensor * conv_output_no_padding = ggml_view_4d(ctx0, conv_output, conv_output->ne[0] - (conv_kernel_size - 1), conv_output->ne[1], conv_output->ne[2], conv_output->ne[3],
  512. // conv_output->nb[1], conv_output->nb[2], conv_output->nb[3],
  513. // (conv_kernel_size - 1) * ggml_element_size(conv_output));
  514. // cb(conv_output_no_padding, "conv_output_no_padding", il);
  515. // // Take only the first n_seq_tokens values
  516. // ggml_tensor * conv_output_proper = ggml_view_4d(ctx0, conv_output_no_padding, n_seq_tokens, conv_output_no_padding->ne[1], conv_output_no_padding->ne[2], conv_output_no_padding->ne[3],
  517. // conv_output_no_padding->nb[1], conv_output_no_padding->nb[2], conv_output_no_padding->nb[3], 0);
  518. // cb(conv_output_proper, "conv_output_proper", il);
  519. conv_output_proper = ggml_cont(ctx0, ggml_transpose(ctx0, conv_output_proper));
  520. cb(conv_output_proper, "conv_output_pre_silu", il);
  521. ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
  522. cb(conv_output_silu, "conv_output_silu", il);
  523. ggml_tensor * conv_qkv_mix = ggml_cont_2d(ctx0, ggml_transpose(ctx0, conv_output_silu), qkv_dim, n_seq_tokens * n_seqs);
  524. cb(conv_qkv_mix, "conv_qkv_mix", il);
  525. // Extract the convolved Q, K, V from conv_output
  526. ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs,
  527. conv_qkv_mix->nb[1], 0);
  528. cb(q_conv, "q_conv", il);
  529. ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs,
  530. conv_qkv_mix->nb[1], head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
  531. cb(k_conv, "k_conv", il);
  532. ggml_tensor * v_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs,
  533. conv_qkv_mix->nb[1], 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
  534. cb(v_conv, "v_conv", il);
  535. // Unsqueeze them
  536. q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
  537. k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
  538. v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
  539. beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
  540. ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
  541. state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
  542. cb(state, "state_predelta", il);
  543. // if head keys and value keys are different, repeat to force tensors into matching shapes
  544. if (num_k_heads != num_v_heads) {
  545. GGML_ASSERT(num_v_heads % num_k_heads == 0);
  546. int64_t repeat_factor = num_v_heads / num_k_heads;
  547. ggml_tensor * q_reshaped = ggml_reshape_4d(ctx0, q_conv, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
  548. ggml_tensor * k_reshaped = ggml_reshape_4d(ctx0, k_conv, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
  549. // Repeat along the third dimension (the new dimension with size 1)
  550. ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
  551. ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
  552. // Reshape back to merge the head and repeat dimensions
  553. // From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs]
  554. // Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs]
  555. q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
  556. k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
  557. }
  558. cb(q_conv, "q_conv_predelta", il);
  559. cb(k_conv, "k_conv_predelta", il);
  560. cb(v_conv, "v_conv_predelta", il);
  561. // Choose between delta_net and delta_net_recurrent based on generation mode
  562. ggml_tensor * attn_out;
  563. if (is_generation) {
  564. // Use delta_net_recurrent for single token generation
  565. attn_out = delta_net_recurrent(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, hparams.f_norm_rms_eps, il);
  566. } else {
  567. // Use regular delta_net for prompt processing
  568. attn_out = delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, hparams.f_norm_rms_eps, il);
  569. }
  570. cb(attn_out, "attn_out", il);
  571. // The tensors were concatenated 1d, so we need to extract them 1d as well
  572. const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs;
  573. ggml_tensor * attn_out_1d =
  574. ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
  575. cb(attn_out_1d, "attn_out_1d", il);
  576. ggml_tensor * attn_out_final = ggml_cont(ctx0, ggml_permute(ctx0, ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, n_seq_tokens, num_v_heads, n_seqs), 0, 2, 1, 3));
  577. cb(attn_out_final, "attn_out_final", il);
  578. // Extract the state part (second part of the concatenated tensor)
  579. // State starts after n_tokens elements along dimension 1
  580. const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs;
  581. ggml_tensor * state_1d = ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out));
  582. cb(state_1d, "state_1d", il);
  583. // Update the recurrent states
  584. ggml_build_forward_expand(gf,
  585. ggml_cpy(ctx0, state_1d, ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
  586. kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
  587. GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out));
  588. // Reshape both attn_out_final and z to 2D tensors for normalization
  589. // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
  590. ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out_final), head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
  591. // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
  592. ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
  593. // Apply gated normalization: self.norm(core_attn_out, z)
  594. ggml_tensor * attn_out_norm = build_q3n_gated_norm(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
  595. // Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
  596. ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, attn_out_norm, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
  597. // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
  598. ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
  599. cb(final_output, "final_output", il);
  600. // Output projection
  601. cur = build_lora_mm(model.layers[il].ssm_out, final_output);
  602. cb(cur, "linear_attn_out", il);
  603. // Reshape back to original dimensions
  604. cur = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs));
  605. return cur;
  606. }
  607. ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il) {
  608. // Check if this is an MoE layer
  609. if (model.layers[il].ffn_gate_inp != nullptr) {
  610. // MoE branch
  611. ggml_tensor * moe_out =
  612. build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
  613. model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, nullptr, n_expert,
  614. n_expert_used, LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
  615. cb(moe_out, "ffn_moe_out", il);
  616. // Add shared experts if present - following Qwen3Next reference implementation
  617. if (model.layers[il].ffn_up_shexp != nullptr) {
  618. ggml_tensor * ffn_shexp =
  619. build_ffn(cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL,
  620. model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
  621. cb(ffn_shexp, "ffn_shexp", il);
  622. // Apply shared expert gating as in the reference implementation
  623. // The shared expert has its own gate that is sigmoided
  624. // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token)
  625. ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
  626. cb(shared_gate, "shared_expert_gate", il);
  627. // Apply sigmoid to the gate
  628. shared_gate = ggml_sigmoid(ctx0, shared_gate);
  629. cb(shared_gate, "shared_expert_gate_sigmoid", il);
  630. // The gate needs to be broadcast to match the dimensions of ffn_shexp
  631. // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1]
  632. // We need to repeat the gate along the feature dimension
  633. shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp);
  634. cb(shared_gate, "shared_expert_gate_broadcast", il);
  635. // Apply the gate to the shared expert output
  636. ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
  637. cb(ffn_shexp, "ffn_shexp_gated", il);
  638. cur = ggml_add(ctx0, moe_out, ffn_shexp);
  639. cb(cur, "ffn_out", il);
  640. } else {
  641. cur = moe_out;
  642. }
  643. } else {
  644. // Dense FFN branch
  645. cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL,
  646. model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
  647. cb(cur, "ffn_out", il);
  648. }
  649. return cur;
  650. }