llama-graph.cpp 57 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702
  1. #include "llama-graph.h"
  2. #include "llama-impl.h"
  3. #include "llama-batch.h"
  4. #include "llama-cparams.h"
  5. #include "llama-kv-cache.h"
  6. #include <cassert>
  7. #include <cmath>
  8. #include <cstring>
  9. static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
  10. // TODO move to hparams if a T5 variant appears that uses a different value
  11. const int64_t max_distance = 128;
  12. if (bidirectional) {
  13. n_buckets >>= 1;
  14. }
  15. const int64_t max_exact = n_buckets >> 1;
  16. int32_t relative_position = x - y;
  17. int32_t relative_bucket = 0;
  18. if (bidirectional) {
  19. relative_bucket += (relative_position > 0) * n_buckets;
  20. relative_position = abs(relative_position);
  21. } else {
  22. relative_position = -std::min<int32_t>(relative_position, 0);
  23. }
  24. int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
  25. relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
  26. relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
  27. return relative_bucket;
  28. }
  29. void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
  30. if (ubatch->token) {
  31. const int64_t n_tokens = ubatch->n_tokens;
  32. ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
  33. }
  34. if (ubatch->embd) {
  35. const int64_t n_embd = embd->ne[0];
  36. const int64_t n_tokens = ubatch->n_tokens;
  37. ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
  38. }
  39. }
  40. void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
  41. if (ubatch->pos && pos) {
  42. const int64_t n_tokens = ubatch->n_tokens;
  43. if (ubatch->token && n_pos_per_embd == 4) {
  44. // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
  45. // the 3 first dims are the same, and 4th dim is all 0
  46. std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
  47. // copy the first dimension
  48. for (int i = 0; i < n_tokens; ++i) {
  49. pos_data[ i] = ubatch->pos[i];
  50. pos_data[ n_tokens + i] = ubatch->pos[i];
  51. pos_data[2 * n_tokens + i] = ubatch->pos[i];
  52. pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
  53. }
  54. ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
  55. } else {
  56. ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
  57. }
  58. }
  59. }
  60. void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
  61. if (ubatch->pos && attn_scale) {
  62. const int64_t n_tokens = ubatch->n_tokens;
  63. std::vector<float> attn_scale_data(n_tokens, 0.0f);
  64. for (int i = 0; i < n_tokens; ++i) {
  65. const float pos = ubatch->pos[i];
  66. attn_scale_data[i] = std::log(
  67. std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
  68. ) * f_attn_temp_scale + 1.0;
  69. }
  70. ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
  71. }
  72. }
  73. void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
  74. if (pos_bucket) {
  75. const int64_t n_tokens = ubatch->n_tokens;
  76. GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
  77. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  78. int32_t * data = (int32_t *) pos_bucket->data;
  79. for (int h = 0; h < 1; ++h) {
  80. for (int j = 0; j < n_tokens; ++j) {
  81. for (int i = 0; i < n_tokens; ++i) {
  82. data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
  83. }
  84. }
  85. }
  86. }
  87. }
  88. void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
  89. if (pos_bucket) {
  90. const int64_t n_tokens = ubatch->n_tokens;
  91. GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
  92. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  93. int32_t * data = (int32_t *) pos_bucket->data;
  94. const int64_t n_kv = kv_self->n;
  95. for (int h = 0; h < 1; ++h) {
  96. for (int j = 0; j < n_tokens; ++j) {
  97. for (int i = 0; i < n_kv; ++i) {
  98. data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
  99. }
  100. }
  101. }
  102. }
  103. }
  104. void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
  105. if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
  106. //GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
  107. if (!out_ids) {
  108. LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
  109. } else {
  110. const int64_t n_tokens = ubatch->n_tokens;
  111. GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
  112. int32_t * data = (int32_t *) out_ids->data;
  113. if (n_outputs == n_tokens) {
  114. for (int i = 0; i < n_tokens; ++i) {
  115. data[i] = i;
  116. }
  117. } else if (ubatch->output) {
  118. int32_t n_outputs = 0;
  119. for (int i = 0; i < n_tokens; ++i) {
  120. if (ubatch->output[i]) {
  121. data[n_outputs++] = i;
  122. }
  123. }
  124. // the graph needs to have been passed the correct number of outputs
  125. GGML_ASSERT(n_outputs == n_outputs);
  126. } else if (n_outputs == 1) {
  127. // only keep last output
  128. data[0] = n_tokens - 1;
  129. } else {
  130. GGML_ASSERT(n_outputs == 0);
  131. }
  132. }
  133. }
  134. }
  135. void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
  136. if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
  137. const int64_t n_tokens = ubatch->n_tokens;
  138. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  139. const int64_t n_seqs = ubatch->n_seqs;
  140. GGML_ASSERT(mean);
  141. GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
  142. float * data = (float *) mean->data;
  143. memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
  144. std::vector<uint64_t> sum(n_tokens, 0);
  145. for (int s = 0; s < n_seqs; ++s) {
  146. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  147. // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
  148. GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
  149. sum[seq_id] += ubatch->n_seq_tokens;
  150. }
  151. std::vector<float> div(n_tokens, 0.0f);
  152. for (int i = 0; i < n_tokens; ++i) {
  153. const uint64_t s = sum[i];
  154. if (s > 0) {
  155. div[i] = 1.0f/float(s);
  156. }
  157. }
  158. for (int s = 0; s < n_seqs; ++s) {
  159. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  160. for (int i = 0; i < n_seq_tokens; ++i) {
  161. data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
  162. }
  163. }
  164. }
  165. }
  166. void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
  167. if (cparams.embeddings && (
  168. cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
  169. cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
  170. const int64_t n_tokens = ubatch->n_tokens;
  171. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  172. const int64_t n_seqs = ubatch->n_seqs;
  173. GGML_ASSERT(cls);
  174. GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
  175. uint32_t * data = (uint32_t *) cls->data;
  176. memset(cls->data, 0, n_tokens * ggml_element_size(cls));
  177. for (int s = 0; s < n_seqs; ++s) {
  178. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  179. // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
  180. GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
  181. for (int i = 0; i < n_seq_tokens; ++i) {
  182. const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
  183. if (pos == 0) {
  184. data[seq_id] = s*n_seq_tokens + i;
  185. }
  186. }
  187. }
  188. }
  189. if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
  190. const int64_t n_tokens = ubatch->n_tokens;
  191. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  192. const int64_t n_seqs = ubatch->n_seqs;
  193. GGML_ASSERT(cls);
  194. GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
  195. uint32_t * data = (uint32_t *) cls->data;
  196. memset(cls->data, 0, n_tokens * ggml_element_size(cls));
  197. std::vector<int> last_pos(n_tokens, -1);
  198. std::vector<int> last_row(n_tokens, -1);
  199. for (int s = 0; s < n_seqs; ++s) {
  200. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  201. // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
  202. GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
  203. for (int i = 0; i < n_seq_tokens; ++i) {
  204. const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
  205. if (pos >= last_pos[seq_id]) {
  206. last_pos[seq_id] = pos;
  207. last_row[seq_id] = s*n_seq_tokens + i;
  208. }
  209. }
  210. }
  211. for (int i = 0; i < n_tokens; ++i) {
  212. if (last_row[i] >= 0) {
  213. data[i] = last_row[i];
  214. }
  215. }
  216. }
  217. }
  218. void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
  219. GGML_UNUSED(ubatch);
  220. const int64_t n_kv = kv_self->n;
  221. if (s_copy) {
  222. GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
  223. int32_t * data = (int32_t *) s_copy->data;
  224. // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
  225. for (uint32_t i = 0; i < n_kv; ++i) {
  226. data[i] = kv_self->s_copy(i);
  227. }
  228. }
  229. }
  230. void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
  231. GGML_UNUSED(ubatch);
  232. const int64_t n_kv = kv_self->n;
  233. if (s_mask) {
  234. GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
  235. float * data = (float *) s_mask->data;
  236. // clear unused states
  237. for (int i = 0; i < n_kv; ++i) {
  238. data[i] = kv_self->s_mask(i);
  239. }
  240. }
  241. }
  242. void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
  243. GGML_UNUSED(ubatch);
  244. if (cross_embd && !cross->v_embd.empty()) {
  245. assert(cross_embd->type == GGML_TYPE_F32);
  246. ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd));
  247. }
  248. }
  249. void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
  250. if (kq_mask) {
  251. if (cparams.causal_attn) {
  252. const int64_t n_kv = ubatch->n_tokens;
  253. const int64_t n_tokens = ubatch->n_tokens;
  254. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  255. const int64_t n_seqs = ubatch->n_seqs;
  256. GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
  257. float * data = (float *) kq_mask->data;
  258. for (int h = 0; h < 1; ++h) {
  259. for (int s1 = 0; s1 < n_seqs; ++s1) {
  260. const llama_seq_id seq_id = ubatch->seq_id[s1][0];
  261. for (int j = 0; j < n_seq_tokens; ++j) {
  262. const int32_t tj = s1*n_seq_tokens + j;
  263. for (int s0 = 0; s0 < n_seqs; ++s0) {
  264. for (int i = 0; i < n_seq_tokens; ++i) {
  265. const int32_t ti = s0*n_seq_tokens + i;
  266. float f = -INFINITY;
  267. for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
  268. if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
  269. if (hparams.use_alibi) {
  270. f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
  271. } else {
  272. f = 0.0f;
  273. }
  274. break;
  275. }
  276. }
  277. data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
  278. }
  279. }
  280. }
  281. }
  282. }
  283. } else {
  284. const int64_t n_tokens = ubatch->n_tokens;
  285. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  286. const int64_t n_seqs = ubatch->n_seqs;
  287. const int64_t n_stride = ubatch->n_tokens;
  288. GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
  289. float * data = (float *) kq_mask->data;
  290. for (int h = 0; h < 1; ++h) {
  291. for (int s1 = 0; s1 < n_seqs; ++s1) {
  292. const llama_seq_id seq_id = ubatch->seq_id[s1][0];
  293. for (int j = 0; j < n_seq_tokens; ++j) {
  294. const int32_t tj = s1*n_seq_tokens + j;
  295. for (int s0 = 0; s0 < n_seqs; ++s0) {
  296. for (int i = 0; i < n_seq_tokens; ++i) {
  297. const int32_t ti = s0*n_seq_tokens + i;
  298. float f = -INFINITY;
  299. for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
  300. if (ubatch->seq_id[s0][s] == seq_id) {
  301. if (hparams.use_alibi) {
  302. f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
  303. } else {
  304. f = 0.0f;
  305. }
  306. break;
  307. }
  308. }
  309. data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
  310. }
  311. }
  312. for (int i = n_tokens; i < n_stride; ++i) {
  313. data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
  314. }
  315. }
  316. }
  317. }
  318. }
  319. }
  320. }
  321. void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
  322. if (self_kq_mask || self_kq_mask_swa) {
  323. const int64_t n_kv = kv_self->n;
  324. const int64_t n_tokens = ubatch->n_tokens;
  325. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  326. const int64_t n_seqs = ubatch->n_seqs;
  327. float * data = nullptr;
  328. float * data_swa = nullptr;
  329. if (self_kq_mask) {
  330. GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
  331. data = (float *) self_kq_mask->data;
  332. }
  333. if (self_kq_mask_swa) {
  334. GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
  335. data_swa = (float *) self_kq_mask_swa->data;
  336. }
  337. // Use only the previous KV cells of the correct sequence for each token of the ubatch.
  338. // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
  339. // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
  340. // Causal mask:
  341. // xxx-------
  342. // xxxx------
  343. // xxxxx-----
  344. // Non-causal mask:
  345. // xxxxx-----
  346. // xxxxx-----
  347. // xxxxx-----
  348. // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
  349. for (int h = 0; h < 1; ++h) {
  350. for (int s = 0; s < n_seqs; ++s) {
  351. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  352. for (int j = 0; j < n_seq_tokens; ++j) {
  353. const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
  354. for (int i = 0; i < n_kv; ++i) {
  355. float f;
  356. // mask the token if:
  357. if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
  358. || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
  359. ) {
  360. f = -INFINITY;
  361. } else {
  362. if (hparams.use_alibi) {
  363. f = -std::abs(kv_self->cells[i].pos - pos);
  364. } else {
  365. f = 0.0f;
  366. }
  367. }
  368. if (data) {
  369. data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
  370. }
  371. // may need to cut off old tokens for sliding window
  372. // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
  373. if (data_swa) {
  374. if (hparams.n_attn_chunk) {
  375. llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
  376. if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
  377. f = -INFINITY;
  378. }
  379. } else {
  380. if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
  381. f = -INFINITY;
  382. }
  383. }
  384. data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
  385. }
  386. }
  387. }
  388. }
  389. // mask padded tokens
  390. if (data) {
  391. for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  392. for (int j = 0; j < n_kv; ++j) {
  393. data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
  394. }
  395. }
  396. }
  397. // mask padded tokens
  398. if (data_swa) {
  399. for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  400. for (int j = 0; j < n_kv; ++j) {
  401. data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
  402. }
  403. }
  404. }
  405. }
  406. }
  407. }
  408. void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
  409. if (cross_kq_mask) {
  410. const int64_t n_enc = cross_kq_mask->ne[0];
  411. const int64_t n_tokens = ubatch->n_tokens;
  412. GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
  413. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  414. float * data = (float *) cross_kq_mask->data;
  415. for (int h = 0; h < 1; ++h) {
  416. for (int j = 0; j < n_tokens; ++j) {
  417. for (int i = 0; i < n_enc; ++i) {
  418. float f = -INFINITY;
  419. for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
  420. const llama_seq_id seq_id = ubatch->seq_id[j][s];
  421. if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
  422. f = 0.0f;
  423. }
  424. }
  425. data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
  426. }
  427. }
  428. for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  429. for (int j = 0; j < n_enc; ++j) {
  430. data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
  431. }
  432. }
  433. }
  434. }
  435. }
  436. //
  437. // llm_graph_context
  438. //
  439. llm_graph_context::llm_graph_context(const llm_graph_params & params) :
  440. arch (params.arch),
  441. hparams (params.hparams),
  442. cparams (params.cparams),
  443. ubatch (params.ubatch),
  444. n_embd (hparams.n_embd),
  445. n_layer (hparams.n_layer),
  446. n_rot (hparams.n_rot),
  447. n_ctx (cparams.n_ctx),
  448. n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
  449. n_head (hparams.n_head()),
  450. n_head_kv (hparams.n_head_kv()),
  451. n_embd_head_k (hparams.n_embd_head_k),
  452. n_embd_k_gqa (hparams.n_embd_k_gqa()),
  453. n_embd_head_v (hparams.n_embd_head_v),
  454. n_embd_v_gqa (hparams.n_embd_v_gqa()),
  455. n_expert (hparams.n_expert),
  456. n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
  457. freq_base (cparams.rope_freq_base),
  458. freq_scale (cparams.rope_freq_scale),
  459. ext_factor (cparams.yarn_ext_factor),
  460. attn_factor (cparams.yarn_attn_factor),
  461. beta_fast (cparams.yarn_beta_fast),
  462. beta_slow (cparams.yarn_beta_slow),
  463. norm_eps (hparams.f_norm_eps),
  464. norm_rms_eps (hparams.f_norm_rms_eps),
  465. n_tokens (ubatch.n_tokens),
  466. n_outputs (params.n_outputs),
  467. n_ctx_orig (cparams.n_ctx_orig_yarn),
  468. pooling_type (cparams.pooling_type),
  469. rope_type (hparams.rope_type),
  470. ctx0 (params.ctx),
  471. sched (params.sched),
  472. backend_cpu (params.backend_cpu),
  473. cvec (params.cvec),
  474. loras (params.loras),
  475. memory (params.memory),
  476. cross (params.cross),
  477. cb_func (params.cb),
  478. res (std::make_unique<llm_graph_result>()) {
  479. }
  480. int64_t llm_graph_context::n_pos_per_embd() const {
  481. return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
  482. }
  483. void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
  484. if (cb_func) {
  485. cb_func(ubatch, cur, name, il);
  486. }
  487. }
  488. ggml_tensor * llm_graph_context::build_cvec(
  489. ggml_tensor * cur,
  490. int il) const {
  491. return cvec->apply_to(ctx0, cur, il);
  492. }
  493. ggml_tensor * llm_graph_context::build_lora_mm(
  494. ggml_tensor * w,
  495. ggml_tensor * cur) const {
  496. ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
  497. for (const auto & lora : *loras) {
  498. llama_adapter_lora_weight * lw = lora.first->get_weight(w);
  499. if (lw == nullptr) {
  500. continue;
  501. }
  502. const float adapter_scale = lora.second;
  503. const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
  504. ggml_tensor * ab_cur = ggml_mul_mat(
  505. ctx0, lw->b,
  506. ggml_mul_mat(ctx0, lw->a, cur)
  507. );
  508. ab_cur = ggml_scale(ctx0, ab_cur, scale);
  509. res = ggml_add(ctx0, res, ab_cur);
  510. }
  511. return res;
  512. }
  513. ggml_tensor * llm_graph_context::build_lora_mm_id(
  514. ggml_tensor * w, // ggml_tensor * as
  515. ggml_tensor * cur, // ggml_tensor * b
  516. ggml_tensor * ids) const {
  517. ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
  518. for (const auto & lora : *loras) {
  519. llama_adapter_lora_weight * lw = lora.first->get_weight(w);
  520. if (lw == nullptr) {
  521. continue;
  522. }
  523. const float alpha = lora.first->alpha;
  524. const float rank = (float) lw->b->ne[0];
  525. const float scale = alpha ? lora.second * alpha / rank : lora.second;
  526. ggml_tensor * ab_cur = ggml_mul_mat_id(
  527. ctx0, lw->b,
  528. ggml_mul_mat_id(ctx0, lw->a, cur, ids),
  529. ids
  530. );
  531. ab_cur = ggml_scale(ctx0, ab_cur, scale);
  532. res = ggml_add(ctx0, res, ab_cur);
  533. }
  534. return res;
  535. }
  536. ggml_tensor * llm_graph_context::build_norm(
  537. ggml_tensor * cur,
  538. ggml_tensor * mw,
  539. ggml_tensor * mb,
  540. llm_norm_type type,
  541. int il) const {
  542. switch (type) {
  543. case LLM_NORM: cur = ggml_norm (ctx0, cur, hparams.f_norm_eps); break;
  544. case LLM_NORM_RMS: cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;
  545. case LLM_NORM_GROUP:
  546. {
  547. cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]);
  548. cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
  549. cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[2]);
  550. } break;
  551. }
  552. if (mw || mb) {
  553. cb(cur, "norm", il);
  554. }
  555. if (mw) {
  556. cur = ggml_mul(ctx0, cur, mw);
  557. if (mb) {
  558. cb(cur, "norm_w", il);
  559. }
  560. }
  561. if (mb) {
  562. cur = ggml_add(ctx0, cur, mb);
  563. }
  564. return cur;
  565. }
  566. ggml_tensor * llm_graph_context::build_ffn(
  567. ggml_tensor * cur,
  568. ggml_tensor * up,
  569. ggml_tensor * up_b,
  570. ggml_tensor * up_s,
  571. ggml_tensor * gate,
  572. ggml_tensor * gate_b,
  573. ggml_tensor * gate_s,
  574. ggml_tensor * down,
  575. ggml_tensor * down_b,
  576. ggml_tensor * down_s,
  577. ggml_tensor * act_scales,
  578. llm_ffn_op_type type_op,
  579. llm_ffn_gate_type type_gate,
  580. int il) const {
  581. ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
  582. cb(tmp, "ffn_up", il);
  583. if (up_b) {
  584. tmp = ggml_add(ctx0, tmp, up_b);
  585. cb(tmp, "ffn_up_b", il);
  586. }
  587. if (up_s) {
  588. tmp = ggml_mul(ctx0, tmp, up_s);
  589. cb(tmp, "ffn_up_s", il);
  590. }
  591. if (gate) {
  592. switch (type_gate) {
  593. case LLM_FFN_SEQ:
  594. {
  595. cur = build_lora_mm(gate, tmp);
  596. cb(cur, "ffn_gate", il);
  597. } break;
  598. case LLM_FFN_PAR:
  599. {
  600. cur = build_lora_mm(gate, cur);
  601. cb(cur, "ffn_gate", il);
  602. } break;
  603. }
  604. if (gate_b) {
  605. cur = ggml_add(ctx0, cur, gate_b);
  606. cb(cur, "ffn_gate_b", il);
  607. }
  608. if (gate_s) {
  609. cur = ggml_mul(ctx0, cur, gate_s);
  610. cb(cur, "ffn_gate_s", il);
  611. }
  612. } else {
  613. cur = tmp;
  614. }
  615. switch (type_op) {
  616. case LLM_FFN_SILU:
  617. {
  618. cur = ggml_silu(ctx0, cur);
  619. cb(cur, "ffn_silu", il);
  620. } break;
  621. case LLM_FFN_GELU:
  622. {
  623. cur = ggml_gelu(ctx0, cur);
  624. cb(cur, "ffn_gelu", il);
  625. if (act_scales != NULL) {
  626. cur = ggml_div(ctx0, cur, act_scales);
  627. cb(cur, "ffn_act", il);
  628. }
  629. } break;
  630. case LLM_FFN_RELU:
  631. {
  632. cur = ggml_relu(ctx0, cur);
  633. cb(cur, "ffn_relu", il);
  634. } break;
  635. case LLM_FFN_RELU_SQR:
  636. {
  637. cur = ggml_relu(ctx0, cur);
  638. cb(cur, "ffn_relu", il);
  639. cur = ggml_sqr(ctx0, cur);
  640. cb(cur, "ffn_sqr(relu)", il);
  641. } break;
  642. case LLM_FFN_SWIGLU:
  643. {
  644. // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
  645. int64_t split_point = cur->ne[0] / 2;
  646. ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
  647. ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
  648. x0 = ggml_silu(ctx0, x0);
  649. cb(cur, "ffn_silu", il);
  650. cur = ggml_mul(ctx0, x0, x1);
  651. cb(cur, "ffn_mul", il);
  652. } break;
  653. }
  654. if (type_gate == LLM_FFN_PAR) {
  655. cur = ggml_mul(ctx0, cur, tmp);
  656. cb(cur, "ffn_gate_par", il);
  657. }
  658. if (down) {
  659. cur = build_lora_mm(down, cur);
  660. if (arch == LLM_ARCH_GLM4) {
  661. // GLM4 seems to have numerical issues with half-precision accumulators
  662. ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
  663. }
  664. }
  665. if (down_b) {
  666. cb(cur, "ffn_down", il);
  667. }
  668. if (down_b) {
  669. cur = ggml_add(ctx0, cur, down_b);
  670. }
  671. if (down_s) {
  672. cur = ggml_mul(ctx0, cur, down_s);
  673. cb(cur, "ffn_down_s", il);
  674. }
  675. return cur;
  676. }
  677. ggml_tensor * llm_graph_context::build_moe_ffn(
  678. ggml_tensor * cur,
  679. ggml_tensor * gate_inp,
  680. ggml_tensor * up_exps,
  681. ggml_tensor * gate_exps,
  682. ggml_tensor * down_exps,
  683. ggml_tensor * exp_probs_b,
  684. int64_t n_expert,
  685. int64_t n_expert_used,
  686. llm_ffn_op_type type_op,
  687. bool norm_w,
  688. bool scale_w,
  689. float w_scale,
  690. llama_expert_gating_func_type gating_op,
  691. int il) const {
  692. const int64_t n_embd = cur->ne[0];
  693. const int64_t n_tokens = cur->ne[1];
  694. const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
  695. ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
  696. cb(logits, "ffn_moe_logits", il);
  697. ggml_tensor * probs = nullptr;
  698. switch (gating_op) {
  699. case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
  700. {
  701. probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
  702. } break;
  703. case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
  704. {
  705. probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
  706. } break;
  707. default:
  708. GGML_ABORT("fatal error");
  709. }
  710. cb(probs, "ffn_moe_probs", il);
  711. // add experts selection bias - introduced in DeepSeek V3
  712. // leave probs unbiased as it's later used to get expert weights
  713. ggml_tensor * selection_probs = probs;
  714. if (exp_probs_b != nullptr) {
  715. selection_probs = ggml_add(ctx0, probs, exp_probs_b);
  716. cb(selection_probs, "ffn_moe_probs_biased", il);
  717. }
  718. // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
  719. // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
  720. if (arch == LLM_ARCH_LLAMA4) {
  721. selection_probs = logits;
  722. }
  723. // select experts
  724. ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
  725. cb(selected_experts->src[0], "ffn_moe_argsort", il);
  726. cb(selected_experts, "ffn_moe_topk", il);
  727. ggml_tensor * weights = ggml_get_rows(ctx0,
  728. ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
  729. cb(weights, "ffn_moe_weights", il);
  730. if (norm_w) {
  731. weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
  732. ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
  733. cb(weights_sum, "ffn_moe_weights_sum", il);
  734. weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
  735. cb(weights, "ffn_moe_weights_norm", il);
  736. weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
  737. }
  738. if (scale_w) {
  739. weights = ggml_scale(ctx0, weights, w_scale);
  740. cb(weights, "ffn_moe_weights_scaled", il);
  741. }
  742. cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
  743. if (weight_before_ffn) {
  744. // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
  745. ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
  746. repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
  747. cur = ggml_mul(ctx0, repeated, weights);
  748. cb(cur, "ffn_moe_weighted", il);
  749. }
  750. ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
  751. cb(up, "ffn_moe_up", il);
  752. ggml_tensor * experts = nullptr;
  753. if (gate_exps) {
  754. cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
  755. cb(cur, "ffn_moe_gate", il);
  756. } else {
  757. cur = up;
  758. }
  759. switch (type_op) {
  760. case LLM_FFN_SILU:
  761. {
  762. cur = ggml_silu(ctx0, cur);
  763. cb(cur, "ffn_moe_silu", il);
  764. } break;
  765. case LLM_FFN_GELU:
  766. {
  767. cur = ggml_gelu(ctx0, cur);
  768. cb(cur, "ffn_moe_gelu", il);
  769. } break;
  770. default:
  771. GGML_ABORT("fatal error");
  772. }
  773. if (gate_exps) {
  774. cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
  775. cb(cur, "ffn_moe_gate_par", il);
  776. }
  777. experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
  778. cb(experts, "ffn_moe_down", il);
  779. if (!weight_before_ffn) {
  780. experts = ggml_mul(ctx0, experts, weights);
  781. cb(cur, "ffn_moe_weighted", il);
  782. }
  783. // aggregate experts
  784. ggml_tensor * moe_out = nullptr;
  785. for (int i = 0; i < n_expert_used; ++i) {
  786. ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
  787. experts->nb[2], i*experts->nb[1]);
  788. if (i == 0) {
  789. moe_out = cur_expert;
  790. } else {
  791. moe_out = ggml_add(ctx0, moe_out, cur_expert);
  792. }
  793. }
  794. if (n_expert_used == 1) {
  795. // avoid returning a non-contiguous tensor
  796. moe_out = ggml_cont(ctx0, moe_out);
  797. }
  798. cb(moe_out, "ffn_moe_out", il);
  799. return moe_out;
  800. }
  801. // input embeddings with optional lora
  802. ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
  803. const int64_t n_embd = hparams.n_embd;
  804. auto inp = std::make_unique<llm_graph_input_embd>();
  805. ggml_tensor * cur = nullptr;
  806. if (ubatch.token) {
  807. inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
  808. //cb(inp->tokens, "inp_tokens", -1);
  809. ggml_set_input(inp->tokens);
  810. cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
  811. // apply lora for embedding tokens if needed
  812. for (const auto & lora : *loras) {
  813. llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
  814. if (lw == nullptr) {
  815. continue;
  816. }
  817. const float adapter_scale = lora.second;
  818. const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
  819. ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
  820. ctx0, lw->b, // non-transposed lora_b
  821. ggml_get_rows(ctx0, lw->a, inp->tokens)
  822. ), scale);
  823. cur = ggml_add(ctx0, cur, inpL_delta);
  824. }
  825. } else {
  826. inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
  827. ggml_set_input(inp->embd);
  828. cur = inp->embd;
  829. }
  830. // For Granite architecture
  831. if (hparams.f_embedding_scale != 0.0f) {
  832. cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
  833. }
  834. cb(cur, "inp_embd", -1);
  835. res->add_input(std::move(inp));
  836. return cur;
  837. }
  838. ggml_tensor * llm_graph_context::build_inp_pos() const {
  839. auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
  840. auto & cur = inp->pos;
  841. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
  842. ggml_set_input(cur);
  843. res->add_input(std::move(inp));
  844. return cur;
  845. }
  846. ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
  847. auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
  848. auto & cur = inp->attn_scale;
  849. // this need to be 1x1xN for broadcasting
  850. cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
  851. ggml_set_input(cur);
  852. res->add_input(std::move(inp));
  853. return cur;
  854. }
  855. ggml_tensor * llm_graph_context::build_inp_out_ids() const {
  856. auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
  857. auto & cur = inp->out_ids;
  858. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
  859. ggml_set_input(cur);
  860. res->add_input(std::move(inp));
  861. return cur;
  862. }
  863. ggml_tensor * llm_graph_context::build_inp_mean() const {
  864. auto inp = std::make_unique<llm_graph_input_mean>(cparams);
  865. auto & cur = inp->mean;
  866. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
  867. ggml_set_input(cur);
  868. res->add_input(std::move(inp));
  869. return cur;
  870. }
  871. ggml_tensor * llm_graph_context::build_inp_cls() const {
  872. auto inp = std::make_unique<llm_graph_input_cls>(cparams);
  873. auto & cur = inp->cls;
  874. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
  875. ggml_set_input(cur);
  876. res->add_input(std::move(inp));
  877. return cur;
  878. }
  879. ggml_tensor * llm_graph_context::build_inp_s_copy() const {
  880. const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
  881. auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
  882. const auto n_kv = kv_self->n;
  883. auto & cur = inp->s_copy;
  884. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
  885. ggml_set_input(cur);
  886. res->add_input(std::move(inp));
  887. return cur;
  888. }
  889. ggml_tensor * llm_graph_context::build_inp_s_mask() const {
  890. const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
  891. auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
  892. const auto n_kv = kv_self->n;
  893. auto & cur = inp->s_mask;
  894. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
  895. ggml_set_input(cur);
  896. res->add_input(std::move(inp));
  897. return cur;
  898. }
  899. ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
  900. auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
  901. auto & cur = inp->cross_embd;
  902. // if we have the output embeddings from the encoder, use them directly
  903. // TODO: needs more work to be correct, for now just use the tensor shape
  904. //if (cross->t_embd) {
  905. // cur = ggml_view_tensor(ctx0, cross->t_embd);
  906. // return cur;
  907. //}
  908. const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd;
  909. const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
  910. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
  911. ggml_set_input(cur);
  912. res->add_input(std::move(inp));
  913. return cur;
  914. }
  915. ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
  916. auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);
  917. auto & cur = inp->pos_bucket;
  918. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
  919. ggml_set_input(cur);
  920. res->add_input(std::move(inp));
  921. return cur;
  922. }
  923. ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
  924. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  925. auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
  926. const auto n_kv = kv_self->n;
  927. auto & cur = inp->pos_bucket;
  928. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
  929. ggml_set_input(cur);
  930. res->add_input(std::move(inp));
  931. return cur;
  932. }
  933. ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
  934. ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]);
  935. cb(pos_bucket_1d, "pos_bucket_1d", -1);
  936. ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
  937. pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
  938. pos_bias = ggml_permute (ctx0, pos_bias, 2, 0, 1, 3);
  939. pos_bias = ggml_cont (ctx0, pos_bias);
  940. cb(pos_bias, "pos_bias", -1);
  941. return pos_bias;
  942. }
  943. ggml_tensor * llm_graph_context::build_attn_mha(
  944. ggml_cgraph * gf,
  945. ggml_tensor * q,
  946. ggml_tensor * k,
  947. ggml_tensor * v,
  948. ggml_tensor * kq_b,
  949. ggml_tensor * kq_mask,
  950. ggml_tensor * v_mla,
  951. bool v_trans,
  952. float kq_scale) const {
  953. //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  954. //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  955. //const int64_t n_head = hparams.n_head(il);
  956. //const int64_t n_head_kv = hparams.n_head_kv(il);
  957. //const auto & n_embd_head_k = hparams.n_embd_head_k;
  958. //const auto & n_embd_head_v = hparams.n_embd_head_v;
  959. const auto n_tokens = q->ne[1];
  960. const auto n_head = q->ne[2];
  961. const auto n_kv = k->ne[1];
  962. ggml_tensor * cur;
  963. // TODO: replace hardcoded padding with ggml-provided padding
  964. if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
  965. GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
  966. if (v_trans) {
  967. v = ggml_transpose(ctx0, v);
  968. }
  969. // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
  970. if (k->type == GGML_TYPE_F32) {
  971. k = ggml_cast(ctx0, k, GGML_TYPE_F16);
  972. }
  973. if (v->type == GGML_TYPE_F32) {
  974. v = ggml_cast(ctx0, v, GGML_TYPE_F16);
  975. }
  976. cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
  977. hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
  978. ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
  979. if (v_mla) {
  980. cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
  981. cur = ggml_mul_mat(ctx0, v_mla, cur);
  982. }
  983. cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
  984. } else {
  985. ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
  986. // note: this op tends to require high floating point range
  987. // while for some models F16 is enough, for others it is not, so we default to F32 here
  988. ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
  989. if (arch == LLM_ARCH_GROK) {
  990. // need to do the following:
  991. // multiply by attn_output_multiplyer of 0.08838834764831845
  992. // and then :
  993. // kq = 30 * tanh(kq / 30)
  994. // before the softmax below
  995. kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
  996. kq = ggml_scale(ctx0, kq, 30);
  997. }
  998. if (hparams.attn_soft_cap) {
  999. kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
  1000. kq = ggml_tanh (ctx0, kq);
  1001. kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
  1002. }
  1003. if (kq_b) {
  1004. kq = ggml_add(ctx0, kq, kq_b);
  1005. }
  1006. kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
  1007. if (!v_trans) {
  1008. // note: avoid this branch
  1009. v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
  1010. }
  1011. ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
  1012. // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
  1013. if (v_mla) {
  1014. kqv = ggml_mul_mat(ctx0, v_mla, kqv);
  1015. }
  1016. cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
  1017. cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
  1018. if (!cparams.offload_kqv) {
  1019. // all nodes between the KV store and the attention output are run on the CPU
  1020. ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
  1021. }
  1022. }
  1023. ggml_build_forward_expand(gf, cur);
  1024. return cur;
  1025. }
  1026. llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
  1027. auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
  1028. // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
  1029. inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1030. //cb(inp_kq_mask, "KQ_mask", -1);
  1031. ggml_set_input(inp->kq_mask);
  1032. inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
  1033. return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
  1034. }
  1035. ggml_tensor * llm_graph_context::build_attn(
  1036. llm_graph_input_attn_no_cache * inp,
  1037. ggml_cgraph * gf,
  1038. ggml_tensor * wo,
  1039. ggml_tensor * wo_b,
  1040. ggml_tensor * q_cur,
  1041. ggml_tensor * k_cur,
  1042. ggml_tensor * v_cur,
  1043. ggml_tensor * kq_b,
  1044. ggml_tensor * v_mla,
  1045. float kq_scale,
  1046. int il) const {
  1047. GGML_UNUSED(n_tokens);
  1048. // these nodes are added to the graph together so that they are not reordered
  1049. // by doing so, the number of splits in the graph is reduced
  1050. ggml_build_forward_expand(gf, q_cur);
  1051. ggml_build_forward_expand(gf, k_cur);
  1052. ggml_build_forward_expand(gf, v_cur);
  1053. const auto & kq_mask = inp->get_kq_mask();
  1054. ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
  1055. //cb(q, "q", il);
  1056. ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
  1057. //cb(k, "k", il);
  1058. ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
  1059. //cb(k, "v", il);
  1060. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
  1061. cb(cur, "kqv_out", il);
  1062. if (wo) {
  1063. cur = build_lora_mm(wo, cur);
  1064. }
  1065. if (wo_b) {
  1066. //cb(cur, "kqv_wo", il);
  1067. }
  1068. if (wo_b) {
  1069. cur = ggml_add(ctx0, cur, wo_b);
  1070. }
  1071. return cur;
  1072. }
  1073. llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
  1074. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  1075. auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
  1076. const auto n_kv = kv_self->n;
  1077. inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1078. //cb(inp->self_kq_mask, "KQ_mask", -1);
  1079. ggml_set_input(inp->self_kq_mask);
  1080. inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
  1081. if (hparams.n_swa_pattern > 1) {
  1082. GGML_ASSERT(hparams.n_swa > 0);
  1083. inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1084. //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
  1085. ggml_set_input(inp->self_kq_mask_swa);
  1086. inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
  1087. }
  1088. return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
  1089. }
  1090. ggml_tensor * llm_graph_context::build_attn(
  1091. llm_graph_input_attn_kv_unified * inp,
  1092. ggml_cgraph * gf,
  1093. ggml_tensor * wo,
  1094. ggml_tensor * wo_b,
  1095. ggml_tensor * q_cur,
  1096. ggml_tensor * k_cur,
  1097. ggml_tensor * v_cur,
  1098. ggml_tensor * kq_b,
  1099. ggml_tensor * v_mla,
  1100. float kq_scale,
  1101. int il) const {
  1102. // these nodes are added to the graph together so that they are not reordered
  1103. // by doing so, the number of splits in the graph is reduced
  1104. ggml_build_forward_expand(gf, q_cur);
  1105. ggml_build_forward_expand(gf, k_cur);
  1106. ggml_build_forward_expand(gf, v_cur);
  1107. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  1108. const auto & n_ctx = cparams.n_ctx;
  1109. const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  1110. const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1111. const auto n_tokens = q_cur->ne[2];
  1112. const bool v_trans = !cparams.flash_attn;
  1113. // store to KV cache
  1114. {
  1115. const auto kv_head = kv_self->head;
  1116. GGML_ASSERT(kv_self->size == n_ctx);
  1117. ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
  1118. //cb(k_cache_view, "k_cache_view", il);
  1119. // note: storing RoPE-ed version of K in the KV cache
  1120. ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
  1121. v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
  1122. ggml_tensor * v_cache_view = nullptr;
  1123. if (!v_trans) {
  1124. v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
  1125. } else {
  1126. // note: the V cache is transposed when not using flash attention
  1127. v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
  1128. ( n_ctx)*ggml_element_size(kv_self->v_l[il]),
  1129. (kv_head)*ggml_element_size(kv_self->v_l[il]));
  1130. v_cur = ggml_transpose(ctx0, v_cur);
  1131. }
  1132. //cb(v_cache_view, "v_cache_view", il);
  1133. ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
  1134. }
  1135. const bool is_swa = hparams.is_swa(il);
  1136. const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
  1137. const auto n_kv = kv_self->n;
  1138. const int64_t n_head_kv = hparams.n_head_kv(il);
  1139. const auto & n_embd_head_k = hparams.n_embd_head_k;
  1140. const auto & n_embd_head_v = hparams.n_embd_head_v;
  1141. ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
  1142. //cb(q, "q", il);
  1143. ggml_tensor * k =
  1144. ggml_view_3d(ctx0, kv_self->k_l[il],
  1145. n_embd_head_k, n_kv, n_head_kv,
  1146. ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
  1147. ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
  1148. 0);
  1149. //cb(k, "k", il);
  1150. ggml_tensor * v = !v_trans ?
  1151. ggml_view_3d(ctx0, kv_self->v_l[il],
  1152. n_embd_head_v, n_kv, n_head_kv,
  1153. ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
  1154. ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
  1155. 0) :
  1156. ggml_view_3d(ctx0, kv_self->v_l[il],
  1157. n_kv, n_embd_head_v, n_head_kv,
  1158. ggml_element_size(kv_self->v_l[il])*n_ctx,
  1159. ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
  1160. 0);
  1161. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
  1162. cb(cur, "kqv_out", il);
  1163. if (wo) {
  1164. cur = build_lora_mm(wo, cur);
  1165. }
  1166. if (wo_b) {
  1167. //cb(cur, "kqv_wo", il);
  1168. }
  1169. if (wo_b) {
  1170. cur = ggml_add(ctx0, cur, wo_b);
  1171. }
  1172. return cur;
  1173. }
  1174. llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
  1175. auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
  1176. const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
  1177. inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1178. ggml_set_input(inp->cross_kq_mask);
  1179. inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
  1180. return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
  1181. }
  1182. ggml_tensor * llm_graph_context::build_attn(
  1183. llm_graph_input_attn_cross * inp,
  1184. ggml_cgraph * gf,
  1185. ggml_tensor * wo,
  1186. ggml_tensor * wo_b,
  1187. ggml_tensor * q_cur,
  1188. ggml_tensor * k_cur,
  1189. ggml_tensor * v_cur,
  1190. ggml_tensor * kq_b,
  1191. ggml_tensor * v_mla,
  1192. float kq_scale,
  1193. int il) const {
  1194. // these nodes are added to the graph together so that they are not reordered
  1195. // by doing so, the number of splits in the graph is reduced
  1196. ggml_build_forward_expand(gf, q_cur);
  1197. ggml_build_forward_expand(gf, k_cur);
  1198. ggml_build_forward_expand(gf, v_cur);
  1199. const auto & kq_mask = inp->get_kq_mask_cross();
  1200. ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
  1201. //cb(q, "q", il);
  1202. ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
  1203. //cb(k, "k", il);
  1204. ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
  1205. //cb(k, "v", il);
  1206. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
  1207. cb(cur, "kqv_out", il);
  1208. if (wo) {
  1209. cur = build_lora_mm(wo, cur);
  1210. }
  1211. if (wo_b) {
  1212. //cb(cur, "kqv_wo", il);
  1213. }
  1214. if (wo_b) {
  1215. cur = ggml_add(ctx0, cur, wo_b);
  1216. }
  1217. return cur;
  1218. }
  1219. ggml_tensor * llm_graph_context::build_copy_mask_state(
  1220. ggml_cgraph * gf,
  1221. ggml_tensor * s,
  1222. ggml_tensor * state_copy,
  1223. ggml_tensor * state_mask,
  1224. int32_t n_state,
  1225. int32_t n_seqs) const {
  1226. const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
  1227. const auto n_kv = kv_self->n;
  1228. const auto kv_head = kv_self->head;
  1229. ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
  1230. // copy states
  1231. // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
  1232. // this shrinks the tensors's ne[1] to n_kv
  1233. states = ggml_get_rows(ctx0, states, state_copy);
  1234. // clear states of sequences which are starting at the beginning of this batch
  1235. // FIXME: zero-out NANs?
  1236. states = ggml_mul(ctx0, states, state_mask);
  1237. // copy states which won't be changed further (between n_seqs and n_kv)
  1238. ggml_build_forward_expand(gf,
  1239. ggml_cpy(ctx0,
  1240. ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
  1241. ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
  1242. // the part of the states that will be used and modified
  1243. return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
  1244. }
  1245. ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
  1246. ggml_cgraph * gf,
  1247. ggml_tensor * state_copy,
  1248. ggml_tensor * state_mask,
  1249. const llama_ubatch & ubatch,
  1250. int il) const {
  1251. const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
  1252. const auto token_shift_count = hparams.token_shift_count;
  1253. const int64_t n_seqs = ubatch.n_seqs;
  1254. ggml_tensor * token_shift_all = kv_self->k_l[il];
  1255. ggml_tensor * token_shift = build_copy_mask_state(
  1256. gf, token_shift_all, state_copy, state_mask,
  1257. hparams.n_embd_k_s(), n_seqs);
  1258. token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
  1259. return token_shift;
  1260. }
  1261. ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
  1262. ggml_tensor * token_shift,
  1263. const llama_ubatch & ubatch,
  1264. int il) const {
  1265. const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
  1266. const auto token_shift_count = hparams.token_shift_count;
  1267. const auto n_embd = hparams.n_embd;
  1268. const int64_t n_seqs = ubatch.n_seqs;
  1269. const auto kv_head = kv_self->head;
  1270. return ggml_cpy(
  1271. ctx0,
  1272. ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
  1273. ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
  1274. );
  1275. }
  1276. void llm_graph_context::build_pooling(
  1277. ggml_cgraph * gf,
  1278. ggml_tensor * cls,
  1279. ggml_tensor * cls_b,
  1280. ggml_tensor * cls_out,
  1281. ggml_tensor * cls_out_b) const {
  1282. if (!cparams.embeddings) {
  1283. return;
  1284. }
  1285. ggml_tensor * inp = res->t_embd;
  1286. //// find result_norm tensor for input
  1287. //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
  1288. // inp = ggml_graph_node(gf, i);
  1289. // if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
  1290. // break;
  1291. // }
  1292. // inp = nullptr;
  1293. //}
  1294. GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
  1295. ggml_tensor * cur;
  1296. switch (pooling_type) {
  1297. case LLAMA_POOLING_TYPE_NONE:
  1298. {
  1299. cur = inp;
  1300. } break;
  1301. case LLAMA_POOLING_TYPE_MEAN:
  1302. {
  1303. ggml_tensor * inp_mean = build_inp_mean();
  1304. cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
  1305. } break;
  1306. case LLAMA_POOLING_TYPE_CLS:
  1307. case LLAMA_POOLING_TYPE_LAST:
  1308. {
  1309. ggml_tensor * inp_cls = build_inp_cls();
  1310. cur = ggml_get_rows(ctx0, inp, inp_cls);
  1311. } break;
  1312. case LLAMA_POOLING_TYPE_RANK:
  1313. {
  1314. ggml_tensor * inp_cls = build_inp_cls();
  1315. inp = ggml_get_rows(ctx0, inp, inp_cls);
  1316. // classification head
  1317. // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
  1318. GGML_ASSERT(cls != nullptr);
  1319. GGML_ASSERT(cls_b != nullptr);
  1320. cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
  1321. cur = ggml_tanh(ctx0, cur);
  1322. // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
  1323. // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
  1324. if (cls_out) {
  1325. GGML_ASSERT(cls_out_b != nullptr);
  1326. cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
  1327. }
  1328. } break;
  1329. default:
  1330. {
  1331. GGML_ABORT("unknown pooling type");
  1332. }
  1333. }
  1334. cb(cur, "result_embd_pooled", -1);
  1335. res->t_embd_pooled = cur;
  1336. ggml_build_forward_expand(gf, cur);
  1337. }