llama-graph.cpp 58 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732
  1. #include "llama-graph.h"
  2. #include "llama-impl.h"
  3. #include "llama-batch.h"
  4. #include "llama-cparams.h"
  5. #include "llama-kv-cache.h"
  6. #include <cassert>
  7. #include <cmath>
  8. #include <cstring>
  9. static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
  10. // TODO move to hparams if a T5 variant appears that uses a different value
  11. const int64_t max_distance = 128;
  12. if (bidirectional) {
  13. n_buckets >>= 1;
  14. }
  15. const int64_t max_exact = n_buckets >> 1;
  16. int32_t relative_position = x - y;
  17. int32_t relative_bucket = 0;
  18. if (bidirectional) {
  19. relative_bucket += (relative_position > 0) * n_buckets;
  20. relative_position = abs(relative_position);
  21. } else {
  22. relative_position = -std::min<int32_t>(relative_position, 0);
  23. }
  24. int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
  25. relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
  26. relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
  27. return relative_bucket;
  28. }
  29. void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
  30. if (ubatch->token) {
  31. const int64_t n_tokens = ubatch->n_tokens;
  32. ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
  33. }
  34. if (ubatch->embd) {
  35. const int64_t n_embd = embd->ne[0];
  36. const int64_t n_tokens = ubatch->n_tokens;
  37. ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
  38. }
  39. }
  40. void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
  41. if (ubatch->pos && pos) {
  42. const int64_t n_tokens = ubatch->n_tokens;
  43. if (ubatch->token && n_pos_per_embd == 4) {
  44. // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
  45. // the 3 first dims are the same, and 4th dim is all 0
  46. std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
  47. // copy the first dimension
  48. for (int i = 0; i < n_tokens; ++i) {
  49. pos_data[ i] = ubatch->pos[i];
  50. pos_data[ n_tokens + i] = ubatch->pos[i];
  51. pos_data[2 * n_tokens + i] = ubatch->pos[i];
  52. pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
  53. }
  54. ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
  55. } else {
  56. ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
  57. }
  58. }
  59. }
  60. void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
  61. if (ubatch->pos && attn_scale) {
  62. const int64_t n_tokens = ubatch->n_tokens;
  63. std::vector<float> attn_scale_data(n_tokens, 0.0f);
  64. for (int i = 0; i < n_tokens; ++i) {
  65. const float pos = ubatch->pos[i];
  66. attn_scale_data[i] = std::log(
  67. std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
  68. ) * f_attn_temp_scale + 1.0;
  69. }
  70. ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
  71. }
  72. }
  73. void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
  74. if (pos_bucket) {
  75. const int64_t n_tokens = ubatch->n_tokens;
  76. GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
  77. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  78. int32_t * data = (int32_t *) pos_bucket->data;
  79. for (int h = 0; h < 1; ++h) {
  80. for (int j = 0; j < n_tokens; ++j) {
  81. for (int i = 0; i < n_tokens; ++i) {
  82. data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
  83. }
  84. }
  85. }
  86. }
  87. }
  88. void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
  89. if (pos_bucket) {
  90. const int64_t n_tokens = ubatch->n_tokens;
  91. GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
  92. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  93. int32_t * data = (int32_t *) pos_bucket->data;
  94. const int64_t n_kv = kv_self->n;
  95. for (int h = 0; h < 1; ++h) {
  96. for (int j = 0; j < n_tokens; ++j) {
  97. for (int i = 0; i < n_kv; ++i) {
  98. data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
  99. }
  100. }
  101. }
  102. }
  103. }
  104. void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
  105. if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
  106. //GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
  107. if (!out_ids) {
  108. LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
  109. } else {
  110. const int64_t n_tokens = ubatch->n_tokens;
  111. GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
  112. int32_t * data = (int32_t *) out_ids->data;
  113. if (n_outputs == n_tokens) {
  114. for (int i = 0; i < n_tokens; ++i) {
  115. data[i] = i;
  116. }
  117. } else if (ubatch->output) {
  118. int32_t n_outputs = 0;
  119. for (int i = 0; i < n_tokens; ++i) {
  120. if (ubatch->output[i]) {
  121. data[n_outputs++] = i;
  122. }
  123. }
  124. // the graph needs to have been passed the correct number of outputs
  125. GGML_ASSERT(n_outputs == n_outputs);
  126. } else if (n_outputs == 1) {
  127. // only keep last output
  128. data[0] = n_tokens - 1;
  129. } else {
  130. GGML_ASSERT(n_outputs == 0);
  131. }
  132. }
  133. }
  134. }
  135. void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
  136. if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
  137. const int64_t n_tokens = ubatch->n_tokens;
  138. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  139. const int64_t n_seqs = ubatch->n_seqs;
  140. GGML_ASSERT(mean);
  141. GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
  142. float * data = (float *) mean->data;
  143. memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
  144. std::vector<uint64_t> sum(n_tokens, 0);
  145. for (int s = 0; s < n_seqs; ++s) {
  146. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  147. // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
  148. GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
  149. sum[seq_id] += ubatch->n_seq_tokens;
  150. }
  151. std::vector<float> div(n_tokens, 0.0f);
  152. for (int i = 0; i < n_tokens; ++i) {
  153. const uint64_t s = sum[i];
  154. if (s > 0) {
  155. div[i] = 1.0f/float(s);
  156. }
  157. }
  158. for (int s = 0; s < n_seqs; ++s) {
  159. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  160. for (int i = 0; i < n_seq_tokens; ++i) {
  161. data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
  162. }
  163. }
  164. }
  165. }
  166. void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
  167. if (cparams.embeddings && (
  168. cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
  169. cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
  170. const int64_t n_tokens = ubatch->n_tokens;
  171. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  172. const int64_t n_seqs = ubatch->n_seqs;
  173. GGML_ASSERT(cls);
  174. GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
  175. uint32_t * data = (uint32_t *) cls->data;
  176. memset(cls->data, 0, n_tokens * ggml_element_size(cls));
  177. for (int s = 0; s < n_seqs; ++s) {
  178. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  179. // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
  180. GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
  181. for (int i = 0; i < n_seq_tokens; ++i) {
  182. const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
  183. if (pos == 0) {
  184. data[seq_id] = s*n_seq_tokens + i;
  185. }
  186. }
  187. }
  188. }
  189. if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
  190. const int64_t n_tokens = ubatch->n_tokens;
  191. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  192. const int64_t n_seqs = ubatch->n_seqs;
  193. GGML_ASSERT(cls);
  194. GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
  195. uint32_t * data = (uint32_t *) cls->data;
  196. memset(cls->data, 0, n_tokens * ggml_element_size(cls));
  197. std::vector<int> last_pos(n_tokens, -1);
  198. std::vector<int> last_row(n_tokens, -1);
  199. for (int s = 0; s < n_seqs; ++s) {
  200. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  201. // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
  202. GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
  203. for (int i = 0; i < n_seq_tokens; ++i) {
  204. const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
  205. if (pos >= last_pos[seq_id]) {
  206. last_pos[seq_id] = pos;
  207. last_row[seq_id] = s*n_seq_tokens + i;
  208. }
  209. }
  210. }
  211. for (int i = 0; i < n_tokens; ++i) {
  212. if (last_row[i] >= 0) {
  213. data[i] = last_row[i];
  214. }
  215. }
  216. }
  217. }
  218. void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
  219. GGML_UNUSED(ubatch);
  220. const int64_t n_kv = kv_self->n;
  221. if (s_copy) {
  222. GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
  223. int32_t * data = (int32_t *) s_copy->data;
  224. // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
  225. for (uint32_t i = 0; i < n_kv; ++i) {
  226. const uint32_t cell_id = i + kv_self->head;
  227. //////////////////////////////////////////////
  228. // TODO: this should not mutate the KV cache !
  229. llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
  230. // prevent out-of-bound sources
  231. if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
  232. kv_cell.src = cell_id;
  233. }
  234. data[i] = kv_cell.src;
  235. // TODO: do not mutate the KV cache
  236. // ensure copy only happens once
  237. if (kv_cell.src != (int32_t) cell_id) {
  238. kv_cell.src = cell_id;
  239. }
  240. }
  241. }
  242. }
  243. void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
  244. GGML_UNUSED(ubatch);
  245. const int64_t n_kv = kv_self->n;
  246. if (s_mask) {
  247. GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
  248. float * data = (float *) s_mask->data;
  249. // clear unused states
  250. for (int i = 0; i < n_kv; ++i) {
  251. const uint32_t cell_id = i + kv_self->head;
  252. //////////////////////////////////////////////
  253. // TODO: this should not mutate the KV cache !
  254. llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
  255. data[i] = (float) (kv_cell.src >= 0);
  256. // only clear once
  257. if (kv_cell.src < 0) {
  258. kv_cell.src = cell_id;
  259. }
  260. }
  261. }
  262. }
  263. void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
  264. GGML_UNUSED(ubatch);
  265. if (cross_embd && !cross->v_embd.empty()) {
  266. assert(cross_embd->type == GGML_TYPE_F32);
  267. ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd));
  268. }
  269. }
  270. void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
  271. if (kq_mask) {
  272. if (cparams.causal_attn) {
  273. const int64_t n_kv = ubatch->n_tokens;
  274. const int64_t n_tokens = ubatch->n_tokens;
  275. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  276. const int64_t n_seqs = ubatch->n_seqs;
  277. GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
  278. float * data = (float *) kq_mask->data;
  279. for (int h = 0; h < 1; ++h) {
  280. for (int s1 = 0; s1 < n_seqs; ++s1) {
  281. const llama_seq_id seq_id = ubatch->seq_id[s1][0];
  282. for (int j = 0; j < n_seq_tokens; ++j) {
  283. const int32_t tj = s1*n_seq_tokens + j;
  284. for (int s0 = 0; s0 < n_seqs; ++s0) {
  285. for (int i = 0; i < n_seq_tokens; ++i) {
  286. const int32_t ti = s0*n_seq_tokens + i;
  287. float f = -INFINITY;
  288. for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
  289. if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
  290. if (hparams.use_alibi) {
  291. f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
  292. } else {
  293. f = 0.0f;
  294. }
  295. break;
  296. }
  297. }
  298. data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
  299. }
  300. }
  301. }
  302. }
  303. }
  304. } else {
  305. const int64_t n_tokens = ubatch->n_tokens;
  306. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  307. const int64_t n_seqs = ubatch->n_seqs;
  308. const int64_t n_stride = ubatch->n_tokens;
  309. GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
  310. float * data = (float *) kq_mask->data;
  311. for (int h = 0; h < 1; ++h) {
  312. for (int s1 = 0; s1 < n_seqs; ++s1) {
  313. const llama_seq_id seq_id = ubatch->seq_id[s1][0];
  314. for (int j = 0; j < n_seq_tokens; ++j) {
  315. const int32_t tj = s1*n_seq_tokens + j;
  316. for (int s0 = 0; s0 < n_seqs; ++s0) {
  317. for (int i = 0; i < n_seq_tokens; ++i) {
  318. const int32_t ti = s0*n_seq_tokens + i;
  319. float f = -INFINITY;
  320. for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
  321. if (ubatch->seq_id[s0][s] == seq_id) {
  322. if (hparams.use_alibi) {
  323. f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
  324. } else {
  325. f = 0.0f;
  326. }
  327. break;
  328. }
  329. }
  330. data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
  331. }
  332. }
  333. for (int i = n_tokens; i < n_stride; ++i) {
  334. data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
  335. }
  336. }
  337. }
  338. }
  339. }
  340. }
  341. }
  342. void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
  343. if (self_kq_mask || self_kq_mask_swa) {
  344. const int64_t n_kv = kv_self->n;
  345. const int64_t n_tokens = ubatch->n_tokens;
  346. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  347. const int64_t n_seqs = ubatch->n_seqs;
  348. float * data = nullptr;
  349. float * data_swa = nullptr;
  350. if (self_kq_mask) {
  351. GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
  352. data = (float *) self_kq_mask->data;
  353. }
  354. if (self_kq_mask_swa) {
  355. GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
  356. data_swa = (float *) self_kq_mask_swa->data;
  357. }
  358. // Use only the previous KV cells of the correct sequence for each token of the ubatch.
  359. // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
  360. // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
  361. // Causal mask:
  362. // xxx-------
  363. // xxxx------
  364. // xxxxx-----
  365. // Non-causal mask:
  366. // xxxxx-----
  367. // xxxxx-----
  368. // xxxxx-----
  369. // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
  370. for (int h = 0; h < 1; ++h) {
  371. for (int s = 0; s < n_seqs; ++s) {
  372. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  373. for (int j = 0; j < n_seq_tokens; ++j) {
  374. const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
  375. for (int i = 0; i < n_kv; ++i) {
  376. float f;
  377. // mask the token if:
  378. if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
  379. || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
  380. ) {
  381. f = -INFINITY;
  382. } else {
  383. if (hparams.use_alibi) {
  384. f = -std::abs(kv_self->cells[i].pos - pos);
  385. } else {
  386. f = 0.0f;
  387. }
  388. }
  389. if (data) {
  390. data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
  391. }
  392. // may need to cut off old tokens for sliding window
  393. // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
  394. if (data_swa) {
  395. if (hparams.n_attn_chunk) {
  396. llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
  397. if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
  398. f = -INFINITY;
  399. }
  400. } else {
  401. if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
  402. f = -INFINITY;
  403. }
  404. }
  405. data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
  406. }
  407. }
  408. }
  409. }
  410. // mask padded tokens
  411. if (data) {
  412. for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  413. for (int j = 0; j < n_kv; ++j) {
  414. data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
  415. }
  416. }
  417. }
  418. // mask padded tokens
  419. if (data_swa) {
  420. for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  421. for (int j = 0; j < n_kv; ++j) {
  422. data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
  423. }
  424. }
  425. }
  426. }
  427. }
  428. }
  429. void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
  430. if (cross_kq_mask) {
  431. const int64_t n_enc = cross_kq_mask->ne[0];
  432. const int64_t n_tokens = ubatch->n_tokens;
  433. GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
  434. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  435. float * data = (float *) cross_kq_mask->data;
  436. for (int h = 0; h < 1; ++h) {
  437. for (int j = 0; j < n_tokens; ++j) {
  438. for (int i = 0; i < n_enc; ++i) {
  439. float f = -INFINITY;
  440. for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
  441. const llama_seq_id seq_id = ubatch->seq_id[j][s];
  442. if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
  443. f = 0.0f;
  444. }
  445. }
  446. data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
  447. }
  448. }
  449. for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  450. for (int j = 0; j < n_enc; ++j) {
  451. data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
  452. }
  453. }
  454. }
  455. }
  456. }
  457. //
  458. // llm_graph_context
  459. //
  460. llm_graph_context::llm_graph_context(const llm_graph_params & params) :
  461. arch (params.arch),
  462. hparams (params.hparams),
  463. cparams (params.cparams),
  464. ubatch (params.ubatch),
  465. n_embd (hparams.n_embd),
  466. n_layer (hparams.n_layer),
  467. n_rot (hparams.n_rot),
  468. n_ctx (cparams.n_ctx),
  469. n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
  470. n_head (hparams.n_head()),
  471. n_head_kv (hparams.n_head_kv()),
  472. n_embd_head_k (hparams.n_embd_head_k),
  473. n_embd_k_gqa (hparams.n_embd_k_gqa()),
  474. n_embd_head_v (hparams.n_embd_head_v),
  475. n_embd_v_gqa (hparams.n_embd_v_gqa()),
  476. n_expert (hparams.n_expert),
  477. n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
  478. freq_base (cparams.rope_freq_base),
  479. freq_scale (cparams.rope_freq_scale),
  480. ext_factor (cparams.yarn_ext_factor),
  481. attn_factor (cparams.yarn_attn_factor),
  482. beta_fast (cparams.yarn_beta_fast),
  483. beta_slow (cparams.yarn_beta_slow),
  484. norm_eps (hparams.f_norm_eps),
  485. norm_rms_eps (hparams.f_norm_rms_eps),
  486. n_tokens (ubatch.n_tokens),
  487. n_outputs (params.n_outputs),
  488. n_ctx_orig (cparams.n_ctx_orig_yarn),
  489. pooling_type (cparams.pooling_type),
  490. rope_type (hparams.rope_type),
  491. ctx0 (params.ctx),
  492. sched (params.sched),
  493. backend_cpu (params.backend_cpu),
  494. cvec (params.cvec),
  495. loras (params.loras),
  496. memory (params.memory),
  497. cross (params.cross),
  498. cb_func (params.cb),
  499. res (std::make_unique<llm_graph_result>()) {
  500. }
  501. int64_t llm_graph_context::n_pos_per_embd() const {
  502. return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
  503. }
  504. void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
  505. if (cb_func) {
  506. cb_func(ubatch, cur, name, il);
  507. }
  508. }
  509. ggml_tensor * llm_graph_context::build_cvec(
  510. ggml_tensor * cur,
  511. int il) const {
  512. return cvec->apply_to(ctx0, cur, il);
  513. }
  514. ggml_tensor * llm_graph_context::build_lora_mm(
  515. ggml_tensor * w,
  516. ggml_tensor * cur) const {
  517. ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
  518. for (const auto & lora : *loras) {
  519. llama_adapter_lora_weight * lw = lora.first->get_weight(w);
  520. if (lw == nullptr) {
  521. continue;
  522. }
  523. const float adapter_scale = lora.second;
  524. const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
  525. ggml_tensor * ab_cur = ggml_mul_mat(
  526. ctx0, lw->b,
  527. ggml_mul_mat(ctx0, lw->a, cur)
  528. );
  529. ab_cur = ggml_scale(ctx0, ab_cur, scale);
  530. res = ggml_add(ctx0, res, ab_cur);
  531. }
  532. return res;
  533. }
  534. ggml_tensor * llm_graph_context::build_lora_mm_id(
  535. ggml_tensor * w, // ggml_tensor * as
  536. ggml_tensor * cur, // ggml_tensor * b
  537. ggml_tensor * ids) const {
  538. ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
  539. for (const auto & lora : *loras) {
  540. llama_adapter_lora_weight * lw = lora.first->get_weight(w);
  541. if (lw == nullptr) {
  542. continue;
  543. }
  544. const float alpha = lora.first->alpha;
  545. const float rank = (float) lw->b->ne[0];
  546. const float scale = alpha ? lora.second * alpha / rank : lora.second;
  547. ggml_tensor * ab_cur = ggml_mul_mat_id(
  548. ctx0, lw->b,
  549. ggml_mul_mat_id(ctx0, lw->a, cur, ids),
  550. ids
  551. );
  552. ab_cur = ggml_scale(ctx0, ab_cur, scale);
  553. res = ggml_add(ctx0, res, ab_cur);
  554. }
  555. return res;
  556. }
  557. ggml_tensor * llm_graph_context::build_norm(
  558. ggml_tensor * cur,
  559. ggml_tensor * mw,
  560. ggml_tensor * mb,
  561. llm_norm_type type,
  562. int il) const {
  563. switch (type) {
  564. case LLM_NORM: cur = ggml_norm (ctx0, cur, hparams.f_norm_eps); break;
  565. case LLM_NORM_RMS: cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;
  566. case LLM_NORM_GROUP:
  567. {
  568. cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]);
  569. cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
  570. cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[2]);
  571. } break;
  572. }
  573. if (mw || mb) {
  574. cb(cur, "norm", il);
  575. }
  576. if (mw) {
  577. cur = ggml_mul(ctx0, cur, mw);
  578. if (mb) {
  579. cb(cur, "norm_w", il);
  580. }
  581. }
  582. if (mb) {
  583. cur = ggml_add(ctx0, cur, mb);
  584. }
  585. return cur;
  586. }
  587. ggml_tensor * llm_graph_context::build_ffn(
  588. ggml_tensor * cur,
  589. ggml_tensor * up,
  590. ggml_tensor * up_b,
  591. ggml_tensor * up_s,
  592. ggml_tensor * gate,
  593. ggml_tensor * gate_b,
  594. ggml_tensor * gate_s,
  595. ggml_tensor * down,
  596. ggml_tensor * down_b,
  597. ggml_tensor * down_s,
  598. ggml_tensor * act_scales,
  599. llm_ffn_op_type type_op,
  600. llm_ffn_gate_type type_gate,
  601. int il) const {
  602. ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
  603. cb(tmp, "ffn_up", il);
  604. if (up_b) {
  605. tmp = ggml_add(ctx0, tmp, up_b);
  606. cb(tmp, "ffn_up_b", il);
  607. }
  608. if (up_s) {
  609. tmp = ggml_mul(ctx0, tmp, up_s);
  610. cb(tmp, "ffn_up_s", il);
  611. }
  612. if (gate) {
  613. switch (type_gate) {
  614. case LLM_FFN_SEQ:
  615. {
  616. cur = build_lora_mm(gate, tmp);
  617. cb(cur, "ffn_gate", il);
  618. } break;
  619. case LLM_FFN_PAR:
  620. {
  621. cur = build_lora_mm(gate, cur);
  622. cb(cur, "ffn_gate", il);
  623. } break;
  624. }
  625. if (gate_b) {
  626. cur = ggml_add(ctx0, cur, gate_b);
  627. cb(cur, "ffn_gate_b", il);
  628. }
  629. if (gate_s) {
  630. cur = ggml_mul(ctx0, cur, gate_s);
  631. cb(cur, "ffn_gate_s", il);
  632. }
  633. } else {
  634. cur = tmp;
  635. }
  636. switch (type_op) {
  637. case LLM_FFN_SILU:
  638. {
  639. cur = ggml_silu(ctx0, cur);
  640. cb(cur, "ffn_silu", il);
  641. } break;
  642. case LLM_FFN_GELU:
  643. {
  644. cur = ggml_gelu(ctx0, cur);
  645. cb(cur, "ffn_gelu", il);
  646. if (act_scales != NULL) {
  647. cur = ggml_div(ctx0, cur, act_scales);
  648. cb(cur, "ffn_act", il);
  649. }
  650. } break;
  651. case LLM_FFN_RELU:
  652. {
  653. cur = ggml_relu(ctx0, cur);
  654. cb(cur, "ffn_relu", il);
  655. } break;
  656. case LLM_FFN_RELU_SQR:
  657. {
  658. cur = ggml_relu(ctx0, cur);
  659. cb(cur, "ffn_relu", il);
  660. cur = ggml_sqr(ctx0, cur);
  661. cb(cur, "ffn_sqr(relu)", il);
  662. } break;
  663. case LLM_FFN_SWIGLU:
  664. {
  665. // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
  666. int64_t split_point = cur->ne[0] / 2;
  667. ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
  668. ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
  669. x0 = ggml_silu(ctx0, x0);
  670. cb(cur, "ffn_silu", il);
  671. cur = ggml_mul(ctx0, x0, x1);
  672. cb(cur, "ffn_mul", il);
  673. } break;
  674. }
  675. if (type_gate == LLM_FFN_PAR) {
  676. cur = ggml_mul(ctx0, cur, tmp);
  677. cb(cur, "ffn_gate_par", il);
  678. }
  679. if (down) {
  680. cur = build_lora_mm(down, cur);
  681. if (arch == LLM_ARCH_GLM4) {
  682. // GLM4 seems to have numerical issues with half-precision accumulators
  683. ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
  684. }
  685. }
  686. if (down_b) {
  687. cb(cur, "ffn_down", il);
  688. }
  689. if (down_b) {
  690. cur = ggml_add(ctx0, cur, down_b);
  691. }
  692. if (down_s) {
  693. cur = ggml_mul(ctx0, cur, down_s);
  694. cb(cur, "ffn_down_s", il);
  695. }
  696. return cur;
  697. }
  698. ggml_tensor * llm_graph_context::build_moe_ffn(
  699. ggml_tensor * cur,
  700. ggml_tensor * gate_inp,
  701. ggml_tensor * up_exps,
  702. ggml_tensor * gate_exps,
  703. ggml_tensor * down_exps,
  704. ggml_tensor * exp_probs_b,
  705. int64_t n_expert,
  706. int64_t n_expert_used,
  707. llm_ffn_op_type type_op,
  708. bool norm_w,
  709. bool scale_w,
  710. float w_scale,
  711. llama_expert_gating_func_type gating_op,
  712. int il) const {
  713. const int64_t n_embd = cur->ne[0];
  714. const int64_t n_tokens = cur->ne[1];
  715. const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
  716. ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
  717. cb(logits, "ffn_moe_logits", il);
  718. ggml_tensor * probs = nullptr;
  719. switch (gating_op) {
  720. case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
  721. {
  722. probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
  723. } break;
  724. case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
  725. {
  726. probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
  727. } break;
  728. default:
  729. GGML_ABORT("fatal error");
  730. }
  731. cb(probs, "ffn_moe_probs", il);
  732. // add experts selection bias - introduced in DeepSeek V3
  733. // leave probs unbiased as it's later used to get expert weights
  734. ggml_tensor * selection_probs = probs;
  735. if (exp_probs_b != nullptr) {
  736. selection_probs = ggml_add(ctx0, probs, exp_probs_b);
  737. cb(selection_probs, "ffn_moe_probs_biased", il);
  738. }
  739. // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
  740. // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
  741. if (arch == LLM_ARCH_LLAMA4) {
  742. selection_probs = logits;
  743. }
  744. // select experts
  745. ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
  746. cb(selected_experts->src[0], "ffn_moe_argsort", il);
  747. cb(selected_experts, "ffn_moe_topk", il);
  748. ggml_tensor * weights = ggml_get_rows(ctx0,
  749. ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
  750. cb(weights, "ffn_moe_weights", il);
  751. if (norm_w) {
  752. weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
  753. ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
  754. cb(weights_sum, "ffn_moe_weights_sum", il);
  755. weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
  756. cb(weights, "ffn_moe_weights_norm", il);
  757. weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
  758. }
  759. if (scale_w) {
  760. weights = ggml_scale(ctx0, weights, w_scale);
  761. cb(weights, "ffn_moe_weights_scaled", il);
  762. }
  763. cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
  764. if (weight_before_ffn) {
  765. // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
  766. ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
  767. repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
  768. cur = ggml_mul(ctx0, repeated, weights);
  769. cb(cur, "ffn_moe_weighted", il);
  770. }
  771. ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
  772. cb(up, "ffn_moe_up", il);
  773. ggml_tensor * experts = nullptr;
  774. if (gate_exps) {
  775. cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
  776. cb(cur, "ffn_moe_gate", il);
  777. } else {
  778. cur = up;
  779. }
  780. switch (type_op) {
  781. case LLM_FFN_SILU:
  782. {
  783. cur = ggml_silu(ctx0, cur);
  784. cb(cur, "ffn_moe_silu", il);
  785. } break;
  786. case LLM_FFN_GELU:
  787. {
  788. cur = ggml_gelu(ctx0, cur);
  789. cb(cur, "ffn_moe_gelu", il);
  790. } break;
  791. default:
  792. GGML_ABORT("fatal error");
  793. }
  794. if (gate_exps) {
  795. cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
  796. cb(cur, "ffn_moe_gate_par", il);
  797. }
  798. experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
  799. cb(experts, "ffn_moe_down", il);
  800. if (!weight_before_ffn) {
  801. experts = ggml_mul(ctx0, experts, weights);
  802. cb(cur, "ffn_moe_weighted", il);
  803. }
  804. // aggregate experts
  805. ggml_tensor * moe_out = nullptr;
  806. for (int i = 0; i < n_expert_used; ++i) {
  807. ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
  808. experts->nb[2], i*experts->nb[1]);
  809. if (i == 0) {
  810. moe_out = cur_expert;
  811. } else {
  812. moe_out = ggml_add(ctx0, moe_out, cur_expert);
  813. }
  814. }
  815. if (n_expert_used == 1) {
  816. // avoid returning a non-contiguous tensor
  817. moe_out = ggml_cont(ctx0, moe_out);
  818. }
  819. cb(moe_out, "ffn_moe_out", il);
  820. return moe_out;
  821. }
  822. // input embeddings with optional lora
  823. ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
  824. const int64_t n_embd = hparams.n_embd;
  825. auto inp = std::make_unique<llm_graph_input_embd>();
  826. ggml_tensor * cur = nullptr;
  827. if (ubatch.token) {
  828. inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
  829. //cb(inp->tokens, "inp_tokens", -1);
  830. ggml_set_input(inp->tokens);
  831. cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
  832. // apply lora for embedding tokens if needed
  833. for (const auto & lora : *loras) {
  834. llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
  835. if (lw == nullptr) {
  836. continue;
  837. }
  838. const float adapter_scale = lora.second;
  839. const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
  840. ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
  841. ctx0, lw->b, // non-transposed lora_b
  842. ggml_get_rows(ctx0, lw->a, inp->tokens)
  843. ), scale);
  844. cur = ggml_add(ctx0, cur, inpL_delta);
  845. }
  846. } else {
  847. inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
  848. ggml_set_input(inp->embd);
  849. cur = inp->embd;
  850. }
  851. // For Granite architecture
  852. if (hparams.f_embedding_scale != 0.0f) {
  853. cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
  854. }
  855. cb(cur, "inp_embd", -1);
  856. res->add_input(std::move(inp));
  857. return cur;
  858. }
  859. ggml_tensor * llm_graph_context::build_inp_pos() const {
  860. auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
  861. auto & cur = inp->pos;
  862. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
  863. ggml_set_input(cur);
  864. res->add_input(std::move(inp));
  865. return cur;
  866. }
  867. ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
  868. auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
  869. auto & cur = inp->attn_scale;
  870. // this need to be 1x1xN for broadcasting
  871. cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
  872. ggml_set_input(cur);
  873. res->add_input(std::move(inp));
  874. return cur;
  875. }
  876. ggml_tensor * llm_graph_context::build_inp_out_ids() const {
  877. auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
  878. auto & cur = inp->out_ids;
  879. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
  880. ggml_set_input(cur);
  881. res->add_input(std::move(inp));
  882. return cur;
  883. }
  884. ggml_tensor * llm_graph_context::build_inp_mean() const {
  885. auto inp = std::make_unique<llm_graph_input_mean>(cparams);
  886. auto & cur = inp->mean;
  887. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
  888. ggml_set_input(cur);
  889. res->add_input(std::move(inp));
  890. return cur;
  891. }
  892. ggml_tensor * llm_graph_context::build_inp_cls() const {
  893. auto inp = std::make_unique<llm_graph_input_cls>(cparams);
  894. auto & cur = inp->cls;
  895. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
  896. ggml_set_input(cur);
  897. res->add_input(std::move(inp));
  898. return cur;
  899. }
  900. ggml_tensor * llm_graph_context::build_inp_s_copy() const {
  901. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  902. auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
  903. const auto n_kv = kv_self->n;
  904. auto & cur = inp->s_copy;
  905. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
  906. ggml_set_input(cur);
  907. res->add_input(std::move(inp));
  908. return cur;
  909. }
  910. ggml_tensor * llm_graph_context::build_inp_s_mask() const {
  911. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  912. auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
  913. const auto n_kv = kv_self->n;
  914. auto & cur = inp->s_mask;
  915. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
  916. ggml_set_input(cur);
  917. res->add_input(std::move(inp));
  918. return cur;
  919. }
  920. ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
  921. auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
  922. auto & cur = inp->cross_embd;
  923. // if we have the output embeddings from the encoder, use them directly
  924. // TODO: needs more work to be correct, for now just use the tensor shape
  925. //if (cross->t_embd) {
  926. // cur = ggml_view_tensor(ctx0, cross->t_embd);
  927. // return cur;
  928. //}
  929. const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd;
  930. const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
  931. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
  932. ggml_set_input(cur);
  933. res->add_input(std::move(inp));
  934. return cur;
  935. }
  936. ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
  937. auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);
  938. auto & cur = inp->pos_bucket;
  939. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
  940. ggml_set_input(cur);
  941. res->add_input(std::move(inp));
  942. return cur;
  943. }
  944. ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
  945. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  946. auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
  947. const auto n_kv = kv_self->n;
  948. auto & cur = inp->pos_bucket;
  949. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
  950. ggml_set_input(cur);
  951. res->add_input(std::move(inp));
  952. return cur;
  953. }
  954. ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
  955. ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]);
  956. cb(pos_bucket_1d, "pos_bucket_1d", -1);
  957. ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
  958. pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
  959. pos_bias = ggml_permute (ctx0, pos_bias, 2, 0, 1, 3);
  960. pos_bias = ggml_cont (ctx0, pos_bias);
  961. cb(pos_bias, "pos_bias", -1);
  962. return pos_bias;
  963. }
  964. ggml_tensor * llm_graph_context::build_attn_mha(
  965. ggml_cgraph * gf,
  966. ggml_tensor * q,
  967. ggml_tensor * k,
  968. ggml_tensor * v,
  969. ggml_tensor * kq_b,
  970. ggml_tensor * kq_mask,
  971. ggml_tensor * v_mla,
  972. bool v_trans,
  973. float kq_scale) const {
  974. //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  975. //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  976. //const int64_t n_head = hparams.n_head(il);
  977. //const int64_t n_head_kv = hparams.n_head_kv(il);
  978. //const auto & n_embd_head_k = hparams.n_embd_head_k;
  979. //const auto & n_embd_head_v = hparams.n_embd_head_v;
  980. const auto n_tokens = q->ne[1];
  981. const auto n_head = q->ne[2];
  982. const auto n_kv = k->ne[1];
  983. ggml_tensor * cur;
  984. // TODO: replace hardcoded padding with ggml-provided padding
  985. if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
  986. GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
  987. if (v_trans) {
  988. v = ggml_transpose(ctx0, v);
  989. }
  990. // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
  991. if (k->type == GGML_TYPE_F32) {
  992. k = ggml_cast(ctx0, k, GGML_TYPE_F16);
  993. }
  994. if (v->type == GGML_TYPE_F32) {
  995. v = ggml_cast(ctx0, v, GGML_TYPE_F16);
  996. }
  997. cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
  998. hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
  999. ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
  1000. if (v_mla) {
  1001. cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
  1002. cur = ggml_mul_mat(ctx0, v_mla, cur);
  1003. }
  1004. cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
  1005. } else {
  1006. ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
  1007. // note: this op tends to require high floating point range
  1008. // while for some models F16 is enough, for others it is not, so we default to F32 here
  1009. ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
  1010. if (arch == LLM_ARCH_GROK) {
  1011. // need to do the following:
  1012. // multiply by attn_output_multiplyer of 0.08838834764831845
  1013. // and then :
  1014. // kq = 30 * tanh(kq / 30)
  1015. // before the softmax below
  1016. kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
  1017. kq = ggml_scale(ctx0, kq, 30);
  1018. }
  1019. if (hparams.attn_soft_cap) {
  1020. kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
  1021. kq = ggml_tanh (ctx0, kq);
  1022. kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
  1023. }
  1024. if (kq_b) {
  1025. kq = ggml_add(ctx0, kq, kq_b);
  1026. }
  1027. kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
  1028. if (!v_trans) {
  1029. // note: avoid this branch
  1030. v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
  1031. }
  1032. ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
  1033. // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
  1034. if (v_mla) {
  1035. kqv = ggml_mul_mat(ctx0, v_mla, kqv);
  1036. }
  1037. cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
  1038. cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
  1039. if (!cparams.offload_kqv) {
  1040. // all nodes between the KV store and the attention output are run on the CPU
  1041. ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
  1042. }
  1043. }
  1044. ggml_build_forward_expand(gf, cur);
  1045. return cur;
  1046. }
  1047. llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
  1048. auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
  1049. // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
  1050. inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1051. //cb(inp_kq_mask, "KQ_mask", -1);
  1052. ggml_set_input(inp->kq_mask);
  1053. inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
  1054. return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
  1055. }
  1056. ggml_tensor * llm_graph_context::build_attn(
  1057. llm_graph_input_attn_no_cache * inp,
  1058. ggml_cgraph * gf,
  1059. ggml_tensor * wo,
  1060. ggml_tensor * wo_b,
  1061. ggml_tensor * q_cur,
  1062. ggml_tensor * k_cur,
  1063. ggml_tensor * v_cur,
  1064. ggml_tensor * kq_b,
  1065. ggml_tensor * v_mla,
  1066. float kq_scale,
  1067. int il) const {
  1068. GGML_UNUSED(n_tokens);
  1069. // these nodes are added to the graph together so that they are not reordered
  1070. // by doing so, the number of splits in the graph is reduced
  1071. ggml_build_forward_expand(gf, q_cur);
  1072. ggml_build_forward_expand(gf, k_cur);
  1073. ggml_build_forward_expand(gf, v_cur);
  1074. const auto & kq_mask = inp->get_kq_mask();
  1075. ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
  1076. //cb(q, "q", il);
  1077. ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
  1078. //cb(k, "k", il);
  1079. ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
  1080. //cb(k, "v", il);
  1081. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
  1082. cb(cur, "kqv_out", il);
  1083. if (wo) {
  1084. cur = build_lora_mm(wo, cur);
  1085. }
  1086. if (wo_b) {
  1087. //cb(cur, "kqv_wo", il);
  1088. }
  1089. if (wo_b) {
  1090. cur = ggml_add(ctx0, cur, wo_b);
  1091. }
  1092. return cur;
  1093. }
  1094. llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
  1095. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  1096. auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
  1097. const auto n_kv = kv_self->n;
  1098. inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1099. //cb(inp->self_kq_mask, "KQ_mask", -1);
  1100. ggml_set_input(inp->self_kq_mask);
  1101. inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
  1102. if (hparams.n_swa_pattern > 1) {
  1103. GGML_ASSERT(hparams.n_swa > 0);
  1104. inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1105. //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
  1106. ggml_set_input(inp->self_kq_mask_swa);
  1107. inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
  1108. }
  1109. return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
  1110. }
  1111. ggml_tensor * llm_graph_context::build_attn(
  1112. llm_graph_input_attn_kv_unified * inp,
  1113. ggml_cgraph * gf,
  1114. ggml_tensor * wo,
  1115. ggml_tensor * wo_b,
  1116. ggml_tensor * q_cur,
  1117. ggml_tensor * k_cur,
  1118. ggml_tensor * v_cur,
  1119. ggml_tensor * kq_b,
  1120. ggml_tensor * v_mla,
  1121. float kq_scale,
  1122. int il) const {
  1123. // these nodes are added to the graph together so that they are not reordered
  1124. // by doing so, the number of splits in the graph is reduced
  1125. ggml_build_forward_expand(gf, q_cur);
  1126. ggml_build_forward_expand(gf, k_cur);
  1127. ggml_build_forward_expand(gf, v_cur);
  1128. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  1129. const auto & n_ctx = cparams.n_ctx;
  1130. const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  1131. const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1132. const auto n_tokens = q_cur->ne[2];
  1133. const bool v_trans = !cparams.flash_attn;
  1134. // store to KV cache
  1135. {
  1136. GGML_ASSERT(!kv_self->recurrent);
  1137. const auto kv_head = kv_self->head;
  1138. GGML_ASSERT(kv_self->size == n_ctx);
  1139. ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
  1140. //cb(k_cache_view, "k_cache_view", il);
  1141. // note: storing RoPE-ed version of K in the KV cache
  1142. ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
  1143. v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
  1144. ggml_tensor * v_cache_view = nullptr;
  1145. if (!v_trans) {
  1146. v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
  1147. } else {
  1148. // note: the V cache is transposed when not using flash attention
  1149. v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
  1150. ( n_ctx)*ggml_element_size(kv_self->v_l[il]),
  1151. (kv_head)*ggml_element_size(kv_self->v_l[il]));
  1152. v_cur = ggml_transpose(ctx0, v_cur);
  1153. }
  1154. //cb(v_cache_view, "v_cache_view", il);
  1155. ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
  1156. }
  1157. const bool is_swa = hparams.is_swa(il);
  1158. const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
  1159. const auto n_kv = kv_self->n;
  1160. const int64_t n_head_kv = hparams.n_head_kv(il);
  1161. const auto & n_embd_head_k = hparams.n_embd_head_k;
  1162. const auto & n_embd_head_v = hparams.n_embd_head_v;
  1163. ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
  1164. //cb(q, "q", il);
  1165. ggml_tensor * k =
  1166. ggml_view_3d(ctx0, kv_self->k_l[il],
  1167. n_embd_head_k, n_kv, n_head_kv,
  1168. ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
  1169. ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
  1170. 0);
  1171. //cb(k, "k", il);
  1172. ggml_tensor * v = !v_trans ?
  1173. ggml_view_3d(ctx0, kv_self->v_l[il],
  1174. n_embd_head_v, n_kv, n_head_kv,
  1175. ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
  1176. ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
  1177. 0) :
  1178. ggml_view_3d(ctx0, kv_self->v_l[il],
  1179. n_kv, n_embd_head_v, n_head_kv,
  1180. ggml_element_size(kv_self->v_l[il])*n_ctx,
  1181. ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
  1182. 0);
  1183. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
  1184. cb(cur, "kqv_out", il);
  1185. if (wo) {
  1186. cur = build_lora_mm(wo, cur);
  1187. }
  1188. if (wo_b) {
  1189. //cb(cur, "kqv_wo", il);
  1190. }
  1191. if (wo_b) {
  1192. cur = ggml_add(ctx0, cur, wo_b);
  1193. }
  1194. return cur;
  1195. }
  1196. llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
  1197. auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
  1198. const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
  1199. inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1200. ggml_set_input(inp->cross_kq_mask);
  1201. inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
  1202. return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
  1203. }
  1204. ggml_tensor * llm_graph_context::build_attn(
  1205. llm_graph_input_attn_cross * inp,
  1206. ggml_cgraph * gf,
  1207. ggml_tensor * wo,
  1208. ggml_tensor * wo_b,
  1209. ggml_tensor * q_cur,
  1210. ggml_tensor * k_cur,
  1211. ggml_tensor * v_cur,
  1212. ggml_tensor * kq_b,
  1213. ggml_tensor * v_mla,
  1214. float kq_scale,
  1215. int il) const {
  1216. // these nodes are added to the graph together so that they are not reordered
  1217. // by doing so, the number of splits in the graph is reduced
  1218. ggml_build_forward_expand(gf, q_cur);
  1219. ggml_build_forward_expand(gf, k_cur);
  1220. ggml_build_forward_expand(gf, v_cur);
  1221. const auto & kq_mask = inp->get_kq_mask_cross();
  1222. ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
  1223. //cb(q, "q", il);
  1224. ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
  1225. //cb(k, "k", il);
  1226. ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
  1227. //cb(k, "v", il);
  1228. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
  1229. cb(cur, "kqv_out", il);
  1230. if (wo) {
  1231. cur = build_lora_mm(wo, cur);
  1232. }
  1233. if (wo_b) {
  1234. //cb(cur, "kqv_wo", il);
  1235. }
  1236. if (wo_b) {
  1237. cur = ggml_add(ctx0, cur, wo_b);
  1238. }
  1239. return cur;
  1240. }
  1241. ggml_tensor * llm_graph_context::build_copy_mask_state(
  1242. ggml_cgraph * gf,
  1243. ggml_tensor * s,
  1244. ggml_tensor * state_copy,
  1245. ggml_tensor * state_mask,
  1246. int32_t n_state,
  1247. int32_t n_seqs) const {
  1248. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  1249. const auto n_kv = kv_self->n;
  1250. const auto kv_head = kv_self->head;
  1251. ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
  1252. // copy states
  1253. // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
  1254. // this shrinks the tensors's ne[1] to n_kv
  1255. states = ggml_get_rows(ctx0, states, state_copy);
  1256. // clear states of sequences which are starting at the beginning of this batch
  1257. // FIXME: zero-out NANs?
  1258. states = ggml_mul(ctx0, states, state_mask);
  1259. // copy states which won't be changed further (between n_seqs and n_kv)
  1260. ggml_build_forward_expand(gf,
  1261. ggml_cpy(ctx0,
  1262. ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
  1263. ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
  1264. // the part of the states that will be used and modified
  1265. return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
  1266. }
  1267. ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
  1268. ggml_cgraph * gf,
  1269. ggml_tensor * state_copy,
  1270. ggml_tensor * state_mask,
  1271. const llama_ubatch & ubatch,
  1272. int il) const {
  1273. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  1274. const auto token_shift_count = hparams.token_shift_count;
  1275. const int64_t n_seqs = ubatch.n_seqs;
  1276. ggml_tensor * token_shift_all = kv_self->k_l[il];
  1277. ggml_tensor * token_shift = build_copy_mask_state(
  1278. gf, token_shift_all, state_copy, state_mask,
  1279. hparams.n_embd_k_s(), n_seqs);
  1280. token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
  1281. return token_shift;
  1282. }
  1283. ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
  1284. ggml_tensor * token_shift,
  1285. const llama_ubatch & ubatch,
  1286. int il) const {
  1287. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  1288. const auto token_shift_count = hparams.token_shift_count;
  1289. const auto n_embd = hparams.n_embd;
  1290. const int64_t n_seqs = ubatch.n_seqs;
  1291. const auto kv_head = kv_self->head;
  1292. return ggml_cpy(
  1293. ctx0,
  1294. ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
  1295. ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
  1296. );
  1297. }
  1298. void llm_graph_context::build_pooling(
  1299. ggml_cgraph * gf,
  1300. ggml_tensor * cls,
  1301. ggml_tensor * cls_b,
  1302. ggml_tensor * cls_out,
  1303. ggml_tensor * cls_out_b) const {
  1304. if (!cparams.embeddings) {
  1305. return;
  1306. }
  1307. ggml_tensor * inp = res->t_embd;
  1308. //// find result_norm tensor for input
  1309. //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
  1310. // inp = ggml_graph_node(gf, i);
  1311. // if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
  1312. // break;
  1313. // }
  1314. // inp = nullptr;
  1315. //}
  1316. GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
  1317. ggml_tensor * cur;
  1318. switch (pooling_type) {
  1319. case LLAMA_POOLING_TYPE_NONE:
  1320. {
  1321. cur = inp;
  1322. } break;
  1323. case LLAMA_POOLING_TYPE_MEAN:
  1324. {
  1325. ggml_tensor * inp_mean = build_inp_mean();
  1326. cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
  1327. } break;
  1328. case LLAMA_POOLING_TYPE_CLS:
  1329. case LLAMA_POOLING_TYPE_LAST:
  1330. {
  1331. ggml_tensor * inp_cls = build_inp_cls();
  1332. cur = ggml_get_rows(ctx0, inp, inp_cls);
  1333. } break;
  1334. case LLAMA_POOLING_TYPE_RANK:
  1335. {
  1336. ggml_tensor * inp_cls = build_inp_cls();
  1337. inp = ggml_get_rows(ctx0, inp, inp_cls);
  1338. // classification head
  1339. // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
  1340. GGML_ASSERT(cls != nullptr);
  1341. GGML_ASSERT(cls_b != nullptr);
  1342. cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
  1343. cur = ggml_tanh(ctx0, cur);
  1344. // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
  1345. // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
  1346. if (cls_out) {
  1347. GGML_ASSERT(cls_out_b != nullptr);
  1348. cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
  1349. }
  1350. } break;
  1351. default:
  1352. {
  1353. GGML_ABORT("unknown pooling type");
  1354. }
  1355. }
  1356. cb(cur, "result_embd_pooled", -1);
  1357. res->t_embd_pooled = cur;
  1358. ggml_build_forward_expand(gf, cur);
  1359. }