llm_build_qwen3next.cpp 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578
  1. #include "llm_build_qwen3next.h"
  2. #include <cmath>
  3. llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) :
  4. llm_graph_context_mamba(params) {
  5. const int64_t n_embd_head = hparams.n_embd_head_v;
  6. GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
  7. ggml_tensor * cur;
  8. ggml_tensor * inpL;
  9. inpL = build_inp_embd(model.tok_embd);
  10. cb(inpL, "model.embed_tokens", -1);
  11. auto * inp = build_inp_mem_hybrid();
  12. ggml_tensor * inp_pos = build_inp_pos();
  13. ggml_tensor * inp_out_ids = build_inp_out_ids();
  14. for (int il = 0; il < n_layer; ++il) {
  15. struct ggml_tensor * inpSA = inpL;
  16. cur = build_q3n_norm(inpL, model.layers[il].attn_norm, il);
  17. cb(cur, "attn_norm", il);
  18. // Determine layer type and build appropriate attention mechanism
  19. if (hparams.is_recurrent(il)) {
  20. // Linear attention layer (gated delta net)
  21. cur = build_qwen3next_linear_attn_layer(inp->get_recr(), cur, model, ubatch, il);
  22. } else {
  23. // Full attention layer
  24. cur = build_qwen3next_attention_layer(cur, inp_pos, inp->get_attn(), model, n_embd_head, il);
  25. }
  26. // Post-attention norm
  27. cur = build_q3n_norm(cur, model.layers[il].attn_post_norm, il);
  28. cb(cur, "attn_post_norm", il);
  29. if (il == n_layer - 1 && inp_out_ids) {
  30. cur = ggml_get_rows(ctx0, cur, inp_out_ids);
  31. inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
  32. }
  33. // Residual connection
  34. cur = ggml_add(ctx0, cur, inpSA);
  35. cb(cur, "attn_residual", il);
  36. // FFN layer (MoE or dense)
  37. cur = build_layer_ffn(cur, model, il);
  38. cb(cur, "post_moe", il);
  39. // Input for next layer
  40. inpL = cur;
  41. }
  42. cur = inpL;
  43. // Final norm
  44. cur = build_q3n_norm(cur, model.output_norm, -1);
  45. cb(cur, "result_norm", -1);
  46. res->t_embd = cur;
  47. // LM head
  48. cur = build_lora_mm(model.output, cur);
  49. cb(cur, "result_output", -1);
  50. ggml_set_output(cur);
  51. res->t_logits = cur;
  52. ggml_build_forward_expand(gf, cur);
  53. }
  54. struct ggml_tensor * llm_build_qwen3next::build_q3n_norm(struct ggml_tensor * input, struct ggml_tensor * weights, int layer) {
  55. ggml_tensor * input_norm = ggml_scale_bias(ctx0, weights, 1.0f, 1.0f);
  56. return build_norm(input, input_norm, nullptr, LLM_NORM_RMS, layer);
  57. }
  58. // ggml_delta_net
  59. struct ggml_tensor * llm_build_qwen3next::ggml_delta_net(struct ggml_tensor * k,
  60. struct ggml_tensor * v,
  61. struct ggml_tensor * q,
  62. struct ggml_tensor * g,
  63. struct ggml_tensor * beta,
  64. struct ggml_tensor * state,
  65. bool use_qk_l2norm,
  66. float scale,
  67. int il) {
  68. GGML_ASSERT(ggml_is_contiguous(k));
  69. GGML_ASSERT(ggml_is_contiguous(v));
  70. GGML_ASSERT(ggml_is_contiguous(q));
  71. GGML_ASSERT(ggml_is_contiguous(g));
  72. GGML_ASSERT(ggml_is_contiguous(beta));
  73. GGML_ASSERT(ggml_is_contiguous(state));
  74. const int64_t S_k = k->ne[0];
  75. const int64_t H_k = k->ne[1];
  76. const int64_t n_tokens = k->ne[2];
  77. const int64_t n_seqs = k->ne[3];
  78. const int64_t S_v = v->ne[0];
  79. const int64_t H_v = v->ne[1];
  80. GGML_ASSERT(v->ne[2] == n_tokens);
  81. GGML_ASSERT(q->ne[2] == n_tokens);
  82. GGML_ASSERT(beta->ne[0] == H_v && beta->ne[1] == n_tokens && beta->ne[3] == n_seqs);
  83. GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == n_seqs && state->ne[3] == n_tokens);
  84. GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens);
  85. GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens);
  86. GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
  87. // Beta sigmoid
  88. struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx0, beta);
  89. cb(beta_sigmoid, "beta_sigmoid", il);
  90. // Gate calculations are done elsewhere in llama-model.cpp
  91. struct ggml_tensor * q_broadcast = q;
  92. struct ggml_tensor * k_broadcast = k;
  93. // if head keys and value keys are different, repeat to force tensors into matching shapes
  94. if (H_k != H_v) {
  95. GGML_ASSERT(H_v % H_k == 0);
  96. int64_t repeat_factor = H_v / H_k;
  97. q_broadcast = ggml_cont_4d(ctx0, q, S_k, n_tokens, H_k, n_seqs);
  98. k_broadcast = ggml_cont_4d(ctx0, k, S_k, n_tokens, H_k, n_seqs);
  99. q_broadcast = ggml_repeat_4d(ctx0, q_broadcast, S_k, n_tokens * repeat_factor, H_k, n_seqs);
  100. k_broadcast = ggml_repeat_4d(ctx0, k_broadcast, S_k, n_tokens * repeat_factor, H_k, n_seqs);
  101. q_broadcast = ggml_reshape_4d(ctx0, q_broadcast, S_k, H_v, n_seqs, n_tokens);
  102. k_broadcast = ggml_reshape_4d(ctx0, k_broadcast, S_k, H_v, n_seqs, n_tokens);
  103. }
  104. struct ggml_tensor * v_reshape = ggml_cont_4d(ctx0, v, S_v, H_v, n_seqs, n_tokens);
  105. struct ggml_tensor * g_reshape = ggml_cont_4d(ctx0, g, S_v, H_v, n_seqs, n_tokens);
  106. struct ggml_tensor * beta_broadcast = ggml_cont_4d(ctx0, beta_sigmoid, 1, H_v, n_seqs, n_tokens);
  107. struct ggml_tensor * state_broadcast = ggml_cont(ctx0, state);
  108. return ggml_delta_net_op(q_broadcast, k_broadcast, v_reshape, g_reshape, beta_broadcast, state_broadcast,
  109. use_qk_l2norm, scale, il);
  110. }
  111. struct ggml_tensor * llm_build_qwen3next::ggml_delta_net_op(struct ggml_tensor * q,
  112. struct ggml_tensor * k,
  113. struct ggml_tensor * v,
  114. struct ggml_tensor * g,
  115. struct ggml_tensor * beta,
  116. struct ggml_tensor * state,
  117. bool use_qk_l2norm,
  118. float scale,
  119. int il) {
  120. GGML_ASSERT(ggml_is_contiguous(q));
  121. GGML_ASSERT(ggml_is_contiguous(k));
  122. GGML_ASSERT(ggml_is_contiguous(v));
  123. GGML_ASSERT(ggml_is_contiguous(g));
  124. GGML_ASSERT(ggml_is_contiguous(beta));
  125. GGML_ASSERT(ggml_is_contiguous(state));
  126. const int64_t S_k = q->ne[0];
  127. const int64_t H_k = q->ne[1];
  128. const int64_t n_seq = q->ne[2];
  129. const int64_t n_tokens = q->ne[3];
  130. const int64_t S_v = v->ne[0];
  131. const int64_t H_v = v->ne[1];
  132. GGML_ASSERT(H_k == H_v); // we broadcasted the tensors in the main function to guarantee this
  133. GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_v && k->ne[2] == n_seq && k->ne[3] == n_tokens);
  134. GGML_ASSERT(v->ne[1] == H_v && v->ne[2] == n_seq && v->ne[3] == n_tokens);
  135. GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[2] == n_seq && g->ne[3] == n_tokens);
  136. GGML_ASSERT(beta->ne[0] == 1 && beta->ne[1] == H_v && beta->ne[2] == n_seq && beta->ne[3] == n_tokens);
  137. GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == n_seq && state->ne[3] == n_tokens);
  138. struct ggml_tensor * new_state = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, S_v, S_v * H_v, n_seq, n_tokens);
  139. new_state = ggml_cpy(ctx0, state, new_state);
  140. cb(new_state, "new_state", il);
  141. if (use_qk_l2norm) {
  142. q = ggml_l2_norm(ctx0, q, 1e-6f);
  143. cb(q, "q_l2_norm", il);
  144. k = ggml_l2_norm(ctx0, k, 1e-6f);
  145. cb(q, "k_l2_norm", il);
  146. }
  147. q = ggml_scale(ctx0, q, scale);
  148. cb(q, "q_scaled", il);
  149. struct ggml_tensor * state_decay = ggml_mul(ctx0, state, g);
  150. cb(state_decay, "state_decay", il);
  151. struct ggml_tensor * kv_mem_presum = ggml_mul(ctx0, state_decay, k);
  152. // Gotta do some squeezing here...
  153. struct ggml_tensor * kv_mem_presum_squeeze = ggml_reshape_4d(ctx0, kv_mem_presum, S_v, S_v, H_v, n_seq * n_tokens);
  154. struct ggml_tensor * kv_mem = ggml_permute(
  155. ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, kv_mem_presum_squeeze, 1, 2, 0, 3))), 2, 0, 1, 3);
  156. cb(kv_mem, "kv_mem", il);
  157. struct ggml_tensor * kv_mem_reshape = ggml_reshape_4d(ctx0, kv_mem, S_v, S_v, n_seq, n_tokens);
  158. struct ggml_tensor * delta = ggml_mul(ctx0, ggml_sub(ctx0, kv_mem_reshape, v), beta);
  159. cb(delta, "delta", il);
  160. struct ggml_tensor * delta_kt = ggml_mul(ctx0, delta, k);
  161. cb(delta_kt, "delta_kt", il);
  162. struct ggml_tensor * state_plus_k_delta = ggml_add(ctx0, state_decay, delta_kt);
  163. cb(state_plus_k_delta, "state_plus_k_delta", il);
  164. struct ggml_tensor * state_q = ggml_mul(ctx0, state_plus_k_delta, q);
  165. cb(state_q, "state_q", il);
  166. // And here...
  167. state_q = ggml_reshape_4d(ctx0, state_q, S_v, S_v, H_v, n_seq * n_tokens);
  168. struct ggml_tensor * output = ggml_permute(ctx0, ggml_sum_rows(ctx0, state_q), 2, 0, 1, 3);
  169. output = ggml_reshape_4d(ctx0, output, S_v, H_v, n_seq, n_tokens);
  170. cb(output, "delta_net_output", il);
  171. struct ggml_tensor * result = ggml_concat(ctx0, output, state_plus_k_delta, 1);
  172. cb(result, "delta_net_result", il);
  173. return result;
  174. }
  175. ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_tensor * cur,
  176. ggml_tensor * inp_pos,
  177. llm_graph_input_attn_kv * inp_attn,
  178. const llama_model & model,
  179. const int64_t n_embd_head,
  180. const int il) {
  181. ggml_tensor * gate = build_lora_mm(model.layers[il].wq_gate, cur);
  182. // compute Q and K and RoPE them
  183. struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
  184. cb(Qcur, "Qcur", il);
  185. struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
  186. cb(Kcur, "Kcur", il);
  187. struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
  188. cb(Vcur, "Vcur", il);
  189. Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
  190. Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
  191. Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
  192. // Apply Q/K normalization
  193. Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
  194. Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
  195. cb(Kcur, "Qcur_normed", il);
  196. cb(Kcur, "Kcur_normed", il);
  197. // Apply RoPE
  198. Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
  199. attn_factor, beta_fast, beta_slow);
  200. Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
  201. attn_factor, beta_fast, beta_slow);
  202. cb(Qcur, "Qcur", il);
  203. cb(Kcur, "Kcur", il);
  204. cb(Vcur, "Vcur", il);
  205. // Attention computation
  206. const float kq_scale =
  207. hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
  208. cur = build_attn(inp_attn, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
  209. // Apply gating
  210. cur = ggml_cont(ctx0, ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)));
  211. cb(cur, "attn_gated", il);
  212. cur = build_lora_mm(model.layers[il].wo, cur);
  213. cb(cur, "attn_output", il);
  214. return cur;
  215. }
  216. ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_input_rs * inp,
  217. ggml_tensor * cur,
  218. const llama_model & model,
  219. const llama_ubatch & ubatch,
  220. int il) {
  221. // Gated Delta Net implementation using the new ggml_delta_net function
  222. const auto * mctx_cur = inp->mctx;
  223. const int64_t d_inner = hparams.ssm_d_inner;
  224. const int64_t n_heads = hparams.ssm_dt_rank;
  225. const int64_t head_dim = d_inner / n_heads;
  226. const int64_t n_seqs = ubatch.n_seqs;
  227. const int64_t head_k_dim = hparams.ssm_d_state;
  228. const int64_t head_v_dim = hparams.ssm_d_state;
  229. const int64_t num_k_heads = hparams.ssm_n_group;
  230. const int64_t num_v_heads = hparams.ssm_dt_rank;
  231. const int64_t n_seq_tokens = ubatch.n_seq_tokens;
  232. const int64_t n_tokens = ubatch.n_tokens;
  233. GGML_ASSERT(n_seqs != 0);
  234. GGML_ASSERT(ubatch.equal_seqs());
  235. GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
  236. // Input projections
  237. ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur);
  238. cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
  239. ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
  240. cb(mixed_ba, "linear_attn_mixed_ba", il);
  241. int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads;
  242. ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tokens, n_seqs);
  243. // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
  244. int64_t ba_new_dim = 2 * num_v_heads / num_k_heads;
  245. ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tokens, n_seqs);
  246. // Split mixed_ba into b and a (beta and alpha parameters)
  247. int64_t split_sizes_ba[2] = {
  248. num_v_heads / num_k_heads, // beta size
  249. num_v_heads / num_k_heads // alpha size
  250. };
  251. ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tokens, n_seqs,
  252. mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0);
  253. cb(b, "b", il);
  254. ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tokens, n_seqs,
  255. mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3],
  256. split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
  257. cb(a, "a", il);
  258. // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
  259. ggml_tensor * beta = ggml_reshape_3d(ctx0, ggml_cont(ctx0, b), num_v_heads, n_tokens, n_seqs);
  260. ggml_tensor * alpha = ggml_reshape_3d(ctx0, ggml_cont(ctx0, a), num_v_heads, n_tokens, n_seqs);
  261. GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
  262. ggml_tensor * alpha_softplus = softplus(alpha, model.layers[il].ssm_dt);
  263. cb(alpha_softplus, "a_softplus", il);
  264. ggml_tensor * A_log_exp = ggml_exp(ctx0, model.layers[il].ssm_a); // A_log.exp()
  265. cb(A_log_exp, "a_logexp", il);
  266. ggml_tensor * gate_scaled = ggml_mul(ctx0, alpha_softplus, A_log_exp); // A_log.exp() * softplus
  267. cb(gate_scaled, "gate_scaled", il);
  268. ggml_tensor * gate = ggml_scale(ctx0, gate_scaled, -1.0f); // - (A_log.exp() * softplus)
  269. cb(gate, "gate", il);
  270. // Get convolution states from cache
  271. ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
  272. ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
  273. // Build the convolution states tensor
  274. ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
  275. cb(conv_states, "conv_states", il);
  276. // Split mixed_qkvz into query, key, value, z
  277. int64_t split_sizes_qkvz[4] = {
  278. head_k_dim, // query size
  279. head_k_dim, // key size
  280. head_v_dim * num_v_heads / num_k_heads, // value size
  281. head_v_dim * num_v_heads / num_k_heads // z size
  282. };
  283. ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_tokens, n_seqs,
  284. mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0));
  285. cb(query, "q", il);
  286. ggml_tensor * key = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, n_seqs,
  287. mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
  288. split_sizes_qkvz[0] * sizeof(float)));
  289. cb(key, "k", il);
  290. ggml_tensor * value = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tokens, n_seqs,
  291. mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
  292. (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)));
  293. cb(value, "v", il);
  294. ggml_tensor * z = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tokens, n_seqs,
  295. mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
  296. (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)));
  297. cb(z, "z", il);
  298. // Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
  299. ggml_tensor * value_reshaped =
  300. ggml_reshape_4d(ctx0, ggml_cont(ctx0, value), head_v_dim, num_v_heads, n_tokens, n_seqs);
  301. ggml_tensor * z_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, z), head_v_dim, num_v_heads, n_tokens, n_seqs);
  302. GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value_reshaped) +
  303. ggml_nelements(z_reshaped) ==
  304. ggml_nelements(mixed_qkvz));
  305. // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
  306. // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
  307. ggml_tensor * query_flat = ggml_reshape_3d(ctx0, query, head_k_dim * num_k_heads, n_tokens, n_seqs);
  308. cb(query_flat, "query_flat", il);
  309. // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
  310. ggml_tensor * key_flat = ggml_reshape_3d(ctx0, key, head_k_dim * num_k_heads, n_tokens, n_seqs);
  311. cb(key_flat, "key_flat", il);
  312. // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
  313. ggml_tensor * value_flat = ggml_reshape_3d(ctx0, value_reshaped, head_v_dim * num_v_heads, n_tokens, n_seqs);
  314. cb(value_flat, "value_flat", il);
  315. // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
  316. ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
  317. qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
  318. cb(qkv_mixed, "qkv_mixed_concatenated", il);
  319. // Calculate the total conv dimension
  320. int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
  321. // Reshape to [n_tokens, qkv_dim, n_seqs] for proper convolution input format
  322. qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, n_tokens, qkv_dim, n_seqs);
  323. cb(qkv_mixed, "qkv_mixed_for_conv", il);
  324. // Calculate convolution kernel size
  325. const int64_t conv_kernel_size = model.layers[il].ssm_conv1d->ne[0];
  326. conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state, n_seqs);
  327. cb(conv_states, "conv_states_reshaped", il);
  328. // Now concatenate along the sequence dimension (dim 0 in Llama.cpp)
  329. ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
  330. cb(conv_input, "conv_input", il);
  331. // Apply convolution
  332. ggml_tensor * conv_output = ggml_ssm_conv(ctx0, conv_input, model.layers[il].ssm_conv1d);
  333. cb(conv_output, "conv_output_raw", il);
  334. if (model.layers[il].ssm_conv1d_b) {
  335. conv_output = ggml_add(ctx0, conv_output, model.layers[il].ssm_conv1d_b);
  336. cb(conv_output, "conv_output_bias", il);
  337. }
  338. conv_output = ggml_silu(ctx0, conv_output);
  339. cb(conv_output, "conv_output_silu", il);
  340. // Update convolution state cache
  341. // Extract the last (conv_kernel_size - 1) states from conv_input
  342. ggml_tensor * last_conv_states =
  343. ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, qkv_dim, n_seqs, conv_input->nb[1], conv_input->nb[2],
  344. n_seq_tokens * conv_input->nb[0]);
  345. ggml_build_forward_expand(gf,
  346. ggml_cpy(ctx0, last_conv_states,
  347. ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * qkv_dim * n_seqs,
  348. mctx_cur->get_head() * (conv_kernel_size - 1) * qkv_dim *
  349. ggml_element_size(conv_states_all))));
  350. cb(conv_states_all, "conv_states_updated", il);
  351. // Reshape conv_output back to proper dimensions
  352. conv_output = ggml_reshape_4d(ctx0, conv_output, qkv_dim, n_seqs, n_seq_tokens, 1);
  353. cb(conv_output, "conv_output_reshaped", il);
  354. conv_output = ggml_permute(ctx0, conv_output, 0, 2, 1, 3);
  355. cb(conv_output, "conv_output_final", il);
  356. // Extract the convolved Q, K, V from conv_output
  357. ggml_tensor * q_conv = ggml_cont(ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens, n_seqs,
  358. conv_output->nb[1], conv_output->nb[2], conv_output->nb[3], 0));
  359. cb(q_conv, "q_conv", il);
  360. ggml_tensor * k_conv = ggml_cont(
  361. ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens, n_seqs,
  362. conv_output->nb[1], conv_output->nb[2], conv_output->nb[3],
  363. head_k_dim * num_k_heads * ggml_element_size(conv_output)));
  364. cb(q_conv, "k_conv", il);
  365. ggml_tensor * v_conv = ggml_cont(
  366. ctx0, ggml_view_4d(ctx0, conv_output, head_v_dim, num_v_heads, n_tokens, n_seqs,
  367. conv_output->nb[1], conv_output->nb[2], conv_output->nb[3],
  368. 2 * head_k_dim * num_k_heads * ggml_element_size(conv_output)));
  369. cb(q_conv, "v_conv", il);
  370. ggml_build_forward_expand(gf, ssm_states_all);
  371. // Beta tensor
  372. beta = ggml_reshape_3d(ctx0, beta, n_heads, n_tokens, n_seqs);
  373. ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_states_all, head_dim, head_dim * n_heads, 1, 1);
  374. ggml_tensor * state_broadcast = ggml_repeat_4d(ctx0, state, head_dim, head_dim * n_heads, n_seqs, n_tokens);
  375. ggml_tensor * target_gate = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_dim, n_heads, n_tokens, n_seqs);
  376. ggml_tensor * gate_broadcast = ggml_reshape_4d(ctx0, gate, 1, n_heads, n_tokens, n_seqs);
  377. gate = ggml_repeat(ctx0, gate_broadcast, target_gate);
  378. // Call the new ggml_delta_net function with the corrected flow
  379. ggml_tensor * output = ggml_delta_net(k_conv, v_conv, q_conv, gate, beta, state_broadcast, true, 1.0f, il);
  380. // Extract the output part
  381. ggml_tensor * attn_out =
  382. ggml_view_4d(ctx0, output, head_dim, n_heads, n_tokens, n_seqs, output->nb[0], output->nb[1], output->nb[2], 0);
  383. cb(output, "attn_out", il);
  384. // Extract the new state
  385. ggml_tensor * new_state =
  386. ggml_view_4d(ctx0, output, head_dim, head_dim * n_heads, n_tokens, n_seqs, output->nb[0], output->nb[1],
  387. output->nb[2], n_tokens * n_seqs * head_dim * n_heads * ggml_element_size(output));
  388. cb(output, "new_state", il);
  389. // Only return the last recurrent state
  390. struct ggml_tensor * state_reshaped = ggml_cont_4d(ctx0, new_state, head_dim, head_dim, n_heads, n_tokens * n_seqs);
  391. struct ggml_tensor * state_last =
  392. ggml_view_4d(ctx0, state_reshaped, head_dim, head_dim, n_heads, 1, state_reshaped->nb[1], state_reshaped->nb[2],
  393. state_reshaped->nb[3], head_dim * head_dim * n_heads * ((n_seqs * n_tokens) - 1));
  394. cb(output, "new_state_last", il);
  395. // Update the recurrent states
  396. ggml_build_forward_expand(gf, ggml_cpy(ctx0, state_last, ssm_states_all));
  397. // Reshape both attn_out and z to 2D tensors for normalization
  398. // attn_out: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
  399. ggml_tensor * attn_out_2d = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out), head_dim, n_heads * n_tokens * n_seqs);
  400. // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
  401. ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_tokens * n_seqs);
  402. // Apply gated normalization: self.norm(core_attn_out, z)
  403. // This is Qwen3NextRMSNormGated which applies: RMSNorm(x) * silu(gate)
  404. ggml_tensor * attn_out_norm = build_norm(attn_out_2d, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
  405. cb(output, "attn_out_norm", il);
  406. // Apply silu gate: attn_out_norm * silu(z_2d)
  407. ggml_tensor * z_silu = ggml_silu(ctx0, z_2d);
  408. cb(output, "z_silu", il);
  409. ggml_tensor * gated_output = ggml_mul(ctx0, attn_out_norm, z_silu);
  410. cb(output, "gated_output", il);
  411. // Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
  412. ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, gated_output, head_dim, n_heads, n_tokens, n_seqs);
  413. // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
  414. ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_tokens, n_seqs);
  415. cb(output, "final_output", il);
  416. // Output projection
  417. cur = build_lora_mm(model.layers[il].ssm_out, final_output);
  418. cb(cur, "linear_attn_out", il);
  419. // Reshape back to original dimensions
  420. cur = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens));
  421. return cur;
  422. }
  423. ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il) {
  424. // Check if this is an MoE layer
  425. if (model.layers[il].ffn_gate_inp != nullptr) {
  426. // MoE branch
  427. ggml_tensor * moe_out =
  428. build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
  429. model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, nullptr, n_expert,
  430. n_expert_used, LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
  431. cb(moe_out, "ffn_moe_out", il);
  432. // Add shared experts if present
  433. if (model.layers[il].ffn_up_shexp != nullptr) {
  434. ggml_tensor * ffn_shexp =
  435. build_ffn(cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL,
  436. model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
  437. cb(ffn_shexp, "ffn_shexp", il);
  438. cur = ggml_add(ctx0, moe_out, ffn_shexp);
  439. cb(cur, "ffn_out", il);
  440. } else {
  441. cur = moe_out;
  442. }
  443. } else {
  444. // Dense FFN branch
  445. cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL,
  446. model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
  447. cb(cur, "ffn_out", il);
  448. }
  449. // Residual connection
  450. cur = ggml_add(ctx0, cur, cur); // This should be the residual from before FFN
  451. cb(cur, "ffn_residual", il);
  452. return cur;
  453. };
  454. ggml_tensor * llm_build_qwen3next::softplus(ggml_tensor * alpha, ggml_tensor * dt_bias) {
  455. ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, dt_bias); // a + dt_bias
  456. ggml_tensor * alpha_exp = ggml_exp(ctx0, alpha_biased); // exp(a + dt_bias)
  457. ggml_tensor * one_plus_exp = ggml_scale_bias(ctx0, alpha_exp, 1.0f, 1.0f); // 1 + exp(a + dt_bias)
  458. ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp); // log(1 + exp(...))
  459. return alpha_softplus;
  460. }