afmoe.cpp 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. #include "models.h"
  2. llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
  3. const int64_t n_embd_head = hparams.n_embd_head_v;
  4. GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
  5. ggml_tensor * cur;
  6. ggml_tensor * inpL;
  7. inpL = build_inp_embd(model.tok_embd);
  8. // MuP scaling: embeddings * sqrt(hidden_size)
  9. // mup_enabled = true, hidden_size = 1024, scale = 32.0
  10. inpL = ggml_scale(ctx0, inpL, sqrtf(float(n_embd)));
  11. cb(inpL, "inp_embd_scaled", -1);
  12. // inp_pos - contains the positions
  13. ggml_tensor * inp_pos = build_inp_pos();
  14. auto * inp_attn = build_attn_inp_kv_iswa();
  15. ggml_tensor * inp_out_ids = build_inp_out_ids();
  16. const float kq_scale = 1.0f/sqrtf(float(n_embd_head));
  17. for (int il = 0; il < n_layer; ++il) {
  18. const float freq_base_l = model.get_rope_freq_base (cparams, il);
  19. const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
  20. ggml_tensor * inpSA = inpL;
  21. // This overlaps with SWA layers in current models, so get_rope_freq_base/scale may be superfluous
  22. const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
  23. (il + 1) % hparams.n_no_rope_layer_step != 0;
  24. // dual attention normalization (pre)
  25. cur = build_norm(inpL,
  26. model.layers[il].attn_norm, NULL,
  27. LLM_NORM_RMS, il);
  28. cb(cur, "attn_norm", il);
  29. // self-attention
  30. {
  31. ggml_tensor * attn_inp = cur; // save input for gate computation
  32. ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
  33. cb(Qcur, "Qcur", il);
  34. ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
  35. cb(Kcur, "Kcur", il);
  36. ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
  37. cb(Vcur, "Vcur", il);
  38. // compute gate from input
  39. ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, attn_inp);
  40. cb(gate, "attn_gate_proj", il);
  41. Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
  42. Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
  43. // Q/K normalization
  44. Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
  45. Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
  46. cb(Qcur, "Qcur_normed", il);
  47. cb(Kcur, "Kcur_normed", il);
  48. if (use_rope) {
  49. Qcur = ggml_rope_ext(
  50. ctx0, Qcur, inp_pos, nullptr,
  51. n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
  52. ext_factor, attn_factor, beta_fast, beta_slow);
  53. cb(Qcur, "Qcur_rope", il);
  54. Kcur = ggml_rope_ext(
  55. ctx0, Kcur, inp_pos, nullptr,
  56. n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
  57. ext_factor, attn_factor, beta_fast, beta_slow);
  58. cb(Kcur, "Kcur_rope", il);
  59. }
  60. Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
  61. cur = build_attn(inp_attn,
  62. NULL, NULL, // wo will be applied after gating
  63. Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
  64. cb(cur, "attn_out", il);
  65. // attention gating: attn_out * sigmoid(gate) BEFORE o_proj
  66. gate = ggml_sigmoid(ctx0, gate);
  67. cb(gate, "attn_gate_sig", il);
  68. cur = ggml_mul(ctx0, cur, gate);
  69. cb(cur, "attn_gated", il);
  70. // now apply output projection
  71. cur = build_lora_mm(model.layers[il].wo, cur);
  72. cb(cur, "attn_o_proj", il);
  73. }
  74. // dual attention normalization (post)
  75. cur = build_norm(cur,
  76. model.layers[il].attn_post_norm, NULL,
  77. LLM_NORM_RMS, il);
  78. cb(cur, "attn_post_norm", il);
  79. if (il == n_layer - 1 && inp_out_ids) {
  80. cur = ggml_get_rows(ctx0, cur, inp_out_ids);
  81. inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
  82. }
  83. ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
  84. cb(ffn_inp, "ffn_inp", il);
  85. // dual ffn normalization (pre)
  86. cur = build_norm(ffn_inp,
  87. model.layers[il].ffn_norm, NULL,
  88. LLM_NORM_RMS, il);
  89. cb(cur, "ffn_norm", il);
  90. // MoE or dense FFN
  91. if ((uint32_t)il >= hparams.n_layer_dense_lead) {
  92. // MoE layer with sigmoid routing, normalization, and scaling
  93. ggml_tensor * moe_out = build_moe_ffn(cur,
  94. model.layers[il].ffn_gate_inp,
  95. model.layers[il].ffn_up_exps,
  96. model.layers[il].ffn_gate_exps,
  97. model.layers[il].ffn_down_exps,
  98. model.layers[il].ffn_exp_probs_b,
  99. n_expert, n_expert_used,
  100. LLM_FFN_SILU,
  101. hparams.expert_weights_norm, // norm_w (route_norm=True)
  102. hparams.expert_weights_scale, // scale_w
  103. hparams.expert_weights_scale, // w_scale (route_scale=2.826)
  104. (llama_expert_gating_func_type) hparams.expert_gating_func,
  105. il);
  106. cb(moe_out, "ffn_moe_out", il);
  107. // shared expert
  108. if (hparams.n_expert_shared > 0) {
  109. ggml_tensor * ffn_shexp = build_ffn(cur,
  110. model.layers[il].ffn_up_shexp, NULL, NULL,
  111. model.layers[il].ffn_gate_shexp, NULL, NULL,
  112. model.layers[il].ffn_down_shexp, NULL, NULL,
  113. NULL,
  114. LLM_FFN_SILU, LLM_FFN_PAR, il);
  115. cb(ffn_shexp, "ffn_shexp", il);
  116. cur = ggml_add(ctx0, moe_out, ffn_shexp);
  117. cb(cur, "ffn_out", il);
  118. } else {
  119. cur = moe_out;
  120. }
  121. } else {
  122. // dense layer
  123. cur = build_ffn(cur,
  124. model.layers[il].ffn_up, NULL, NULL,
  125. model.layers[il].ffn_gate, NULL, NULL,
  126. model.layers[il].ffn_down, NULL, NULL,
  127. NULL,
  128. LLM_FFN_SILU, LLM_FFN_PAR, il);
  129. cb(cur, "ffn_out", il);
  130. }
  131. // dual ffn normalization (post)
  132. cur = build_norm(cur,
  133. model.layers[il].ffn_post_norm, NULL,
  134. LLM_NORM_RMS, il);
  135. cb(cur, "ffn_post_norm", il);
  136. cur = ggml_add(ctx0, cur, ffn_inp);
  137. cur = build_cvec(cur, il);
  138. cb(cur, "l_out", il);
  139. // input for next layer
  140. inpL = cur;
  141. }
  142. cur = inpL;
  143. cur = build_norm(cur,
  144. model.output_norm, NULL,
  145. LLM_NORM_RMS, -1);
  146. cb(cur, "result_norm", -1);
  147. res->t_embd = cur;
  148. // lm_head
  149. cur = build_lora_mm(model.output, cur);
  150. cb(cur, "result_output", -1);
  151. res->t_logits = cur;
  152. ggml_build_forward_expand(gf, cur);
  153. }