llm_build_qwen3next.cpp 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694
  1. #include "llm_build_qwen3next.h"
  2. #include "../../ggml/src/ggml-impl.h"
  3. #include <cmath>
  4. // Implementation of depthwise 1D convolution using F32 to avoid F16 limitations
  5. static ggml_tensor* ggml_conv_1d_dw_f32(
  6. ggml_context * ctx,
  7. ggml_tensor * kernel,
  8. ggml_tensor * input,
  9. int stride,
  10. int padding,
  11. int dilation) {
  12. // Following the pattern from ggml_conv_1d_dw but using F32
  13. // Reshape input from [length, channels, batch, dummy] to [length, 1, channels, batch]
  14. ggml_tensor* reshaped_input = ggml_reshape_4d(ctx, input, input->ne[0], 1, input->ne[1], input->ne[2]);
  15. // Apply im2col with F32 destination type to avoid F16 requirement
  16. ggml_tensor* im2col_result = ggml_im2col(ctx, kernel, reshaped_input, stride, 0, padding, 0, dilation, 0, false, GGML_TYPE_F32);
  17. // Now multiply: im2col_result * kernel (following the exact pattern from ggml_conv_1d_dw)
  18. // In ggml_conv_1d_dw: ggml_mul_mat(ctx, im2col, a) where a is the kernel
  19. ggml_tensor* mul_result = ggml_mul_mat(ctx, im2col_result, kernel);
  20. // Reshape the result following ggml_conv_1d_dw: [result->ne[0], result->ne[2], 1]
  21. ggml_tensor* output_3d = ggml_reshape_3d(ctx, mul_result, mul_result->ne[0], mul_result->ne[2], 1);
  22. return output_3d;
  23. }
  24. llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) :
  25. llm_graph_context_mamba(params) {
  26. const int64_t n_embd_head = hparams.n_embd_head_v;
  27. GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
  28. ggml_tensor * cur;
  29. ggml_tensor * inpL;
  30. inpL = build_inp_embd(model.tok_embd);
  31. cb(inpL, "model.embed_tokens", -1);
  32. auto * inp = build_inp_mem_hybrid();
  33. ggml_tensor * inp_pos = build_inp_pos();
  34. ggml_tensor * inp_out_ids = build_inp_out_ids();
  35. for (int il = 0; il < n_layer; ++il) {
  36. struct ggml_tensor * inpSA = inpL;
  37. cur = build_q3n_norm(inpL, model.layers[il].attn_norm, il);
  38. cb(cur, "attn_norm", il);
  39. // Determine layer type and build appropriate attention mechanism
  40. if (hparams.is_recurrent(il)) {
  41. // Linear attention layer (gated delta net)
  42. cur = build_qwen3next_linear_attn_layer(inp->get_recr(), cur, model, ubatch, il);
  43. } else {
  44. // Full attention layer
  45. cur = build_qwen3next_attention_layer(cur, inp_pos, inp->get_attn(), model, n_embd_head, il);
  46. }
  47. if (il == n_layer - 1 && inp_out_ids) {
  48. cur = ggml_get_rows(ctx0, cur, inp_out_ids);
  49. inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
  50. }
  51. // Residual connection
  52. cur = ggml_add(ctx0, cur, inpSA);
  53. cb(cur, "attn_residual", il);
  54. // Save the tensor before post-attention norm for residual connection
  55. ggml_tensor * ffn_residual = cur;
  56. // Post-attention norm
  57. ggml_tensor * attn_post_norm = build_q3n_norm(cur, model.layers[il].attn_post_norm, il);
  58. cb(attn_post_norm, "attn_post_norm", il);
  59. // FFN layer (MoE or dense) - without residual connection
  60. cur = build_layer_ffn(attn_post_norm, model, il, false);
  61. cb(cur, "ffn_out", il);
  62. // Residual connection for FFN - add to the tensor BEFORE post_attention_layernorm
  63. cur = ggml_add(ctx0, cur, ffn_residual);
  64. cb(cur, "post_moe", il);
  65. // Input for next layer
  66. inpL = cur;
  67. }
  68. cur = inpL;
  69. // Final norm
  70. cur = build_q3n_norm(cur, model.output_norm, -1);
  71. cb(cur, "result_norm", -1);
  72. res->t_embd = cur;
  73. // LM head
  74. cur = build_lora_mm(model.output, cur);
  75. cb(cur, "result_output", -1);
  76. ggml_set_output(cur);
  77. res->t_logits = cur;
  78. ggml_build_forward_expand(gf, cur);
  79. }
  80. struct ggml_tensor * llm_build_qwen3next::build_q3n_norm(struct ggml_tensor * input, struct ggml_tensor * weights, int layer) {
  81. ggml_tensor * input_norm = ggml_scale_bias(ctx0, weights, 1.0f, 1.0f);
  82. return build_norm(input, input_norm, nullptr, LLM_NORM_RMS, layer);
  83. }
  84. struct ggml_tensor * llm_build_qwen3next::build_q3n_gated_norm(struct ggml_tensor * input, struct ggml_tensor * weights, struct ggml_tensor * gate, int layer) {
  85. ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer);
  86. ggml_tensor * gated_silu = ggml_silu(ctx0, gate);
  87. return ggml_mul(ctx0, normalized, gated_silu);
  88. }
  89. struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_tensor * cur,
  90. ggml_tensor * inp_pos,
  91. llm_graph_input_attn_kv * inp_attn,
  92. const llama_model & model,
  93. const int64_t n_embd_head,
  94. const int il) {
  95. // compute Q and K and RoPE them
  96. // Qwen3Next uses a single Q projection that outputs query + gate
  97. struct ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur);
  98. cb(Qcur_full, "Qcur_full", il);
  99. Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1);
  100. // Split Q projection into query and gate
  101. // The split should be along dimension 0 (the feature dimension)
  102. struct ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
  103. struct ggml_tensor * gate = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3],
  104. n_embd_head * ggml_element_size(Qcur_full));
  105. cb(Qcur, "Qcur", il);
  106. cb(gate, "gate", il);
  107. // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention
  108. Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
  109. cb(Qcur, "Qcur_reshaped", il);
  110. // Apply Q normalization only to the query part
  111. Qcur = build_q3n_norm(Qcur, model.layers[il].attn_q_norm, il);
  112. cb(Qcur, "Qcur_normed", il);
  113. // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
  114. gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
  115. cb(gate, "gate_reshaped", il);
  116. struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
  117. cb(Kcur, "Kcur", il);
  118. struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
  119. cb(Vcur, "Vcur", il);
  120. Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
  121. Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
  122. Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
  123. // Apply Q/K normalization
  124. Kcur = build_q3n_norm(Kcur, model.layers[il].attn_k_norm, il);
  125. cb(Kcur, "Kcur_normed", il);
  126. // Apply RoPE
  127. Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
  128. attn_factor, beta_fast, beta_slow);
  129. Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
  130. attn_factor, beta_fast, beta_slow);
  131. cb(Qcur, "Qcur", il);
  132. cb(Kcur, "Kcur", il);
  133. cb(Vcur, "Vcur", il);
  134. // Attention computation
  135. const float kq_scale =
  136. hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
  137. cur = build_attn(inp_attn, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
  138. // Apply gating directly using the original gate tensor
  139. cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate));
  140. cb(cur, "attn_gated", il);
  141. cur = build_lora_mm(model.layers[il].wo, cur);
  142. cb(cur, "attn_output", il);
  143. return cur;
  144. }
  145. // delta_net
  146. // prepare all the tensor data for the operation so we only
  147. // do the absolutely necessary steps in the op itself
  148. struct ggml_tensor * llm_build_qwen3next::delta_net(
  149. struct ggml_context * ctx,
  150. struct ggml_tensor * q,
  151. struct ggml_tensor * k,
  152. struct ggml_tensor * v,
  153. struct ggml_tensor * g,
  154. struct ggml_tensor * beta,
  155. struct ggml_tensor * state,
  156. bool use_qk_l2norm,
  157. float eps_norm,
  158. const int il
  159. ) {
  160. GGML_ASSERT(ggml_is_contiguous(q));
  161. GGML_ASSERT(ggml_is_contiguous(k));
  162. GGML_ASSERT(ggml_is_contiguous(v));
  163. GGML_ASSERT(ggml_is_contiguous(g));
  164. GGML_ASSERT(ggml_is_contiguous(beta));
  165. GGML_ASSERT(ggml_is_contiguous(state));
  166. const int64_t S_k = q->ne[0];
  167. const int64_t H_k = q->ne[1];
  168. const int64_t n_tokens = q->ne[2];
  169. const int64_t n_seqs = q->ne[3];
  170. const int64_t S_v = v->ne[0];
  171. const int64_t H_v = v->ne[1];
  172. GGML_ASSERT(v->ne[2] == n_tokens);
  173. GGML_ASSERT(k->ne[2] == n_tokens);
  174. GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
  175. GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
  176. GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == 1);
  177. GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
  178. GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
  179. GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
  180. cb(q, "q_prenorm", il);
  181. cb(k, "k_prenorm", il);
  182. if (use_qk_l2norm) {
  183. q = ggml_l2_norm(ctx, q, eps_norm);
  184. k = ggml_l2_norm(ctx, k, eps_norm);
  185. }
  186. cb(k, "k_postnorm", il);
  187. cb(q, "q_prescale", il);
  188. int64_t pad_size = (GGML_DELTA_NET_CHUNK - n_tokens % GGML_DELTA_NET_CHUNK) % GGML_DELTA_NET_CHUNK;
  189. // yes, n_tokens, not H_k, the reference implementation has wrong naming
  190. int64_t num_chunks = (n_tokens + pad_size) / GGML_DELTA_NET_CHUNK;
  191. float scale = 1.0f / sqrtf(S_v);
  192. q = ggml_scale(ctx, q, scale);
  193. cb(beta, "beta_raw", il);
  194. beta = ggml_sigmoid(ctx, beta);
  195. cb(q, "q_postscale", il);
  196. cb(beta, "beta_sigmoid", il);
  197. // First, permute to chunked format: [S_k, n_tokens, H_k, n_seqs]
  198. q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3));
  199. cb(q, "q_reshape", il);
  200. k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3));
  201. cb(k, "k_reshape", il);
  202. v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3));
  203. cb(v, "v_reshape", il);
  204. beta = ggml_cont(ctx, ggml_permute(ctx, beta, 1, 2, 0, 3));
  205. cb(beta, "beta_reshape", il);
  206. g = ggml_cont(ctx, ggml_permute(ctx, g, 2, 0, 3, 1));
  207. cb(g, "g_permute", il);
  208. // Then, pad the second dimension (n_tokens) to chunk_size
  209. q = ggml_pad(ctx, q, 0, pad_size, 0, 0);
  210. k = ggml_pad(ctx, k, 0, pad_size, 0, 0);
  211. v = ggml_pad(ctx, v, 0, pad_size, 0, 0);
  212. // ... except for beta and g, where we pad the last dimension
  213. beta = ggml_pad(ctx, beta, pad_size, 0, 0, 0);
  214. g = ggml_pad(ctx, g, pad_size, 0, 0, 0);
  215. cb(q, "q_pad", il);
  216. cb(k, "k_pad", il);
  217. cb(v, "v_pad", il);
  218. cb(beta, "beta_pad", il);
  219. cb(g, "g_pad", il);
  220. GGML_ASSERT(q->ne[1] % GGML_DELTA_NET_CHUNK == 0 && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
  221. GGML_ASSERT(k->ne[1] % GGML_DELTA_NET_CHUNK == 0 && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
  222. GGML_ASSERT(v->ne[1] % GGML_DELTA_NET_CHUNK == 0 && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
  223. GGML_ASSERT(beta->ne[0] % GGML_DELTA_NET_CHUNK == 0 && beta->ne[1] == H_k && beta->ne[2] == 1 && beta->ne[3] == n_seqs);
  224. GGML_ASSERT(g->ne[0] % GGML_DELTA_NET_CHUNK == 0 && g->ne[2] == H_k && g->ne[1] == 1 && g->ne[3] == n_seqs);
  225. ggml_tensor * beta_unsq = ggml_cont_4d(ctx, beta, 1, GGML_DELTA_NET_CHUNK * num_chunks, H_k, n_seqs);
  226. ggml_tensor * beta_bcast = ggml_repeat_4d(ctx, beta_unsq, S_v, GGML_DELTA_NET_CHUNK * num_chunks, H_k, n_seqs);
  227. cb(beta_unsq, "beta_unsq", il);
  228. cb(beta_bcast, "beta_bcast", il);
  229. struct ggml_tensor * v_beta = ggml_mul(ctx, v, beta_bcast);
  230. v_beta = ggml_reshape_4d(ctx, v_beta, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
  231. cb(v_beta, "v_beta", il);
  232. struct ggml_tensor * k_beta = ggml_mul(ctx, k, beta_bcast);
  233. k_beta = ggml_reshape_4d(ctx, k_beta, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
  234. cb(k_beta, "k_beta", il);
  235. k = ggml_reshape_4d(ctx, k, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
  236. cb(k_beta, "k_reshape", il);
  237. g = ggml_reshape_4d(ctx, g, GGML_DELTA_NET_CHUNK, 1, H_k * num_chunks, n_seqs);
  238. cb(g, "g_reshape", il);
  239. struct ggml_tensor * g_cumsum = ggml_cumsum(ctx, g);
  240. cb(g_cumsum, "g_cumsum", il);
  241. struct ggml_tensor * gcs_i = ggml_cont_4d(ctx, g_cumsum, GGML_DELTA_NET_CHUNK, 1, num_chunks * H_v, n_seqs); // [chunk_size, 1, n_tokens, n_seqs]
  242. struct ggml_tensor * gcs_j = ggml_cont_4d(ctx, g_cumsum, 1, GGML_DELTA_NET_CHUNK, num_chunks * H_v, n_seqs); // [1, chunk_size, n_tokens, n_seqs]
  243. // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs]
  244. struct ggml_tensor * gcs_i_broadcast = ggml_repeat_4d(ctx, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v, n_seqs); // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
  245. struct ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx, gcs_j, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v, n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
  246. struct ggml_tensor * decay_mask = ggml_sub(ctx, gcs_j_broadcast, gcs_i_broadcast);
  247. cb(decay_mask, "sub", il);
  248. // Apply lower triangular mask to ensure attention is causal (only past tokens influence current)
  249. decay_mask = ggml_tri_keep(ctx, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
  250. cb(decay_mask, "sub_tri", il);
  251. // Apply exponential to get the decay mask values
  252. decay_mask = ggml_exp(ctx, decay_mask);
  253. cb(decay_mask, "sub_tri_exp", il);
  254. // Apply lower triangular mask again to ensure only lower triangular values remain
  255. decay_mask = ggml_tri_keep(ctx, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
  256. cb(decay_mask, "decay_mask", il);
  257. struct ggml_tensor * kmulkbeta = ggml_mul_mat(ctx, ggml_cont(ctx, k), ggml_cont(ctx, k_beta));
  258. cb(kmulkbeta, "k_beta @ k_t ", il);
  259. struct ggml_tensor * k_decay = ggml_mul(ctx, kmulkbeta, decay_mask);
  260. cb(k_decay, "(k_beta @ k_t) * decay_mask", il);
  261. struct ggml_tensor * attn = ggml_neg(ctx, ggml_tri_keep(ctx, k_decay, GGML_TRI_TYPE_LOWER));
  262. cb(attn, "attn_in", il);
  263. // We'll be returning the result as a 1D tensor due to the dimensions mismatch of the state and output tensors
  264. const int64_t ne[1] = { (S_v * H_v * n_tokens * n_seqs ) + (S_v * S_v * H_v * n_seqs) };
  265. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 1, ne);
  266. ggml_set_op_params_i32(result, 0, H_v);
  267. ggml_set_op_params_i32(result, 1, S_k);
  268. ggml_set_op_params_i32(result, 2, S_v);
  269. ggml_set_op_params_i32(result, 3, n_tokens); // Pass original n_tokens
  270. result->op = GGML_OP_DELTA_NET;
  271. result->src[0] = q;
  272. result->src[1] = k;
  273. result->src[2] = v;
  274. result->src[3] = g_cumsum;
  275. result->src[4] = state;
  276. result->src[5] = decay_mask;
  277. result->src[6] = v_beta;
  278. result->src[7] = k_beta;
  279. result->src[8] = attn;
  280. return result;
  281. }
  282. ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_input_rs * inp,
  283. ggml_tensor * cur,
  284. const llama_model & model,
  285. const llama_ubatch & ubatch,
  286. int il) {
  287. // Gated Delta Net implementation using the new delta_net function
  288. const auto * mctx_cur = inp->mctx;
  289. const int64_t d_inner = hparams.ssm_d_inner;
  290. const int64_t n_heads = hparams.ssm_dt_rank;
  291. const int64_t head_dim = d_inner / n_heads;
  292. const int64_t n_seqs = ubatch.n_seqs;
  293. const int64_t head_k_dim = hparams.ssm_d_state;
  294. const int64_t head_v_dim = hparams.ssm_d_state;
  295. const int64_t num_k_heads = hparams.ssm_n_group;
  296. const int64_t num_v_heads = hparams.ssm_dt_rank;
  297. const int64_t n_seq_tokens = ubatch.n_seq_tokens;
  298. const int64_t n_tokens = ubatch.n_tokens;
  299. GGML_ASSERT(n_seqs != 0);
  300. GGML_ASSERT(ubatch.equal_seqs());
  301. GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
  302. // Input projections
  303. ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur);
  304. cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
  305. ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
  306. cb(mixed_ba, "linear_attn_mixed_ba", il);
  307. int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads;
  308. ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tokens, n_seqs);
  309. // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
  310. int64_t ba_new_dim = 2 * num_v_heads / num_k_heads;
  311. ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tokens, n_seqs);
  312. // Split mixed_ba into b and a (beta and alpha parameters)
  313. int64_t split_sizes_ba[2] = {
  314. num_v_heads / num_k_heads, // beta size
  315. num_v_heads / num_k_heads // alpha size
  316. };
  317. ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tokens, n_seqs,
  318. mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0);
  319. cb(b, "b", il);
  320. ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tokens, n_seqs,
  321. mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3],
  322. split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
  323. cb(a, "a", il);
  324. // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
  325. ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_tokens, n_seqs);
  326. ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_tokens, n_seqs);
  327. GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
  328. ggml_tensor * alpha_softplus = softplus(alpha, model.layers[il].ssm_dt);
  329. cb(alpha_softplus, "a_softplus", il);
  330. ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus
  331. cb(gate, "gate", il);
  332. // Get convolution states from cache
  333. ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
  334. ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
  335. // Build the convolution states tensor
  336. ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
  337. cb(conv_states, "conv_states", il);
  338. // Split mixed_qkvz into query, key, value, z
  339. int64_t split_sizes_qkvz[4] = {
  340. head_k_dim, // query size
  341. head_k_dim, // key size
  342. head_v_dim * num_v_heads / num_k_heads, // value size
  343. head_v_dim * num_v_heads / num_k_heads // z size
  344. };
  345. ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_tokens, n_seqs,
  346. mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0));
  347. cb(query, "q", il);
  348. ggml_tensor * key = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, n_seqs,
  349. mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
  350. split_sizes_qkvz[0] * sizeof(float)));
  351. cb(key, "k", il);
  352. ggml_tensor * value = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tokens, n_seqs,
  353. mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
  354. (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)));
  355. cb(value, "v", il);
  356. ggml_tensor * z = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tokens, n_seqs,
  357. mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
  358. (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)));
  359. cb(z, "z", il);
  360. // Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
  361. ggml_tensor * value_reshaped =
  362. ggml_reshape_4d(ctx0, ggml_cont(ctx0, value), head_v_dim, num_v_heads, n_tokens, n_seqs);
  363. ggml_tensor * z_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, z), head_v_dim, num_v_heads, n_tokens, n_seqs);
  364. GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value_reshaped) +
  365. ggml_nelements(z_reshaped) ==
  366. ggml_nelements(mixed_qkvz));
  367. // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
  368. // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
  369. ggml_tensor * query_flat = ggml_reshape_3d(ctx0, query, head_k_dim * num_k_heads, n_tokens, n_seqs);
  370. cb(query_flat, "query_flat", il);
  371. // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
  372. ggml_tensor * key_flat = ggml_reshape_3d(ctx0, key, head_k_dim * num_k_heads, n_tokens, n_seqs);
  373. cb(key_flat, "key_flat", il);
  374. // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
  375. ggml_tensor * value_flat = ggml_reshape_3d(ctx0, value_reshaped, head_v_dim * num_v_heads, n_tokens, n_seqs);
  376. cb(value_flat, "value_flat", il);
  377. // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
  378. ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
  379. qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
  380. qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
  381. cb(qkv_mixed, "qkv_mixed_concatenated", il);
  382. // Calculate the total conv dimension
  383. int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
  384. // Calculate convolution kernel size
  385. ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
  386. const int64_t conv_kernel_size = conv_kernel->ne[0];
  387. conv_kernel = ggml_permute(ctx0, conv_kernel, 0, 2, 1, 3);
  388. conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state, n_seqs);
  389. cb(conv_states, "conv_states_reshaped", il);
  390. ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
  391. cb(conv_input, "conv_input", il);
  392. // Apply convolution
  393. ggml_tensor * conv_output = ggml_conv_1d_dw_f32(ctx0, conv_kernel, conv_input, 1, conv_kernel_size - 1, n_seqs);
  394. cb(conv_output, "conv_output_raw", il);
  395. // Remove the padding
  396. ggml_tensor * conv_output_no_padding = ggml_view_4d(ctx0, conv_output, conv_output->ne[0] - (conv_kernel_size - 1), conv_output->ne[1], conv_output->ne[2], conv_output->ne[3],
  397. conv_output->nb[1], conv_output->nb[2], conv_output->nb[3],
  398. (conv_kernel_size - 1) * ggml_element_size(conv_output));
  399. cb(conv_output_no_padding, "conv_output_no_padding", il);
  400. // Take only the first (n_tokens * n_seqs) values
  401. ggml_tensor * conv_output_proper = ggml_view_4d(ctx0, conv_output_no_padding, n_tokens * n_seqs, conv_output_no_padding->ne[1], conv_output_no_padding->ne[2], conv_output_no_padding->ne[3],
  402. conv_output_no_padding->nb[1], conv_output_no_padding->nb[2], conv_output_no_padding->nb[3], 0);
  403. cb(conv_output_proper, "conv_output_proper", il);
  404. conv_output_proper = ggml_permute(ctx0, conv_output_proper, 0, 1, 3, 2);
  405. conv_output_proper = ggml_cont_4d(ctx0, conv_output_proper, qkv_dim, 1, n_tokens, n_seqs);
  406. ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
  407. cb(conv_output_silu, "conv_output_silu", il);
  408. // Update convolution state cache
  409. // Extract the last (conv_kernel_size - 1) states from conv_input
  410. ggml_tensor * last_conv_states =
  411. ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, qkv_dim, n_seqs, conv_input->nb[1], conv_input->nb[2],
  412. n_seq_tokens * conv_input->nb[0]);
  413. ggml_build_forward_expand(gf,
  414. ggml_cpy(ctx0, last_conv_states,
  415. ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * qkv_dim * n_seqs,
  416. mctx_cur->get_head() * (conv_kernel_size - 1) * qkv_dim *
  417. ggml_element_size(conv_states_all))));
  418. cb(conv_states_all, "conv_states_updated", il);
  419. conv_output_proper = ggml_reshape_2d(ctx0, conv_output_silu, n_tokens * n_seqs, qkv_dim);
  420. cb(conv_output_proper, "conv_output_final", il);
  421. ggml_tensor * conv_transposed = ggml_transpose(ctx0, conv_output_proper);
  422. cb(conv_transposed, "conv_transposed", il);
  423. ggml_tensor * conv_qkv_mix = ggml_cont_2d(ctx0, conv_transposed, qkv_dim, n_tokens * n_seqs);
  424. cb(conv_qkv_mix, "conv_qkv_mix", il);
  425. // Extract the convolved Q, K, V from conv_output
  426. ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_tokens * n_seqs,
  427. conv_qkv_mix->nb[1], 0);
  428. cb(q_conv, "q_conv", il);
  429. ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_tokens * n_seqs,
  430. conv_qkv_mix->nb[1], head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
  431. cb(k_conv, "k_conv", il);
  432. ggml_tensor * v_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_tokens * n_seqs,
  433. conv_qkv_mix->nb[1], 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
  434. cb(v_conv, "v_conv", il);
  435. // Unsqueeze them
  436. q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_tokens, n_seqs);
  437. k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_tokens, n_seqs);
  438. v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_tokens, n_seqs);
  439. beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_tokens, n_seqs);
  440. ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_states_all, head_dim, head_dim * n_heads, 1, 1);
  441. // if head keys and value keys are different, repeat to force tensors into matching shapes
  442. if (num_k_heads != num_v_heads) {
  443. GGML_ASSERT(num_v_heads % num_k_heads == 0);
  444. int64_t repeat_factor = num_v_heads / num_k_heads;
  445. q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_k_heads * repeat_factor, n_tokens, n_seqs);
  446. k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_k_heads * repeat_factor, n_tokens, n_seqs);
  447. }
  448. cb(q_conv, "q_conv_predelta", il);
  449. cb(k_conv, "k_conv_predelta", il);
  450. cb(v_conv, "v_conv_predelta", il);
  451. // Call the new delta_net function with the corrected flow
  452. ggml_tensor * attn_out = delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, hparams.f_norm_rms_eps, il);
  453. cb(attn_out, "attn_out", il);
  454. // The tensors were concatenated 1d, so we need to extract them 1d as well
  455. const int64_t output_flat_size = head_dim * n_heads * n_tokens * n_seqs;
  456. ggml_tensor * attn_out_1d =
  457. ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
  458. cb(attn_out_1d, "attn_out_1d", il);
  459. ggml_tensor * attn_out_final = ggml_cont(ctx0, ggml_permute(ctx0, ggml_cont_4d(ctx0, attn_out_1d, head_dim, n_tokens, n_heads, n_seqs), 0, 2, 1, 3));
  460. cb(attn_out_final, "attn_out_final", il);
  461. // Extract the state part (second part of the concatenated tensor)
  462. // State starts after n_tokens elements along dimension 1
  463. const int64_t state_flat_size = head_dim * head_dim * n_heads * n_seqs;
  464. ggml_tensor * state_1d = ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out));
  465. cb(state_1d, "state_1d", il);
  466. ggml_tensor * new_state = ggml_reshape_4d(ctx0, state_1d, head_dim, head_dim, n_heads, n_seqs);
  467. cb(new_state, "new_state", il);
  468. // Update the recurrent states - we use the new_state directly since it's already the last state
  469. ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_state, ssm_states_all));
  470. // Reshape both attn_out_final and z to 2D tensors for normalization
  471. // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
  472. ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out_final), head_dim, n_heads * n_tokens * n_seqs);
  473. // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
  474. ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_tokens * n_seqs);
  475. // Apply gated normalization: self.norm(core_attn_out, z)
  476. ggml_tensor * attn_out_norm = build_q3n_gated_norm(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
  477. // Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
  478. ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, attn_out_norm, head_dim, n_heads, n_tokens, n_seqs);
  479. // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
  480. ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_tokens, n_seqs);
  481. cb(final_output, "final_output", il);
  482. // Output projection
  483. cur = build_lora_mm(model.layers[il].ssm_out, final_output);
  484. cb(cur, "linear_attn_out", il);
  485. // Reshape back to original dimensions
  486. cur = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens));
  487. return cur;
  488. }
  489. ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il, bool do_residual) {
  490. // Check if this is an MoE layer
  491. if (model.layers[il].ffn_gate_inp != nullptr) {
  492. // MoE branch
  493. ggml_tensor * moe_out =
  494. build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
  495. model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, nullptr, n_expert,
  496. n_expert_used, LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
  497. cb(moe_out, "ffn_moe_out", il);
  498. // Add shared experts if present - following Qwen3Next reference implementation
  499. if (model.layers[il].ffn_up_shexp != nullptr) {
  500. ggml_tensor * ffn_shexp =
  501. build_ffn(cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL,
  502. model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
  503. cb(ffn_shexp, "ffn_shexp", il);
  504. // Apply shared expert gating as in the reference implementation
  505. // The shared expert has its own gate that is sigmoided
  506. // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token)
  507. ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
  508. cb(shared_gate, "shared_expert_gate", il);
  509. // Apply sigmoid to the gate
  510. shared_gate = ggml_sigmoid(ctx0, shared_gate);
  511. cb(shared_gate, "shared_expert_gate_sigmoid", il);
  512. // The gate needs to be broadcast to match the dimensions of ffn_shexp
  513. // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1]
  514. // We need to repeat the gate along the feature dimension
  515. shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp);
  516. cb(shared_gate, "shared_expert_gate_broadcast", il);
  517. // Apply the gate to the shared expert output
  518. ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
  519. cb(ffn_shexp, "ffn_shexp_gated", il);
  520. cur = ggml_add(ctx0, moe_out, ffn_shexp);
  521. cb(cur, "ffn_out", il);
  522. } else {
  523. cur = moe_out;
  524. }
  525. } else {
  526. // Dense FFN branch
  527. cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL,
  528. model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
  529. cb(cur, "ffn_out", il);
  530. }
  531. // Residual connection (only if requested)
  532. if (do_residual) {
  533. cur = ggml_add(ctx0, cur, cur);
  534. cb(cur, "ffn_residual", il);
  535. }
  536. cur = build_cvec(cur, il);
  537. cb(cur, "l_out", il);
  538. return cur;
  539. };
  540. ggml_tensor * llm_build_qwen3next::softplus(ggml_tensor * alpha, ggml_tensor * dt_bias) {
  541. ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, dt_bias); // a + dt_bias
  542. ggml_tensor * alpha_exp = ggml_exp(ctx0, alpha_biased); // exp(a + dt_bias)
  543. ggml_tensor * one_plus_exp = ggml_scale_bias(ctx0, alpha_exp, 1.0f, 1.0f); // 1 + exp(a + dt_bias)
  544. ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp); // log(1 + exp(...))
  545. return alpha_softplus;
  546. }