llm_build_rwkv7_base.cpp 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. #include "../llama-graph.h"
  2. #include "../llama-model.h"
  3. #include "llm_build_rwkv_base.h"
  4. #include <cmath>
  5. llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) :
  6. llm_graph_context(params),
  7. model(model) {};
  8. ggml_tensor * llm_build_rwkv7_base::build_rwkv7_channel_mix(const llama_layer * layer,
  9. ggml_tensor * cur,
  10. ggml_tensor * x_prev,
  11. llm_arch arch) const {
  12. ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
  13. switch (arch) {
  14. case LLM_ARCH_RWKV7:
  15. {
  16. ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
  17. ggml_tensor * k = ggml_sqr(ctx0, ggml_relu(ctx0, build_lora_mm(layer->channel_mix_key, xk)));
  18. cur = build_lora_mm(layer->channel_mix_value, k);
  19. }
  20. break;
  21. default:
  22. GGML_ABORT("fatal error");
  23. };
  24. return cur;
  25. };
  26. ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * inp,
  27. ggml_tensor * cur,
  28. ggml_tensor * x_prev,
  29. ggml_tensor *& first_layer_value,
  30. const llama_ubatch & ubatch,
  31. int il) const {
  32. const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
  33. const auto n_tokens = ubatch.n_tokens;
  34. const auto n_seqs = ubatch.n_seqs;
  35. const auto n_embd = hparams.n_embd;
  36. const auto head_size = hparams.wkv_head_size;
  37. const auto head_count = n_embd / head_size;
  38. const auto n_seq_tokens = ubatch.n_seq_tokens;
  39. const auto kv_head = mctx_cur->get_head();
  40. const auto & layer = model.layers[il];
  41. bool has_gating = layer.time_mix_g1 && layer.time_mix_g2;
  42. ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
  43. ggml_tensor * dummy = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_embd, n_seq_tokens, n_seqs, has_gating ? 6 : 5);
  44. sx = ggml_repeat(ctx0, sx, dummy);
  45. ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_fused), cur);
  46. ggml_tensor * xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
  47. ggml_tensor * xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
  48. ggml_tensor * xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
  49. ggml_tensor * xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
  50. ggml_tensor * xa = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
  51. ggml_tensor * xg =
  52. has_gating ? ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 5 * sizeof(float)) :
  53. nullptr;
  54. ggml_tensor * r = build_lora_mm(layer.time_mix_receptance, xr);
  55. ggml_tensor * w = ggml_add(
  56. ctx0, ggml_mul_mat(ctx0, layer.time_mix_w2, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_w1, xw))),
  57. layer.time_mix_w0);
  58. w = ggml_exp(ctx0, ggml_scale(ctx0, ggml_sigmoid(ctx0, w), -0.606531));
  59. ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk);
  60. ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv);
  61. if (first_layer_value == nullptr) {
  62. first_layer_value = v;
  63. } else {
  64. // Add the first layer value as a residual connection.
  65. v = ggml_add(ctx0, v,
  66. ggml_mul(ctx0, ggml_sub(ctx0, first_layer_value, v),
  67. ggml_sigmoid(ctx0, ggml_add(ctx0,
  68. ggml_mul_mat(ctx0, layer.time_mix_v2,
  69. ggml_mul_mat(ctx0, layer.time_mix_v1, xv)),
  70. layer.time_mix_v0))));
  71. };
  72. ggml_tensor * g = nullptr;
  73. if (layer.time_mix_g1 && layer.time_mix_g2) {
  74. g = ggml_mul_mat(ctx0, layer.time_mix_g2, ggml_sigmoid(ctx0, ggml_mul_mat(ctx0, layer.time_mix_g1, xg)));
  75. };
  76. ggml_tensor * a = ggml_sigmoid(
  77. ctx0, ggml_add(ctx0, ggml_mul_mat(ctx0, layer.time_mix_a2, ggml_mul_mat(ctx0, layer.time_mix_a1, xa)),
  78. layer.time_mix_a0));
  79. ggml_tensor * kk = ggml_reshape_3d(ctx0, ggml_mul(ctx0, k, layer.time_mix_k_k), head_size, head_count, n_tokens);
  80. kk = ggml_l2_norm(ctx0, kk, 1e-12);
  81. ggml_tensor * ka = ggml_mul(ctx0, k, layer.time_mix_k_a);
  82. k = ggml_add(ctx0, k, ggml_sub(ctx0, ggml_mul(ctx0, a, ka), ka));
  83. r = ggml_reshape_3d(ctx0, r, head_size, head_count, n_tokens);
  84. w = ggml_reshape_3d(ctx0, w, head_size, head_count, n_tokens);
  85. k = ggml_reshape_3d(ctx0, k, head_size, head_count, n_tokens);
  86. v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
  87. a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
  88. ggml_tensor * wkv_state = build_rs(inp, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs);
  89. ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
  90. cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
  91. wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
  92. ggml_build_forward_expand(
  93. gf, ggml_cpy(ctx0, wkv_state,
  94. ggml_view_1d(ctx0, mctx_cur->get_s_l(il), hparams.n_embd_s() * n_seqs,
  95. hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il)))));
  96. if (layer.time_mix_ln && layer.time_mix_ln_b) {
  97. // group norm with head_count groups
  98. cur = ggml_reshape_3d(ctx0, cur, n_embd / head_count, head_count, n_tokens);
  99. cur = ggml_norm(ctx0, cur, 64e-5f);
  100. // Convert back to regular vectors.
  101. cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
  102. cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.time_mix_ln), layer.time_mix_ln_b);
  103. } else {
  104. cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
  105. };
  106. ggml_tensor * rk = ggml_sum_rows(
  107. ctx0, ggml_mul(ctx0, ggml_mul(ctx0, k, r), ggml_reshape_2d(ctx0, layer.time_mix_r_k, head_size, head_count)));
  108. cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, ggml_mul(ctx0, v, rk), n_embd, n_tokens));
  109. if (has_gating) {
  110. cur = ggml_mul(ctx0, cur, g);
  111. }
  112. cur = build_lora_mm(layer.time_mix_output, cur);
  113. return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs);
  114. }