llama-graph.cpp 56 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686
  1. #include "llama-graph.h"
  2. #include "llama-impl.h"
  3. #include "llama-batch.h"
  4. #include "llama-cparams.h"
  5. #include "llama-kv-cache.h"
  6. #include <cassert>
  7. #include <cmath>
  8. #include <cstring>
  9. static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
  10. // TODO move to hparams if a T5 variant appears that uses a different value
  11. const int64_t max_distance = 128;
  12. if (bidirectional) {
  13. n_buckets >>= 1;
  14. }
  15. const int64_t max_exact = n_buckets >> 1;
  16. int32_t relative_position = x - y;
  17. int32_t relative_bucket = 0;
  18. if (bidirectional) {
  19. relative_bucket += (relative_position > 0) * n_buckets;
  20. relative_position = abs(relative_position);
  21. } else {
  22. relative_position = -std::min<int32_t>(relative_position, 0);
  23. }
  24. int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
  25. relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
  26. relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
  27. return relative_bucket;
  28. }
  29. void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
  30. if (ubatch->token) {
  31. const int64_t n_tokens = ubatch->n_tokens;
  32. ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
  33. }
  34. if (ubatch->embd) {
  35. const int64_t n_embd = embd->ne[0];
  36. const int64_t n_tokens = ubatch->n_tokens;
  37. ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
  38. }
  39. }
  40. void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
  41. if (ubatch->pos && pos) {
  42. const int64_t n_tokens = ubatch->n_tokens;
  43. ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_token*ggml_element_size(pos));
  44. }
  45. }
  46. void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
  47. if (ubatch->pos && attn_scale) {
  48. const int64_t n_tokens = ubatch->n_tokens;
  49. std::vector<float> attn_scale_data(n_tokens, 0.0f);
  50. for (int i = 0; i < n_tokens; ++i) {
  51. const float pos = ubatch->pos[i];
  52. attn_scale_data[i] = std::log(
  53. std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
  54. ) * f_attn_temp_scale + 1.0;
  55. }
  56. ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(attn_scale));
  57. }
  58. }
  59. void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
  60. if (pos_bucket) {
  61. const int64_t n_tokens = ubatch->n_tokens;
  62. GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
  63. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  64. int32_t * data = (int32_t *) pos_bucket->data;
  65. for (int h = 0; h < 1; ++h) {
  66. for (int j = 0; j < n_tokens; ++j) {
  67. for (int i = 0; i < n_tokens; ++i) {
  68. data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
  69. }
  70. }
  71. }
  72. }
  73. }
  74. void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
  75. if (pos_bucket) {
  76. const int64_t n_tokens = ubatch->n_tokens;
  77. GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
  78. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  79. int32_t * data = (int32_t *) pos_bucket->data;
  80. const int64_t n_kv = kv_self->n;
  81. for (int h = 0; h < 1; ++h) {
  82. for (int j = 0; j < n_tokens; ++j) {
  83. for (int i = 0; i < n_kv; ++i) {
  84. data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
  85. }
  86. }
  87. }
  88. }
  89. }
  90. void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
  91. if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
  92. //GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
  93. if (!out_ids) {
  94. LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
  95. } else {
  96. const int64_t n_tokens = ubatch->n_tokens;
  97. GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
  98. int32_t * data = (int32_t *) out_ids->data;
  99. if (n_outputs == n_tokens) {
  100. for (int i = 0; i < n_tokens; ++i) {
  101. data[i] = i;
  102. }
  103. } else if (ubatch->output) {
  104. int32_t n_outputs = 0;
  105. for (int i = 0; i < n_tokens; ++i) {
  106. if (ubatch->output[i]) {
  107. data[n_outputs++] = i;
  108. }
  109. }
  110. // the graph needs to have been passed the correct number of outputs
  111. GGML_ASSERT(n_outputs == n_outputs);
  112. } else if (n_outputs == 1) {
  113. // only keep last output
  114. data[0] = n_tokens - 1;
  115. } else {
  116. GGML_ASSERT(n_outputs == 0);
  117. }
  118. }
  119. }
  120. }
  121. void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
  122. if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
  123. const int64_t n_tokens = ubatch->n_tokens;
  124. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  125. const int64_t n_seqs = ubatch->n_seqs;
  126. GGML_ASSERT(mean);
  127. GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
  128. float * data = (float *) mean->data;
  129. memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
  130. std::vector<uint64_t> sum(n_tokens, 0);
  131. for (int s = 0; s < n_seqs; ++s) {
  132. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  133. // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
  134. GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
  135. sum[seq_id] += ubatch->n_seq_tokens;
  136. }
  137. std::vector<float> div(n_tokens, 0.0f);
  138. for (int i = 0; i < n_tokens; ++i) {
  139. const uint64_t s = sum[i];
  140. if (s > 0) {
  141. div[i] = 1.0f/float(s);
  142. }
  143. }
  144. for (int s = 0; s < n_seqs; ++s) {
  145. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  146. for (int i = 0; i < n_seq_tokens; ++i) {
  147. data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
  148. }
  149. }
  150. }
  151. }
  152. void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
  153. if (cparams.embeddings && (
  154. cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
  155. cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
  156. const int64_t n_tokens = ubatch->n_tokens;
  157. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  158. const int64_t n_seqs = ubatch->n_seqs;
  159. GGML_ASSERT(cls);
  160. GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
  161. uint32_t * data = (uint32_t *) cls->data;
  162. memset(cls->data, 0, n_tokens * ggml_element_size(cls));
  163. for (int s = 0; s < n_seqs; ++s) {
  164. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  165. // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
  166. GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
  167. for (int i = 0; i < n_seq_tokens; ++i) {
  168. const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
  169. if (pos == 0) {
  170. data[seq_id] = s*n_seq_tokens + i;
  171. }
  172. }
  173. }
  174. }
  175. if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
  176. const int64_t n_tokens = ubatch->n_tokens;
  177. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  178. const int64_t n_seqs = ubatch->n_seqs;
  179. GGML_ASSERT(cls);
  180. GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
  181. uint32_t * data = (uint32_t *) cls->data;
  182. memset(cls->data, 0, n_tokens * ggml_element_size(cls));
  183. std::vector<int> last_pos(n_tokens, -1);
  184. std::vector<int> last_row(n_tokens, -1);
  185. for (int s = 0; s < n_seqs; ++s) {
  186. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  187. // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
  188. GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
  189. for (int i = 0; i < n_seq_tokens; ++i) {
  190. const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
  191. if (pos >= last_pos[seq_id]) {
  192. last_pos[seq_id] = pos;
  193. last_row[seq_id] = s*n_seq_tokens + i;
  194. }
  195. }
  196. }
  197. for (int i = 0; i < n_tokens; ++i) {
  198. if (last_row[i] >= 0) {
  199. data[i] = last_row[i];
  200. }
  201. }
  202. }
  203. }
  204. void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
  205. GGML_UNUSED(ubatch);
  206. const int64_t n_kv = kv_self->n;
  207. if (s_copy) {
  208. GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
  209. int32_t * data = (int32_t *) s_copy->data;
  210. // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
  211. for (uint32_t i = 0; i < n_kv; ++i) {
  212. const uint32_t cell_id = i + kv_self->head;
  213. //////////////////////////////////////////////
  214. // TODO: this should not mutate the KV cache !
  215. llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
  216. // prevent out-of-bound sources
  217. if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
  218. kv_cell.src = cell_id;
  219. }
  220. data[i] = kv_cell.src;
  221. // TODO: do not mutate the KV cache
  222. // ensure copy only happens once
  223. if (kv_cell.src != (int32_t) cell_id) {
  224. kv_cell.src = cell_id;
  225. }
  226. }
  227. }
  228. }
  229. void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
  230. GGML_UNUSED(ubatch);
  231. const int64_t n_kv = kv_self->n;
  232. if (s_mask) {
  233. GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
  234. float * data = (float *) s_mask->data;
  235. // clear unused states
  236. for (int i = 0; i < n_kv; ++i) {
  237. const uint32_t cell_id = i + kv_self->head;
  238. //////////////////////////////////////////////
  239. // TODO: this should not mutate the KV cache !
  240. llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
  241. data[i] = (float) (kv_cell.src >= 0);
  242. // only clear once
  243. if (kv_cell.src < 0) {
  244. kv_cell.src = cell_id;
  245. }
  246. }
  247. }
  248. }
  249. void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
  250. GGML_UNUSED(ubatch);
  251. if (cross_embd && !cross->v_embd.empty()) {
  252. assert(cross_embd->type == GGML_TYPE_F32);
  253. ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd));
  254. }
  255. }
  256. void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
  257. if (kq_mask) {
  258. if (cparams.causal_attn) {
  259. const int64_t n_kv = ubatch->n_tokens;
  260. const int64_t n_tokens = ubatch->n_tokens;
  261. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  262. const int64_t n_seqs = ubatch->n_seqs;
  263. GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
  264. float * data = (float *) kq_mask->data;
  265. for (int h = 0; h < 1; ++h) {
  266. for (int s1 = 0; s1 < n_seqs; ++s1) {
  267. const llama_seq_id seq_id = ubatch->seq_id[s1][0];
  268. for (int j = 0; j < n_seq_tokens; ++j) {
  269. const int32_t tj = s1*n_seq_tokens + j;
  270. for (int s0 = 0; s0 < n_seqs; ++s0) {
  271. for (int i = 0; i < n_seq_tokens; ++i) {
  272. const int32_t ti = s0*n_seq_tokens + i;
  273. float f = -INFINITY;
  274. for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
  275. if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
  276. if (hparams.use_alibi) {
  277. f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
  278. } else {
  279. f = 0.0f;
  280. }
  281. break;
  282. }
  283. }
  284. data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
  285. }
  286. }
  287. }
  288. }
  289. }
  290. } else {
  291. const int64_t n_tokens = ubatch->n_tokens;
  292. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  293. const int64_t n_seqs = ubatch->n_seqs;
  294. const int64_t n_stride = ubatch->n_tokens;
  295. GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
  296. float * data = (float *) kq_mask->data;
  297. for (int h = 0; h < 1; ++h) {
  298. for (int s1 = 0; s1 < n_seqs; ++s1) {
  299. const llama_seq_id seq_id = ubatch->seq_id[s1][0];
  300. for (int j = 0; j < n_seq_tokens; ++j) {
  301. const int32_t tj = s1*n_seq_tokens + j;
  302. for (int s0 = 0; s0 < n_seqs; ++s0) {
  303. for (int i = 0; i < n_seq_tokens; ++i) {
  304. const int32_t ti = s0*n_seq_tokens + i;
  305. float f = -INFINITY;
  306. for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
  307. if (ubatch->seq_id[s0][s] == seq_id) {
  308. if (hparams.use_alibi) {
  309. f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
  310. } else {
  311. f = 0.0f;
  312. }
  313. break;
  314. }
  315. }
  316. data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
  317. }
  318. }
  319. for (int i = n_tokens; i < n_stride; ++i) {
  320. data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
  321. }
  322. }
  323. }
  324. }
  325. }
  326. }
  327. }
  328. void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
  329. if (self_kq_mask || self_kq_mask_swa) {
  330. const int64_t n_kv = kv_self->n;
  331. const int64_t n_tokens = ubatch->n_tokens;
  332. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  333. const int64_t n_seqs = ubatch->n_seqs;
  334. float * data = nullptr;
  335. float * data_swa = nullptr;
  336. if (self_kq_mask) {
  337. GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
  338. data = (float *) self_kq_mask->data;
  339. }
  340. if (self_kq_mask_swa) {
  341. GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
  342. data_swa = (float *) self_kq_mask_swa->data;
  343. }
  344. // Use only the previous KV cells of the correct sequence for each token of the ubatch.
  345. // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
  346. // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
  347. // Causal mask:
  348. // xxx-------
  349. // xxxx------
  350. // xxxxx-----
  351. // Non-causal mask:
  352. // xxxxx-----
  353. // xxxxx-----
  354. // xxxxx-----
  355. // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
  356. for (int h = 0; h < 1; ++h) {
  357. for (int s = 0; s < n_seqs; ++s) {
  358. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  359. for (int j = 0; j < n_seq_tokens; ++j) {
  360. const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
  361. for (int i = 0; i < n_kv; ++i) {
  362. float f;
  363. // mask the token if:
  364. if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
  365. || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
  366. ) {
  367. f = -INFINITY;
  368. } else {
  369. if (hparams.use_alibi) {
  370. f = -std::abs(kv_self->cells[i].pos - pos);
  371. } else {
  372. f = 0.0f;
  373. }
  374. }
  375. if (data) {
  376. data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
  377. }
  378. // may need to cut off old tokens for sliding window
  379. // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
  380. if (data_swa) {
  381. if (hparams.n_attn_chunk) {
  382. llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
  383. if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
  384. f = -INFINITY;
  385. }
  386. } else {
  387. if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
  388. f = -INFINITY;
  389. }
  390. }
  391. data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
  392. }
  393. }
  394. }
  395. }
  396. // mask padded tokens
  397. if (data) {
  398. for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  399. for (int j = 0; j < n_kv; ++j) {
  400. data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
  401. }
  402. }
  403. }
  404. // mask padded tokens
  405. if (data_swa) {
  406. for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  407. for (int j = 0; j < n_kv; ++j) {
  408. data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
  409. }
  410. }
  411. }
  412. }
  413. }
  414. }
  415. void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
  416. if (cross_kq_mask) {
  417. const int64_t n_enc = cross_kq_mask->ne[0];
  418. const int64_t n_tokens = ubatch->n_tokens;
  419. GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
  420. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  421. float * data = (float *) cross_kq_mask->data;
  422. for (int h = 0; h < 1; ++h) {
  423. for (int j = 0; j < n_tokens; ++j) {
  424. for (int i = 0; i < n_enc; ++i) {
  425. float f = -INFINITY;
  426. for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
  427. const llama_seq_id seq_id = ubatch->seq_id[j][s];
  428. if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
  429. f = 0.0f;
  430. }
  431. }
  432. data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
  433. }
  434. }
  435. for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  436. for (int j = 0; j < n_enc; ++j) {
  437. data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
  438. }
  439. }
  440. }
  441. }
  442. }
  443. //
  444. // llm_graph_context
  445. //
  446. llm_graph_context::llm_graph_context(const llm_graph_params & params) :
  447. arch (params.arch),
  448. hparams (params.hparams),
  449. cparams (params.cparams),
  450. ubatch (params.ubatch),
  451. n_embd (hparams.n_embd),
  452. n_layer (hparams.n_layer),
  453. n_rot (hparams.n_rot),
  454. n_ctx (cparams.n_ctx),
  455. n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
  456. n_head (hparams.n_head()),
  457. n_head_kv (hparams.n_head_kv()),
  458. n_embd_head_k (hparams.n_embd_head_k),
  459. n_embd_k_gqa (hparams.n_embd_k_gqa()),
  460. n_embd_head_v (hparams.n_embd_head_v),
  461. n_embd_v_gqa (hparams.n_embd_v_gqa()),
  462. n_expert (hparams.n_expert),
  463. n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
  464. freq_base (cparams.rope_freq_base),
  465. freq_scale (cparams.rope_freq_scale),
  466. ext_factor (cparams.yarn_ext_factor),
  467. attn_factor (cparams.yarn_attn_factor),
  468. beta_fast (cparams.yarn_beta_fast),
  469. beta_slow (cparams.yarn_beta_slow),
  470. norm_eps (hparams.f_norm_eps),
  471. norm_rms_eps (hparams.f_norm_rms_eps),
  472. n_tokens (ubatch.n_tokens),
  473. n_outputs (params.n_outputs),
  474. n_ctx_orig (cparams.n_ctx_orig_yarn),
  475. pooling_type (cparams.pooling_type),
  476. rope_type (hparams.rope_type),
  477. ctx0 (params.ctx),
  478. sched (params.sched),
  479. backend_cpu (params.backend_cpu),
  480. cvec (params.cvec),
  481. loras (params.loras),
  482. memory (params.memory),
  483. cross (params.cross),
  484. cb_func (params.cb),
  485. res (std::make_unique<llm_graph_result>()) {
  486. }
  487. int64_t llm_graph_context::n_pos_per_token() const {
  488. return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
  489. }
  490. void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
  491. if (cb_func) {
  492. cb_func(ubatch, cur, name, il);
  493. }
  494. }
  495. ggml_tensor * llm_graph_context::build_cvec(
  496. ggml_tensor * cur,
  497. int il) const {
  498. return cvec->apply_to(ctx0, cur, il);
  499. }
  500. ggml_tensor * llm_graph_context::build_lora_mm(
  501. ggml_tensor * w,
  502. ggml_tensor * cur) const {
  503. ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
  504. for (const auto & lora : *loras) {
  505. llama_adapter_lora_weight * lw = lora.first->get_weight(w);
  506. if (lw == nullptr) {
  507. continue;
  508. }
  509. const float adapter_scale = lora.second;
  510. const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
  511. ggml_tensor * ab_cur = ggml_mul_mat(
  512. ctx0, lw->b,
  513. ggml_mul_mat(ctx0, lw->a, cur)
  514. );
  515. ab_cur = ggml_scale(ctx0, ab_cur, scale);
  516. res = ggml_add(ctx0, res, ab_cur);
  517. }
  518. return res;
  519. }
  520. ggml_tensor * llm_graph_context::build_lora_mm_id(
  521. ggml_tensor * w, // ggml_tensor * as
  522. ggml_tensor * cur, // ggml_tensor * b
  523. ggml_tensor * ids) const {
  524. ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
  525. for (const auto & lora : *loras) {
  526. llama_adapter_lora_weight * lw = lora.first->get_weight(w);
  527. if (lw == nullptr) {
  528. continue;
  529. }
  530. const float alpha = lora.first->alpha;
  531. const float rank = (float) lw->b->ne[0];
  532. const float scale = alpha ? lora.second * alpha / rank : lora.second;
  533. ggml_tensor * ab_cur = ggml_mul_mat_id(
  534. ctx0, lw->b,
  535. ggml_mul_mat_id(ctx0, lw->a, cur, ids),
  536. ids
  537. );
  538. ab_cur = ggml_scale(ctx0, ab_cur, scale);
  539. res = ggml_add(ctx0, res, ab_cur);
  540. }
  541. return res;
  542. }
  543. ggml_tensor * llm_graph_context::build_norm(
  544. ggml_tensor * cur,
  545. ggml_tensor * mw,
  546. ggml_tensor * mb,
  547. llm_norm_type type,
  548. int il) const {
  549. switch (type) {
  550. case LLM_NORM: cur = ggml_norm (ctx0, cur, hparams.f_norm_eps); break;
  551. case LLM_NORM_RMS: cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;
  552. case LLM_NORM_GROUP:
  553. {
  554. cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]);
  555. cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
  556. cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[2]);
  557. } break;
  558. }
  559. if (mw || mb) {
  560. cb(cur, "norm", il);
  561. }
  562. if (mw) {
  563. cur = ggml_mul(ctx0, cur, mw);
  564. if (mb) {
  565. cb(cur, "norm_w", il);
  566. }
  567. }
  568. if (mb) {
  569. cur = ggml_add(ctx0, cur, mb);
  570. }
  571. return cur;
  572. }
  573. ggml_tensor * llm_graph_context::build_ffn(
  574. ggml_tensor * cur,
  575. ggml_tensor * up,
  576. ggml_tensor * up_b,
  577. ggml_tensor * up_s,
  578. ggml_tensor * gate,
  579. ggml_tensor * gate_b,
  580. ggml_tensor * gate_s,
  581. ggml_tensor * down,
  582. ggml_tensor * down_b,
  583. ggml_tensor * down_s,
  584. ggml_tensor * act_scales,
  585. llm_ffn_op_type type_op,
  586. llm_ffn_gate_type type_gate,
  587. int il) const {
  588. ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
  589. cb(tmp, "ffn_up", il);
  590. if (up_b) {
  591. tmp = ggml_add(ctx0, tmp, up_b);
  592. cb(tmp, "ffn_up_b", il);
  593. }
  594. if (up_s) {
  595. tmp = ggml_mul(ctx0, tmp, up_s);
  596. cb(tmp, "ffn_up_s", il);
  597. }
  598. if (gate) {
  599. switch (type_gate) {
  600. case LLM_FFN_SEQ:
  601. {
  602. cur = build_lora_mm(gate, tmp);
  603. cb(cur, "ffn_gate", il);
  604. } break;
  605. case LLM_FFN_PAR:
  606. {
  607. cur = build_lora_mm(gate, cur);
  608. cb(cur, "ffn_gate", il);
  609. } break;
  610. }
  611. if (gate_b) {
  612. cur = ggml_add(ctx0, cur, gate_b);
  613. cb(cur, "ffn_gate_b", il);
  614. }
  615. if (gate_s) {
  616. cur = ggml_mul(ctx0, cur, gate_s);
  617. cb(cur, "ffn_gate_s", il);
  618. }
  619. } else {
  620. cur = tmp;
  621. }
  622. switch (type_op) {
  623. case LLM_FFN_SILU:
  624. {
  625. cur = ggml_silu(ctx0, cur);
  626. cb(cur, "ffn_silu", il);
  627. } break;
  628. case LLM_FFN_GELU:
  629. {
  630. cur = ggml_gelu(ctx0, cur);
  631. cb(cur, "ffn_gelu", il);
  632. if (act_scales != NULL) {
  633. cur = ggml_div(ctx0, cur, act_scales);
  634. cb(cur, "ffn_act", il);
  635. }
  636. } break;
  637. case LLM_FFN_RELU:
  638. {
  639. cur = ggml_relu(ctx0, cur);
  640. cb(cur, "ffn_relu", il);
  641. } break;
  642. case LLM_FFN_RELU_SQR:
  643. {
  644. cur = ggml_relu(ctx0, cur);
  645. cb(cur, "ffn_relu", il);
  646. cur = ggml_sqr(ctx0, cur);
  647. cb(cur, "ffn_sqr(relu)", il);
  648. } break;
  649. case LLM_FFN_SWIGLU:
  650. {
  651. // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
  652. int64_t split_point = cur->ne[0] / 2;
  653. ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
  654. ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
  655. x0 = ggml_silu(ctx0, x0);
  656. cb(cur, "ffn_silu", il);
  657. cur = ggml_mul(ctx0, x0, x1);
  658. cb(cur, "ffn_mul", il);
  659. } break;
  660. }
  661. if (type_gate == LLM_FFN_PAR) {
  662. cur = ggml_mul(ctx0, cur, tmp);
  663. cb(cur, "ffn_gate_par", il);
  664. }
  665. if (down) {
  666. cur = build_lora_mm(down, cur);
  667. }
  668. if (down_b) {
  669. cb(cur, "ffn_down", il);
  670. }
  671. if (down_b) {
  672. cur = ggml_add(ctx0, cur, down_b);
  673. }
  674. if (down_s) {
  675. cur = ggml_mul(ctx0, cur, down_s);
  676. cb(cur, "ffn_down_s", il);
  677. }
  678. return cur;
  679. }
  680. ggml_tensor * llm_graph_context::build_moe_ffn(
  681. ggml_tensor * cur,
  682. ggml_tensor * gate_inp,
  683. ggml_tensor * up_exps,
  684. ggml_tensor * gate_exps,
  685. ggml_tensor * down_exps,
  686. ggml_tensor * exp_probs_b,
  687. int64_t n_expert,
  688. int64_t n_expert_used,
  689. llm_ffn_op_type type_op,
  690. bool norm_w,
  691. bool scale_w,
  692. float w_scale,
  693. llama_expert_gating_func_type gating_op,
  694. int il) const {
  695. const int64_t n_embd = cur->ne[0];
  696. const int64_t n_tokens = cur->ne[1];
  697. const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
  698. ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
  699. cb(logits, "ffn_moe_logits", il);
  700. ggml_tensor * probs = nullptr;
  701. switch (gating_op) {
  702. case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
  703. {
  704. probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
  705. } break;
  706. case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
  707. {
  708. probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
  709. } break;
  710. default:
  711. GGML_ABORT("fatal error");
  712. }
  713. cb(probs, "ffn_moe_probs", il);
  714. // add experts selection bias - introduced in DeepSeek V3
  715. // leave probs unbiased as it's later used to get expert weights
  716. ggml_tensor * selection_probs = probs;
  717. if (exp_probs_b != nullptr) {
  718. selection_probs = ggml_add(ctx0, probs, exp_probs_b);
  719. cb(selection_probs, "ffn_moe_probs_biased", il);
  720. }
  721. // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
  722. // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
  723. if (arch == LLM_ARCH_LLAMA4) {
  724. selection_probs = logits;
  725. }
  726. // select experts
  727. ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
  728. cb(selected_experts->src[0], "ffn_moe_argsort", il);
  729. cb(selected_experts, "ffn_moe_topk", il);
  730. ggml_tensor * weights = ggml_get_rows(ctx0,
  731. ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
  732. cb(weights, "ffn_moe_weights", il);
  733. if (norm_w) {
  734. weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
  735. ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
  736. cb(weights_sum, "ffn_moe_weights_sum", il);
  737. weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
  738. cb(weights, "ffn_moe_weights_norm", il);
  739. weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
  740. }
  741. if (scale_w) {
  742. weights = ggml_scale(ctx0, weights, w_scale);
  743. cb(weights, "ffn_moe_weights_scaled", il);
  744. }
  745. cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
  746. if (weight_before_ffn) {
  747. // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
  748. ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
  749. repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
  750. cur = ggml_mul(ctx0, repeated, weights);
  751. cb(cur, "ffn_moe_weighted", il);
  752. }
  753. ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
  754. cb(up, "ffn_moe_up", il);
  755. ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
  756. cb(gate, "ffn_moe_gate", il);
  757. switch (type_op) {
  758. case LLM_FFN_SILU:
  759. {
  760. gate = ggml_silu(ctx0, gate);
  761. cb(gate, "ffn_moe_silu", il);
  762. } break;
  763. case LLM_FFN_GELU:
  764. {
  765. gate = ggml_gelu(ctx0, gate);
  766. cb(gate, "ffn_moe_gelu", il);
  767. } break;
  768. default:
  769. GGML_ABORT("fatal error");
  770. }
  771. ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
  772. cb(par, "ffn_moe_gate_par", il);
  773. ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
  774. cb(experts, "ffn_moe_down", il);
  775. if (!weight_before_ffn) {
  776. experts = ggml_mul(ctx0, experts, weights);
  777. cb(cur, "ffn_moe_weighted", il);
  778. }
  779. // aggregate experts
  780. ggml_tensor * moe_out = nullptr;
  781. for (int i = 0; i < n_expert_used; ++i) {
  782. ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
  783. experts->nb[2], i*experts->nb[1]);
  784. if (i == 0) {
  785. moe_out = cur_expert;
  786. } else {
  787. moe_out = ggml_add(ctx0, moe_out, cur_expert);
  788. }
  789. }
  790. if (n_expert_used == 1) {
  791. // avoid returning a non-contiguous tensor
  792. moe_out = ggml_cont(ctx0, moe_out);
  793. }
  794. cb(moe_out, "ffn_moe_out", il);
  795. return moe_out;
  796. }
  797. // input embeddings with optional lora
  798. ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
  799. const int64_t n_embd = hparams.n_embd;
  800. auto inp = std::make_unique<llm_graph_input_embd>();
  801. ggml_tensor * cur = nullptr;
  802. if (ubatch.token) {
  803. inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
  804. //cb(inp->tokens, "inp_tokens", -1);
  805. ggml_set_input(inp->tokens);
  806. cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
  807. // apply lora for embedding tokens if needed
  808. for (const auto & lora : *loras) {
  809. llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
  810. if (lw == nullptr) {
  811. continue;
  812. }
  813. const float adapter_scale = lora.second;
  814. const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
  815. ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
  816. ctx0, lw->b, // non-transposed lora_b
  817. ggml_get_rows(ctx0, lw->a, inp->tokens)
  818. ), scale);
  819. cur = ggml_add(ctx0, cur, inpL_delta);
  820. }
  821. } else {
  822. inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
  823. ggml_set_input(inp->embd);
  824. cur = inp->embd;
  825. }
  826. // For Granite architecture
  827. if (hparams.f_embedding_scale != 0.0f) {
  828. cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
  829. }
  830. cb(cur, "inp_embd", -1);
  831. res->add_input(std::move(inp));
  832. return cur;
  833. }
  834. ggml_tensor * llm_graph_context::build_inp_pos() const {
  835. auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_token());
  836. auto & cur = inp->pos;
  837. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
  838. ggml_set_input(cur);
  839. res->add_input(std::move(inp));
  840. return cur;
  841. }
  842. ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
  843. auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
  844. auto & cur = inp->attn_scale;
  845. cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token());
  846. ggml_set_input(cur);
  847. res->add_input(std::move(inp));
  848. return cur;
  849. }
  850. ggml_tensor * llm_graph_context::build_inp_out_ids() const {
  851. auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
  852. auto & cur = inp->out_ids;
  853. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
  854. ggml_set_input(cur);
  855. res->add_input(std::move(inp));
  856. return cur;
  857. }
  858. ggml_tensor * llm_graph_context::build_inp_mean() const {
  859. auto inp = std::make_unique<llm_graph_input_mean>(cparams);
  860. auto & cur = inp->mean;
  861. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
  862. ggml_set_input(cur);
  863. res->add_input(std::move(inp));
  864. return cur;
  865. }
  866. ggml_tensor * llm_graph_context::build_inp_cls() const {
  867. auto inp = std::make_unique<llm_graph_input_cls>(cparams);
  868. auto & cur = inp->cls;
  869. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
  870. ggml_set_input(cur);
  871. res->add_input(std::move(inp));
  872. return cur;
  873. }
  874. ggml_tensor * llm_graph_context::build_inp_s_copy() const {
  875. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  876. auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
  877. const auto n_kv = kv_self->n;
  878. auto & cur = inp->s_copy;
  879. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
  880. ggml_set_input(cur);
  881. res->add_input(std::move(inp));
  882. return cur;
  883. }
  884. ggml_tensor * llm_graph_context::build_inp_s_mask() const {
  885. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  886. auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
  887. const auto n_kv = kv_self->n;
  888. auto & cur = inp->s_mask;
  889. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
  890. ggml_set_input(cur);
  891. res->add_input(std::move(inp));
  892. return cur;
  893. }
  894. ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
  895. auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
  896. auto & cur = inp->cross_embd;
  897. // if we have the output embeddings from the encoder, use them directly
  898. // TODO: needs more work to be correct, for now just use the tensor shape
  899. //if (cross->t_embd) {
  900. // cur = ggml_view_tensor(ctx0, cross->t_embd);
  901. // return cur;
  902. //}
  903. const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd;
  904. const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
  905. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
  906. ggml_set_input(cur);
  907. res->add_input(std::move(inp));
  908. return cur;
  909. }
  910. ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
  911. auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);
  912. auto & cur = inp->pos_bucket;
  913. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
  914. ggml_set_input(cur);
  915. res->add_input(std::move(inp));
  916. return cur;
  917. }
  918. ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
  919. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  920. auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
  921. const auto n_kv = kv_self->n;
  922. auto & cur = inp->pos_bucket;
  923. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
  924. ggml_set_input(cur);
  925. res->add_input(std::move(inp));
  926. return cur;
  927. }
  928. ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
  929. ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]);
  930. cb(pos_bucket_1d, "pos_bucket_1d", -1);
  931. ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
  932. pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
  933. pos_bias = ggml_permute (ctx0, pos_bias, 2, 0, 1, 3);
  934. pos_bias = ggml_cont (ctx0, pos_bias);
  935. cb(pos_bias, "pos_bias", -1);
  936. return pos_bias;
  937. }
  938. ggml_tensor * llm_graph_context::build_attn_mha(
  939. ggml_cgraph * gf,
  940. ggml_tensor * q,
  941. ggml_tensor * k,
  942. ggml_tensor * v,
  943. ggml_tensor * kq_b,
  944. ggml_tensor * kq_mask,
  945. bool v_trans,
  946. float kq_scale) const {
  947. //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  948. //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  949. //const int64_t n_head = hparams.n_head(il);
  950. //const int64_t n_head_kv = hparams.n_head_kv(il);
  951. //const auto & n_embd_head_k = hparams.n_embd_head_k;
  952. //const auto & n_embd_head_v = hparams.n_embd_head_v;
  953. const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];
  954. const auto n_tokens = q->ne[1];
  955. const auto n_head = q->ne[2];
  956. const auto n_kv = k->ne[1];
  957. ggml_tensor * cur;
  958. // TODO: replace hardcoded padding with ggml-provided padding
  959. if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
  960. GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
  961. if (v_trans) {
  962. v = ggml_transpose(ctx0, v);
  963. }
  964. cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
  965. hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
  966. ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
  967. cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
  968. } else {
  969. ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
  970. // note: this op tends to require high floating point range
  971. // while for some models F16 is enough, for others it is not, so we default to F32 here
  972. ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
  973. if (arch == LLM_ARCH_GROK) {
  974. // need to do the following:
  975. // multiply by attn_output_multiplyer of 0.08838834764831845
  976. // and then :
  977. // kq = 30 * tanh(kq / 30)
  978. // before the softmax below
  979. kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
  980. kq = ggml_scale(ctx0, kq, 30);
  981. }
  982. if (hparams.attn_soft_cap) {
  983. kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
  984. kq = ggml_tanh (ctx0, kq);
  985. kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
  986. }
  987. if (kq_b) {
  988. kq = ggml_add(ctx0, kq, kq_b);
  989. }
  990. kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
  991. if (!v_trans) {
  992. // note: avoid this branch
  993. v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
  994. }
  995. ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
  996. ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
  997. cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
  998. if (!cparams.offload_kqv) {
  999. // all nodes between the KV store and the attention output are run on the CPU
  1000. ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
  1001. }
  1002. }
  1003. ggml_build_forward_expand(gf, cur);
  1004. return cur;
  1005. }
  1006. llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
  1007. auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
  1008. // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
  1009. inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1010. //cb(inp_kq_mask, "KQ_mask", -1);
  1011. ggml_set_input(inp->kq_mask);
  1012. inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
  1013. return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
  1014. }
  1015. ggml_tensor * llm_graph_context::build_attn(
  1016. llm_graph_input_attn_no_cache * inp,
  1017. ggml_cgraph * gf,
  1018. ggml_tensor * wo,
  1019. ggml_tensor * wo_b,
  1020. ggml_tensor * q_cur,
  1021. ggml_tensor * k_cur,
  1022. ggml_tensor * v_cur,
  1023. ggml_tensor * kq_b,
  1024. float kq_scale,
  1025. int il) const {
  1026. GGML_UNUSED(n_tokens);
  1027. // these nodes are added to the graph together so that they are not reordered
  1028. // by doing so, the number of splits in the graph is reduced
  1029. ggml_build_forward_expand(gf, q_cur);
  1030. ggml_build_forward_expand(gf, k_cur);
  1031. ggml_build_forward_expand(gf, v_cur);
  1032. const auto & kq_mask = inp->get_kq_mask();
  1033. ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
  1034. //cb(q, "q", il);
  1035. ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
  1036. //cb(k, "k", il);
  1037. ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
  1038. //cb(k, "v", il);
  1039. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
  1040. cb(cur, "kqv_out", il);
  1041. if (wo) {
  1042. cur = build_lora_mm(wo, cur);
  1043. }
  1044. if (wo_b) {
  1045. //cb(cur, "kqv_wo", il);
  1046. }
  1047. if (wo_b) {
  1048. cur = ggml_add(ctx0, cur, wo_b);
  1049. }
  1050. return cur;
  1051. }
  1052. llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
  1053. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  1054. auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
  1055. const auto n_kv = kv_self->n;
  1056. inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1057. //cb(inp->self_kq_mask, "KQ_mask", -1);
  1058. ggml_set_input(inp->self_kq_mask);
  1059. inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
  1060. if (hparams.n_swa_pattern > 1) {
  1061. GGML_ASSERT(hparams.n_swa > 0);
  1062. inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1063. //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
  1064. ggml_set_input(inp->self_kq_mask_swa);
  1065. inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
  1066. }
  1067. return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
  1068. }
  1069. ggml_tensor * llm_graph_context::build_attn(
  1070. llm_graph_input_attn_kv_unified * inp,
  1071. ggml_cgraph * gf,
  1072. ggml_tensor * wo,
  1073. ggml_tensor * wo_b,
  1074. ggml_tensor * q_cur,
  1075. ggml_tensor * k_cur,
  1076. ggml_tensor * v_cur,
  1077. ggml_tensor * kq_b,
  1078. float kq_scale,
  1079. int il) const {
  1080. // these nodes are added to the graph together so that they are not reordered
  1081. // by doing so, the number of splits in the graph is reduced
  1082. ggml_build_forward_expand(gf, q_cur);
  1083. ggml_build_forward_expand(gf, k_cur);
  1084. ggml_build_forward_expand(gf, v_cur);
  1085. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  1086. const auto & n_ctx = cparams.n_ctx;
  1087. const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  1088. const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1089. const auto n_tokens = q_cur->ne[2];
  1090. const bool v_trans = !cparams.flash_attn;
  1091. // store to KV cache
  1092. {
  1093. GGML_ASSERT(!kv_self->recurrent);
  1094. const auto kv_head = kv_self->head;
  1095. GGML_ASSERT(kv_self->size == n_ctx);
  1096. ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
  1097. //cb(k_cache_view, "k_cache_view", il);
  1098. // note: storing RoPE-ed version of K in the KV cache
  1099. ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
  1100. v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
  1101. ggml_tensor * v_cache_view = nullptr;
  1102. if (!v_trans) {
  1103. v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
  1104. } else {
  1105. // note: the V cache is transposed when not using flash attention
  1106. v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
  1107. ( n_ctx)*ggml_element_size(kv_self->v_l[il]),
  1108. (kv_head)*ggml_element_size(kv_self->v_l[il]));
  1109. v_cur = ggml_transpose(ctx0, v_cur);
  1110. }
  1111. //cb(v_cache_view, "v_cache_view", il);
  1112. ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
  1113. }
  1114. const bool is_swa = hparams.is_swa(il);
  1115. const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
  1116. const auto n_kv = kv_self->n;
  1117. const int64_t n_head_kv = hparams.n_head_kv(il);
  1118. const auto & n_embd_head_k = hparams.n_embd_head_k;
  1119. const auto & n_embd_head_v = hparams.n_embd_head_v;
  1120. ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
  1121. //cb(q, "q", il);
  1122. ggml_tensor * k =
  1123. ggml_view_3d(ctx0, kv_self->k_l[il],
  1124. n_embd_head_k, n_kv, n_head_kv,
  1125. ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
  1126. ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
  1127. 0);
  1128. //cb(k, "k", il);
  1129. ggml_tensor * v = !v_trans ?
  1130. ggml_view_3d(ctx0, kv_self->v_l[il],
  1131. n_embd_head_v, n_kv, n_head_kv,
  1132. ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
  1133. ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
  1134. 0) :
  1135. ggml_view_3d(ctx0, kv_self->v_l[il],
  1136. n_kv, n_embd_head_v, n_head_kv,
  1137. ggml_element_size(kv_self->v_l[il])*n_ctx,
  1138. ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
  1139. 0);
  1140. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
  1141. cb(cur, "kqv_out", il);
  1142. if (wo) {
  1143. cur = build_lora_mm(wo, cur);
  1144. }
  1145. if (wo_b) {
  1146. //cb(cur, "kqv_wo", il);
  1147. }
  1148. if (wo_b) {
  1149. cur = ggml_add(ctx0, cur, wo_b);
  1150. }
  1151. return cur;
  1152. }
  1153. llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
  1154. auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
  1155. const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
  1156. inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1157. ggml_set_input(inp->cross_kq_mask);
  1158. inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
  1159. return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
  1160. }
  1161. ggml_tensor * llm_graph_context::build_attn(
  1162. llm_graph_input_attn_cross * inp,
  1163. ggml_cgraph * gf,
  1164. ggml_tensor * wo,
  1165. ggml_tensor * wo_b,
  1166. ggml_tensor * q_cur,
  1167. ggml_tensor * k_cur,
  1168. ggml_tensor * v_cur,
  1169. ggml_tensor * kq_b,
  1170. float kq_scale,
  1171. int il) const {
  1172. // these nodes are added to the graph together so that they are not reordered
  1173. // by doing so, the number of splits in the graph is reduced
  1174. ggml_build_forward_expand(gf, q_cur);
  1175. ggml_build_forward_expand(gf, k_cur);
  1176. ggml_build_forward_expand(gf, v_cur);
  1177. const auto & kq_mask = inp->get_kq_mask_cross();
  1178. ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
  1179. //cb(q, "q", il);
  1180. ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
  1181. //cb(k, "k", il);
  1182. ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
  1183. //cb(k, "v", il);
  1184. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
  1185. cb(cur, "kqv_out", il);
  1186. if (wo) {
  1187. cur = build_lora_mm(wo, cur);
  1188. }
  1189. if (wo_b) {
  1190. //cb(cur, "kqv_wo", il);
  1191. }
  1192. if (wo_b) {
  1193. cur = ggml_add(ctx0, cur, wo_b);
  1194. }
  1195. return cur;
  1196. }
  1197. ggml_tensor * llm_graph_context::build_copy_mask_state(
  1198. ggml_cgraph * gf,
  1199. ggml_tensor * s,
  1200. ggml_tensor * state_copy,
  1201. ggml_tensor * state_mask,
  1202. int32_t n_state,
  1203. int32_t n_seqs) const {
  1204. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  1205. const auto n_kv = kv_self->n;
  1206. const auto kv_head = kv_self->head;
  1207. ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
  1208. // copy states
  1209. // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
  1210. // this shrinks the tensors's ne[1] to n_kv
  1211. states = ggml_get_rows(ctx0, states, state_copy);
  1212. // clear states of sequences which are starting at the beginning of this batch
  1213. // FIXME: zero-out NANs?
  1214. states = ggml_mul(ctx0, states, state_mask);
  1215. // copy states which won't be changed further (between n_seqs and n_kv)
  1216. ggml_build_forward_expand(gf,
  1217. ggml_cpy(ctx0,
  1218. ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
  1219. ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
  1220. // the part of the states that will be used and modified
  1221. return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
  1222. }
  1223. ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
  1224. ggml_cgraph * gf,
  1225. ggml_tensor * state_copy,
  1226. ggml_tensor * state_mask,
  1227. const llama_ubatch & ubatch,
  1228. int il) const {
  1229. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  1230. const auto token_shift_count = hparams.token_shift_count;
  1231. const int64_t n_seqs = ubatch.n_seqs;
  1232. ggml_tensor * token_shift_all = kv_self->k_l[il];
  1233. ggml_tensor * token_shift = build_copy_mask_state(
  1234. gf, token_shift_all, state_copy, state_mask,
  1235. hparams.n_embd_k_s(), n_seqs);
  1236. token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
  1237. return token_shift;
  1238. }
  1239. ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
  1240. ggml_tensor * token_shift,
  1241. const llama_ubatch & ubatch,
  1242. int il) const {
  1243. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  1244. const auto token_shift_count = hparams.token_shift_count;
  1245. const auto n_embd = hparams.n_embd;
  1246. const int64_t n_seqs = ubatch.n_seqs;
  1247. const auto kv_head = kv_self->head;
  1248. return ggml_cpy(
  1249. ctx0,
  1250. ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
  1251. ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
  1252. );
  1253. }
  1254. void llm_graph_context::build_pooling(
  1255. ggml_cgraph * gf,
  1256. ggml_tensor * cls,
  1257. ggml_tensor * cls_b,
  1258. ggml_tensor * cls_out,
  1259. ggml_tensor * cls_out_b) const {
  1260. if (!cparams.embeddings) {
  1261. return;
  1262. }
  1263. ggml_tensor * inp = res->t_embd;
  1264. //// find result_norm tensor for input
  1265. //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
  1266. // inp = ggml_graph_node(gf, i);
  1267. // if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
  1268. // break;
  1269. // }
  1270. // inp = nullptr;
  1271. //}
  1272. GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
  1273. ggml_tensor * cur;
  1274. switch (pooling_type) {
  1275. case LLAMA_POOLING_TYPE_NONE:
  1276. {
  1277. cur = inp;
  1278. } break;
  1279. case LLAMA_POOLING_TYPE_MEAN:
  1280. {
  1281. ggml_tensor * inp_mean = build_inp_mean();
  1282. cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
  1283. } break;
  1284. case LLAMA_POOLING_TYPE_CLS:
  1285. case LLAMA_POOLING_TYPE_LAST:
  1286. {
  1287. ggml_tensor * inp_cls = build_inp_cls();
  1288. cur = ggml_get_rows(ctx0, inp, inp_cls);
  1289. } break;
  1290. case LLAMA_POOLING_TYPE_RANK:
  1291. {
  1292. ggml_tensor * inp_cls = build_inp_cls();
  1293. inp = ggml_get_rows(ctx0, inp, inp_cls);
  1294. // classification head
  1295. // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
  1296. GGML_ASSERT(cls != nullptr);
  1297. GGML_ASSERT(cls_b != nullptr);
  1298. cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
  1299. cur = ggml_tanh(ctx0, cur);
  1300. // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
  1301. // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
  1302. if (cls_out) {
  1303. GGML_ASSERT(cls_out_b != nullptr);
  1304. cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
  1305. }
  1306. } break;
  1307. default:
  1308. {
  1309. GGML_ABORT("unknown pooling type");
  1310. }
  1311. }
  1312. cb(cur, "result_embd_pooled", -1);
  1313. res->t_embd_pooled = cur;
  1314. ggml_build_forward_expand(gf, cur);
  1315. }