llama-graph.cpp 68 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048
  1. #include "llama-graph.h"
  2. #include "llama-impl.h"
  3. #include "llama-batch.h"
  4. #include "llama-cparams.h"
  5. #include "llama-kv-cache.h"
  6. #include "llama-kv-cache-iswa.h"
  7. #include "llama-memory-hybrid.h"
  8. #include "llama-memory-recurrent.h"
  9. #include <cassert>
  10. #include <cmath>
  11. #include <cstring>
  12. void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
  13. if (ubatch->token) {
  14. const int64_t n_tokens = ubatch->n_tokens;
  15. ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
  16. }
  17. if (ubatch->embd) {
  18. const int64_t n_embd = embd->ne[0];
  19. const int64_t n_tokens = ubatch->n_tokens;
  20. ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
  21. }
  22. }
  23. bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
  24. bool res = true;
  25. res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
  26. res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
  27. return res;
  28. }
  29. void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
  30. if (ubatch->pos && pos) {
  31. const int64_t n_tokens = ubatch->n_tokens;
  32. if (ubatch->token && n_pos_per_embd == 4) {
  33. // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
  34. // the 3 first dims are the same, and 4th dim is all 0
  35. std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
  36. // copy the first dimension
  37. for (int i = 0; i < n_tokens; ++i) {
  38. pos_data[ i] = ubatch->pos[i];
  39. pos_data[ n_tokens + i] = ubatch->pos[i];
  40. pos_data[2 * n_tokens + i] = ubatch->pos[i];
  41. pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
  42. }
  43. ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
  44. } else {
  45. ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
  46. }
  47. }
  48. }
  49. bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
  50. bool res = true;
  51. res &= pos->ne[0] == params.ubatch.n_tokens;
  52. return res;
  53. }
  54. void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
  55. if (ubatch->pos && attn_scale) {
  56. const int64_t n_tokens = ubatch->n_tokens;
  57. std::vector<float> attn_scale_data(n_tokens, 0.0f);
  58. for (int i = 0; i < n_tokens; ++i) {
  59. const float pos = ubatch->pos[i];
  60. attn_scale_data[i] = std::log(
  61. std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
  62. ) * f_attn_temp_scale + 1.0;
  63. }
  64. ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
  65. }
  66. }
  67. void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
  68. if (pos_bucket) {
  69. const int64_t n_tokens = ubatch->n_tokens;
  70. GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
  71. GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
  72. int32_t * data = (int32_t *) pos_bucket->data;
  73. for (int h = 0; h < 1; ++h) {
  74. for (int j = 0; j < n_tokens; ++j) {
  75. for (int i = 0; i < n_tokens; ++i) {
  76. data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
  77. }
  78. }
  79. }
  80. }
  81. }
  82. void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
  83. if (pos_bucket) {
  84. mctx->set_input_pos_bucket(pos_bucket, ubatch);
  85. }
  86. }
  87. void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
  88. GGML_ASSERT(out_ids);
  89. const int64_t n_tokens = ubatch->n_tokens;
  90. GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
  91. int32_t * data = (int32_t *) out_ids->data;
  92. if (n_outputs == n_tokens) {
  93. for (int i = 0; i < n_tokens; ++i) {
  94. data[i] = i;
  95. }
  96. return;
  97. }
  98. GGML_ASSERT(ubatch->output);
  99. int n_outputs = 0;
  100. for (int i = 0; i < n_tokens; ++i) {
  101. if (ubatch->output[i]) {
  102. data[n_outputs++] = i;
  103. }
  104. }
  105. }
  106. bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
  107. bool res = true;
  108. res &= n_outputs == params.n_outputs;
  109. return res;
  110. }
  111. void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
  112. if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
  113. const int64_t n_tokens = ubatch->n_tokens;
  114. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  115. const int64_t n_seqs_unq = ubatch->n_seqs_unq;
  116. GGML_ASSERT(mean);
  117. GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
  118. float * data = (float *) mean->data;
  119. memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
  120. std::vector<uint64_t> sums(n_seqs_unq, 0);
  121. for (int i = 0; i < n_tokens; i += n_seq_tokens) {
  122. for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
  123. const llama_seq_id seq_id = ubatch->seq_id[i][s];
  124. const int32_t seq_idx = ubatch->seq_idx[seq_id];
  125. sums[seq_idx] += ubatch->n_seq_tokens;
  126. }
  127. }
  128. std::vector<float> div(n_seqs_unq, 0.0f);
  129. for (int s = 0; s < n_seqs_unq; ++s) {
  130. const uint64_t sum = sums[s];
  131. if (sum > 0) {
  132. div[s] = 1.0f/float(sum);
  133. }
  134. }
  135. for (int i = 0; i < n_tokens; i += n_seq_tokens) {
  136. for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
  137. const llama_seq_id seq_id = ubatch->seq_id[i][s];
  138. const int32_t seq_idx = ubatch->seq_idx[seq_id];
  139. for (int j = 0; j < n_seq_tokens; ++j) {
  140. data[seq_idx*n_tokens + i + j] = div[seq_idx];
  141. }
  142. }
  143. }
  144. }
  145. }
  146. void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
  147. const int64_t n_tokens = ubatch->n_tokens;
  148. const int64_t n_seqs_unq = ubatch->n_seqs_unq;
  149. if (cparams.embeddings && (
  150. cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
  151. cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
  152. cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
  153. )) {
  154. GGML_ASSERT(cls);
  155. GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
  156. uint32_t * data = (uint32_t *) cls->data;
  157. memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
  158. std::vector<int> target_pos(n_seqs_unq, -1);
  159. std::vector<int> target_row(n_seqs_unq, -1);
  160. const bool last = (
  161. cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
  162. (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
  163. );
  164. for (int i = 0; i < n_tokens; ++i) {
  165. const llama_pos pos = ubatch->pos[i];
  166. for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
  167. const llama_seq_id seq_id = ubatch->seq_id[i][s];
  168. const int32_t seq_idx = ubatch->seq_idx[seq_id];
  169. if (
  170. (target_pos[seq_idx] == -1) ||
  171. ( last && pos >= target_pos[seq_idx]) ||
  172. (!last && pos < target_pos[seq_idx])
  173. ) {
  174. target_pos[seq_idx] = pos;
  175. target_row[seq_idx] = i;
  176. }
  177. }
  178. }
  179. for (int s = 0; s < n_seqs_unq; ++s) {
  180. if (target_row[s] >= 0) {
  181. data[s] = target_row[s];
  182. }
  183. }
  184. }
  185. }
  186. void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
  187. GGML_UNUSED(ubatch);
  188. const int64_t n_rs = mctx->get_n_rs();
  189. if (s_copy) {
  190. GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
  191. int32_t * data = (int32_t *) s_copy->data;
  192. // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
  193. for (uint32_t i = 0; i < n_rs; ++i) {
  194. data[i] = mctx->s_copy(i);
  195. }
  196. }
  197. }
  198. void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
  199. GGML_UNUSED(ubatch);
  200. if (cross_embd && !cross->v_embd.empty()) {
  201. assert(cross_embd->type == GGML_TYPE_F32);
  202. ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd));
  203. }
  204. }
  205. static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
  206. LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
  207. const char * swa_type_str = "unknown";
  208. switch (swa_type) {
  209. case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
  210. case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
  211. case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
  212. case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
  213. };
  214. LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
  215. LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
  216. LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
  217. LLAMA_LOG_DEBUG(" ");
  218. for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
  219. LLAMA_LOG_DEBUG("%2d", j);
  220. }
  221. LLAMA_LOG_DEBUG("\n");
  222. for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
  223. LLAMA_LOG_DEBUG(" %2d ", i);
  224. for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
  225. float val = data[i * n_kv + j];
  226. if (val == -INFINITY) {
  227. LLAMA_LOG_DEBUG(" ∞");
  228. } else {
  229. LLAMA_LOG_DEBUG(" 0");
  230. }
  231. }
  232. LLAMA_LOG_DEBUG("\n");
  233. }
  234. }
  235. void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
  236. const int64_t n_kv = ubatch->n_tokens;
  237. const int64_t n_tokens = ubatch->n_tokens;
  238. const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
  239. for (int h = 0; h < 1; ++h) {
  240. for (int i1 = 0; i1 < n_tokens; ++i1) {
  241. const llama_seq_id s1 = ubatch->seq_id[i1][0];
  242. const llama_pos p1 = ubatch->pos[i1];
  243. const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
  244. for (int i0 = 0; i0 < n_tokens; ++i0) {
  245. const llama_seq_id s0 = ubatch->seq_id[i0][0];
  246. const llama_pos p0 = ubatch->pos[i0];
  247. // mask different sequences
  248. if (s0 != s1) {
  249. continue;
  250. }
  251. // mask future tokens
  252. if (cparams.causal_attn && p0 > p1) {
  253. continue;
  254. }
  255. // apply SWA if any
  256. if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
  257. continue;
  258. }
  259. data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
  260. }
  261. }
  262. }
  263. };
  264. {
  265. GGML_ASSERT(self_kq_mask);
  266. GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
  267. float * data = (float *) self_kq_mask->data;
  268. std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
  269. fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
  270. if (debug) {
  271. print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
  272. }
  273. }
  274. if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
  275. GGML_ASSERT(self_kq_mask_swa);
  276. GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
  277. float * data = (float *) self_kq_mask_swa->data;
  278. std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
  279. fill_mask(data, hparams.n_swa, hparams.swa_type);
  280. if (debug) {
  281. print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
  282. }
  283. }
  284. }
  285. void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
  286. mctx->set_input_k_idxs(self_k_idxs, ubatch);
  287. mctx->set_input_v_idxs(self_v_idxs, ubatch);
  288. mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
  289. }
  290. bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
  291. const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
  292. this->mctx = mctx;
  293. bool res = true;
  294. res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
  295. //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
  296. res &= self_kq_mask->ne[0] == mctx->get_n_kv();
  297. res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
  298. return res;
  299. }
  300. void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
  301. mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
  302. mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
  303. mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
  304. mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
  305. mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
  306. mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
  307. }
  308. bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
  309. const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
  310. this->mctx = mctx;
  311. bool res = true;
  312. res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
  313. //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
  314. res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
  315. //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
  316. res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
  317. res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
  318. res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
  319. res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
  320. return res;
  321. }
  322. void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
  323. GGML_ASSERT(cross_kq_mask);
  324. const int64_t n_enc = cross_kq_mask->ne[0];
  325. const int64_t n_tokens = ubatch->n_tokens;
  326. GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
  327. GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
  328. float * data = (float *) cross_kq_mask->data;
  329. for (int h = 0; h < 1; ++h) {
  330. for (int i = 0; i < n_tokens; ++i) {
  331. for (int j = 0; j < n_enc; ++j) {
  332. float f = -INFINITY;
  333. for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
  334. const llama_seq_id seq_id = ubatch->seq_id[i][s];
  335. if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
  336. f = 0.0f;
  337. }
  338. }
  339. data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
  340. }
  341. }
  342. for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  343. for (int j = 0; j < n_enc; ++j) {
  344. data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
  345. }
  346. }
  347. }
  348. }
  349. void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
  350. inp_attn->set_input(ubatch);
  351. inp_rs->set_input(ubatch);
  352. }
  353. //
  354. // llm_graph_result
  355. //
  356. llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
  357. reset();
  358. const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
  359. debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
  360. }
  361. int64_t llm_graph_result::get_max_nodes() const {
  362. return max_nodes;
  363. }
  364. void llm_graph_result::reset() {
  365. t_tokens = nullptr;
  366. t_logits = nullptr;
  367. t_embd = nullptr;
  368. t_embd_pooled = nullptr;
  369. params = {};
  370. inputs.clear();
  371. buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
  372. ggml_init_params params = {
  373. /*.mem_size =*/ buf_compute_meta.size(),
  374. /*.mem_buffer =*/ buf_compute_meta.data(),
  375. /*.no_alloc =*/ true,
  376. };
  377. ctx_compute.reset(ggml_init(params));
  378. gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
  379. }
  380. void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
  381. for (auto & input : inputs) {
  382. input->set_input(ubatch);
  383. }
  384. }
  385. bool llm_graph_result::can_reuse(const llm_graph_params & params) {
  386. if (!this->params.allow_reuse(params)) {
  387. if (debug > 1) {
  388. LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
  389. }
  390. return false;
  391. }
  392. if (debug > 1) {
  393. LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
  394. }
  395. bool res = true;
  396. for (auto & input : inputs) {
  397. const bool cur = input->can_reuse(params);
  398. if (debug > 1) {
  399. LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
  400. }
  401. res = res && cur;
  402. }
  403. if (debug > 0) {
  404. LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
  405. }
  406. return res;
  407. }
  408. llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
  409. inputs.emplace_back(std::move(input));
  410. return inputs.back().get();
  411. }
  412. void llm_graph_result::set_params(const llm_graph_params & params) {
  413. this->params = params;
  414. }
  415. //
  416. // llm_graph_context
  417. //
  418. llm_graph_context::llm_graph_context(const llm_graph_params & params) :
  419. arch (params.arch),
  420. hparams (params.hparams),
  421. cparams (params.cparams),
  422. ubatch (params.ubatch),
  423. n_embd (hparams.n_embd),
  424. n_layer (hparams.n_layer),
  425. n_rot (hparams.n_rot),
  426. n_ctx (cparams.n_ctx),
  427. n_head (hparams.n_head()),
  428. n_head_kv (hparams.n_head_kv()),
  429. n_embd_head_k (hparams.n_embd_head_k),
  430. n_embd_k_gqa (hparams.n_embd_k_gqa()),
  431. n_embd_head_v (hparams.n_embd_head_v),
  432. n_embd_v_gqa (hparams.n_embd_v_gqa()),
  433. n_expert (hparams.n_expert),
  434. n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
  435. freq_base (cparams.rope_freq_base),
  436. freq_scale (cparams.rope_freq_scale),
  437. ext_factor (cparams.yarn_ext_factor),
  438. attn_factor (cparams.yarn_attn_factor),
  439. beta_fast (cparams.yarn_beta_fast),
  440. beta_slow (cparams.yarn_beta_slow),
  441. norm_eps (hparams.f_norm_eps),
  442. norm_rms_eps (hparams.f_norm_rms_eps),
  443. n_tokens (ubatch.n_tokens),
  444. n_outputs (params.n_outputs),
  445. n_ctx_orig (cparams.n_ctx_orig_yarn),
  446. pooling_type (cparams.pooling_type),
  447. rope_type (hparams.rope_type),
  448. sched (params.sched),
  449. backend_cpu (params.backend_cpu),
  450. cvec (params.cvec),
  451. loras (params.loras),
  452. mctx (params.mctx),
  453. cross (params.cross),
  454. cb_func (params.cb),
  455. res (params.res),
  456. ctx0 (res->get_ctx()),
  457. gf (res->get_gf()) {
  458. res->set_params(params);
  459. }
  460. void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
  461. if (cb_func) {
  462. cb_func(ubatch, cur, name, il);
  463. }
  464. }
  465. ggml_tensor * llm_graph_context::build_cvec(
  466. ggml_tensor * cur,
  467. int il) const {
  468. return cvec->apply_to(ctx0, cur, il);
  469. }
  470. ggml_tensor * llm_graph_context::build_lora_mm(
  471. ggml_tensor * w,
  472. ggml_tensor * cur) const {
  473. ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
  474. for (const auto & lora : *loras) {
  475. llama_adapter_lora_weight * lw = lora.first->get_weight(w);
  476. if (lw == nullptr) {
  477. continue;
  478. }
  479. const float adapter_scale = lora.second;
  480. const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
  481. ggml_tensor * ab_cur = ggml_mul_mat(
  482. ctx0, lw->b,
  483. ggml_mul_mat(ctx0, lw->a, cur)
  484. );
  485. ab_cur = ggml_scale(ctx0, ab_cur, scale);
  486. res = ggml_add(ctx0, res, ab_cur);
  487. }
  488. return res;
  489. }
  490. ggml_tensor * llm_graph_context::build_lora_mm_id(
  491. ggml_tensor * w, // ggml_tensor * as
  492. ggml_tensor * cur, // ggml_tensor * b
  493. ggml_tensor * ids) const {
  494. ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
  495. for (const auto & lora : *loras) {
  496. llama_adapter_lora_weight * lw = lora.first->get_weight(w);
  497. if (lw == nullptr) {
  498. continue;
  499. }
  500. const float alpha = lora.first->alpha;
  501. const float rank = (float) lw->b->ne[0];
  502. const float scale = alpha ? lora.second * alpha / rank : lora.second;
  503. ggml_tensor * ab_cur = ggml_mul_mat_id(
  504. ctx0, lw->b,
  505. ggml_mul_mat_id(ctx0, lw->a, cur, ids),
  506. ids
  507. );
  508. ab_cur = ggml_scale(ctx0, ab_cur, scale);
  509. res = ggml_add(ctx0, res, ab_cur);
  510. }
  511. return res;
  512. }
  513. ggml_tensor * llm_graph_context::build_norm(
  514. ggml_tensor * cur,
  515. ggml_tensor * mw,
  516. ggml_tensor * mb,
  517. llm_norm_type type,
  518. int il) const {
  519. switch (type) {
  520. case LLM_NORM: cur = ggml_norm (ctx0, cur, hparams.f_norm_eps); break;
  521. case LLM_NORM_RMS: cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;
  522. case LLM_NORM_GROUP:
  523. {
  524. cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]);
  525. cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
  526. cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[2]);
  527. } break;
  528. }
  529. if (mw || mb) {
  530. cb(cur, "norm", il);
  531. }
  532. if (mw) {
  533. cur = ggml_mul(ctx0, cur, mw);
  534. if (mb) {
  535. cb(cur, "norm_w", il);
  536. }
  537. }
  538. if (mb) {
  539. cur = ggml_add(ctx0, cur, mb);
  540. }
  541. return cur;
  542. }
  543. ggml_tensor * llm_graph_context::build_ffn(
  544. ggml_tensor * cur,
  545. ggml_tensor * up,
  546. ggml_tensor * up_b,
  547. ggml_tensor * up_s,
  548. ggml_tensor * gate,
  549. ggml_tensor * gate_b,
  550. ggml_tensor * gate_s,
  551. ggml_tensor * down,
  552. ggml_tensor * down_b,
  553. ggml_tensor * down_s,
  554. ggml_tensor * act_scales,
  555. llm_ffn_op_type type_op,
  556. llm_ffn_gate_type type_gate,
  557. int il) const {
  558. ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
  559. cb(tmp, "ffn_up", il);
  560. if (up_b) {
  561. tmp = ggml_add(ctx0, tmp, up_b);
  562. cb(tmp, "ffn_up_b", il);
  563. }
  564. if (up_s) {
  565. tmp = ggml_mul(ctx0, tmp, up_s);
  566. cb(tmp, "ffn_up_s", il);
  567. }
  568. if (gate) {
  569. switch (type_gate) {
  570. case LLM_FFN_SEQ:
  571. {
  572. cur = build_lora_mm(gate, tmp);
  573. cb(cur, "ffn_gate", il);
  574. } break;
  575. case LLM_FFN_PAR:
  576. {
  577. cur = build_lora_mm(gate, cur);
  578. cb(cur, "ffn_gate", il);
  579. } break;
  580. }
  581. if (gate_b) {
  582. cur = ggml_add(ctx0, cur, gate_b);
  583. cb(cur, "ffn_gate_b", il);
  584. }
  585. if (gate_s) {
  586. cur = ggml_mul(ctx0, cur, gate_s);
  587. cb(cur, "ffn_gate_s", il);
  588. }
  589. } else {
  590. cur = tmp;
  591. }
  592. switch (type_op) {
  593. case LLM_FFN_SILU:
  594. if (gate && type_gate == LLM_FFN_PAR) {
  595. cur = ggml_swiglu_split(ctx0, cur, tmp);
  596. cb(cur, "ffn_swiglu", il);
  597. type_gate = LLM_FFN_SEQ;
  598. } else {
  599. cur = ggml_silu(ctx0, cur);
  600. cb(cur, "ffn_silu", il);
  601. } break;
  602. case LLM_FFN_GELU:
  603. if (gate && type_gate == LLM_FFN_PAR) {
  604. cur = ggml_geglu_split(ctx0, cur, tmp);
  605. cb(cur, "ffn_geglu", il);
  606. type_gate = LLM_FFN_SEQ;
  607. } else {
  608. cur = ggml_gelu(ctx0, cur);
  609. cb(cur, "ffn_gelu", il);
  610. if (act_scales != NULL) {
  611. cur = ggml_div(ctx0, cur, act_scales);
  612. cb(cur, "ffn_act", il);
  613. }
  614. } break;
  615. case LLM_FFN_RELU:
  616. if (gate && type_gate == LLM_FFN_PAR) {
  617. cur = ggml_reglu_split(ctx0, cur, tmp);
  618. cb(cur, "ffn_reglu", il);
  619. type_gate = LLM_FFN_SEQ;
  620. } else {
  621. cur = ggml_relu(ctx0, cur);
  622. cb(cur, "ffn_relu", il);
  623. } break;
  624. case LLM_FFN_RELU_SQR:
  625. {
  626. cur = ggml_relu(ctx0, cur);
  627. cb(cur, "ffn_relu", il);
  628. cur = ggml_sqr(ctx0, cur);
  629. cb(cur, "ffn_sqr(relu)", il);
  630. } break;
  631. case LLM_FFN_SWIGLU:
  632. {
  633. cur = ggml_swiglu(ctx0, cur);
  634. cb(cur, "ffn_swiglu", il);
  635. } break;
  636. case LLM_FFN_GEGLU:
  637. {
  638. cur = ggml_geglu(ctx0, cur);
  639. cb(cur, "ffn_geglu", il);
  640. } break;
  641. case LLM_FFN_REGLU:
  642. {
  643. cur = ggml_reglu(ctx0, cur);
  644. cb(cur, "ffn_reglu", il);
  645. } break;
  646. default:
  647. GGML_ABORT("fatal error");
  648. }
  649. //expand here so that we can fuse ffn gate
  650. ggml_build_forward_expand(gf, cur);
  651. if (gate && type_gate == LLM_FFN_PAR) {
  652. cur = ggml_mul(ctx0, cur, tmp);
  653. cb(cur, "ffn_gate_par", il);
  654. }
  655. if (down) {
  656. cur = build_lora_mm(down, cur);
  657. if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
  658. // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
  659. ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
  660. }
  661. }
  662. if (down_b) {
  663. cb(cur, "ffn_down", il);
  664. }
  665. if (down_b) {
  666. cur = ggml_add(ctx0, cur, down_b);
  667. }
  668. if (down_s) {
  669. cur = ggml_mul(ctx0, cur, down_s);
  670. cb(cur, "ffn_down_s", il);
  671. }
  672. return cur;
  673. }
  674. ggml_tensor * llm_graph_context::build_moe_ffn(
  675. ggml_tensor * cur,
  676. ggml_tensor * gate_inp,
  677. ggml_tensor * up_exps,
  678. ggml_tensor * gate_exps,
  679. ggml_tensor * down_exps,
  680. ggml_tensor * exp_probs_b,
  681. int64_t n_expert,
  682. int64_t n_expert_used,
  683. llm_ffn_op_type type_op,
  684. bool norm_w,
  685. bool scale_w,
  686. float w_scale,
  687. llama_expert_gating_func_type gating_op,
  688. int il,
  689. ggml_tensor * probs_in) const {
  690. return build_moe_ffn(
  691. cur,
  692. gate_inp, /* gate_inp_b */ nullptr,
  693. up_exps, /* up_exps_b */ nullptr,
  694. gate_exps, /* gate_exps_b */ nullptr,
  695. down_exps, /* down_exps_b */ nullptr,
  696. exp_probs_b,
  697. n_expert,
  698. n_expert_used,
  699. type_op,
  700. norm_w,
  701. scale_w,
  702. w_scale,
  703. gating_op,
  704. il,
  705. probs_in
  706. );
  707. }
  708. ggml_tensor * llm_graph_context::build_moe_ffn(
  709. ggml_tensor * cur,
  710. ggml_tensor * gate_inp,
  711. ggml_tensor * gate_inp_b,
  712. ggml_tensor * up_exps,
  713. ggml_tensor * up_exps_b,
  714. ggml_tensor * gate_exps,
  715. ggml_tensor * gate_exps_b,
  716. ggml_tensor * down_exps,
  717. ggml_tensor * down_exps_b,
  718. ggml_tensor * exp_probs_b,
  719. int64_t n_expert,
  720. int64_t n_expert_used,
  721. llm_ffn_op_type type_op,
  722. bool norm_w,
  723. bool scale_w,
  724. float w_scale,
  725. llama_expert_gating_func_type gating_op,
  726. int il,
  727. ggml_tensor * probs_in) const {
  728. const int64_t n_embd = cur->ne[0];
  729. const int64_t n_tokens = cur->ne[1];
  730. const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
  731. ggml_tensor * logits = nullptr;
  732. if (probs_in == nullptr) {
  733. logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
  734. cb(logits, "ffn_moe_logits", il);
  735. } else {
  736. logits = probs_in;
  737. }
  738. if (gate_inp_b) {
  739. logits = ggml_add(ctx0, logits, gate_inp_b);
  740. cb(logits, "ffn_moe_logits_biased", il);
  741. }
  742. ggml_tensor * probs = nullptr;
  743. switch (gating_op) {
  744. case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
  745. {
  746. probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
  747. } break;
  748. case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
  749. {
  750. probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
  751. } break;
  752. case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
  753. {
  754. probs = logits; // [n_expert, n_tokens]
  755. } break;
  756. default:
  757. GGML_ABORT("fatal error");
  758. }
  759. cb(probs, "ffn_moe_probs", il);
  760. // add experts selection bias - introduced in DeepSeek V3
  761. // leave probs unbiased as it's later used to get expert weights
  762. ggml_tensor * selection_probs = probs;
  763. if (exp_probs_b != nullptr) {
  764. selection_probs = ggml_add(ctx0, probs, exp_probs_b);
  765. cb(selection_probs, "ffn_moe_probs_biased", il);
  766. }
  767. // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
  768. // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
  769. if (arch == LLM_ARCH_LLAMA4) {
  770. selection_probs = logits;
  771. }
  772. if (arch == LLM_ARCH_GROVEMOE) {
  773. selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
  774. cb(selection_probs, "ffn_moe_probs_biased", il);
  775. }
  776. // select top n_group_used expert groups
  777. // https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
  778. if (hparams.n_expert_groups > 1 && n_tokens > 0) {
  779. const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
  780. // organize experts into n_expert_groups
  781. ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
  782. ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
  783. group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
  784. // get top n_group_used expert groups
  785. group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
  786. group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
  787. ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
  788. cb(expert_groups, "ffn_moe_group_topk", il);
  789. // mask out the other groups
  790. selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
  791. selection_probs = ggml_set_rows(ctx0, ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
  792. selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
  793. cb(selection_probs, "ffn_moe_probs_masked", il);
  794. }
  795. // select experts
  796. ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
  797. cb(selected_experts->src[0], "ffn_moe_argsort", il);
  798. cb(selected_experts, "ffn_moe_topk", il);
  799. if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
  800. // TODO: Use scalar div instead when/if implemented
  801. ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
  802. selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
  803. probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
  804. } else {
  805. probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
  806. }
  807. ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
  808. cb(weights, "ffn_moe_weights", il);
  809. if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
  810. weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
  811. weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
  812. weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
  813. cb(weights, "ffn_moe_weights_softmax", il);
  814. }
  815. if (norm_w) {
  816. weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
  817. ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
  818. cb(weights_sum, "ffn_moe_weights_sum", il);
  819. // Avoid division by zero, clamp to smallest number representable by F16
  820. weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY);
  821. cb(weights_sum, "ffn_moe_weights_sum_clamped", il);
  822. weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
  823. cb(weights, "ffn_moe_weights_norm", il);
  824. weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
  825. }
  826. if (scale_w) {
  827. weights = ggml_scale(ctx0, weights, w_scale);
  828. cb(weights, "ffn_moe_weights_scaled", il);
  829. }
  830. //call early so that topk-moe can be used
  831. ggml_build_forward_expand(gf, weights);
  832. cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
  833. if (weight_before_ffn) {
  834. // repeat cur to [n_embd, n_expert_used, n_tokens]
  835. ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
  836. cur = ggml_mul(ctx0, repeated, weights);
  837. cb(cur, "ffn_moe_weighted", il);
  838. }
  839. ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
  840. cb(up, "ffn_moe_up", il);
  841. if (up_exps_b) {
  842. up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
  843. cb(up, "ffn_moe_up_biased", il);
  844. }
  845. ggml_tensor * experts = nullptr;
  846. if (gate_exps) {
  847. cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
  848. cb(cur, "ffn_moe_gate", il);
  849. } else {
  850. cur = up;
  851. }
  852. if (gate_exps_b) {
  853. cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
  854. cb(cur, "ffn_moe_gate_biased", il);
  855. }
  856. switch (type_op) {
  857. case LLM_FFN_SILU:
  858. if (gate_exps) {
  859. cur = ggml_swiglu_split(ctx0, cur, up);
  860. cb(cur, "ffn_moe_swiglu", il);
  861. } else {
  862. cur = ggml_silu(ctx0, cur);
  863. cb(cur, "ffn_moe_silu", il);
  864. } break;
  865. case LLM_FFN_GELU:
  866. if (gate_exps) {
  867. cur = ggml_geglu_split(ctx0, cur, up);
  868. cb(cur, "ffn_moe_geglu", il);
  869. } else {
  870. cur = ggml_gelu(ctx0, cur);
  871. cb(cur, "ffn_moe_gelu", il);
  872. } break;
  873. case LLM_FFN_SWIGLU_OAI_MOE:
  874. {
  875. // TODO: move to hparams?
  876. constexpr float alpha = 1.702f;
  877. constexpr float limit = 7.0f;
  878. cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
  879. cb(cur, "ffn_moe_swiglu_oai", il);
  880. } break;
  881. case LLM_FFN_RELU:
  882. if (gate_exps) {
  883. cur = ggml_reglu_split(ctx0, cur, up);
  884. cb(cur, "ffn_moe_reglu", il);
  885. } else {
  886. cur = ggml_relu(ctx0, cur);
  887. cb(cur, "ffn_moe_relu", il);
  888. } break;
  889. default:
  890. GGML_ABORT("fatal error");
  891. }
  892. //expand here so that we can fuse ffn gate
  893. ggml_build_forward_expand(gf, cur);
  894. experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
  895. cb(experts, "ffn_moe_down", il);
  896. if (down_exps_b) {
  897. experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
  898. cb(experts, "ffn_moe_down_biased", il);
  899. }
  900. if (!weight_before_ffn) {
  901. experts = ggml_mul(ctx0, experts, weights);
  902. cb(cur, "ffn_moe_weighted", il);
  903. }
  904. ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
  905. assert(n_expert_used > 0);
  906. // order the views before the adds
  907. for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
  908. cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
  909. ggml_build_forward_expand(gf, cur_experts[i]);
  910. }
  911. // aggregate experts
  912. // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
  913. // to avoid potentially a large number of add nodes during warmup
  914. // ref: https://github.com/ggml-org/llama.cpp/pull/14753
  915. ggml_tensor * moe_out = cur_experts[0];
  916. for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
  917. moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
  918. }
  919. if (hparams.n_expert_used == 1) {
  920. // avoid returning a non-contiguous tensor
  921. moe_out = ggml_cont(ctx0, moe_out);
  922. }
  923. cb(moe_out, "ffn_moe_out", il);
  924. return moe_out;
  925. }
  926. // input embeddings with optional lora
  927. ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
  928. const int64_t n_embd = hparams.n_embd_inp();
  929. auto inp = std::make_unique<llm_graph_input_embd>();
  930. ggml_tensor * cur = nullptr;
  931. if (ubatch.token) {
  932. inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
  933. //cb(inp->tokens, "inp_tokens", -1);
  934. ggml_set_input(inp->tokens);
  935. res->t_tokens = inp->tokens;
  936. cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
  937. // apply lora for embedding tokens if needed
  938. for (const auto & lora : *loras) {
  939. llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
  940. if (lw == nullptr) {
  941. continue;
  942. }
  943. const float adapter_scale = lora.second;
  944. const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
  945. ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
  946. ctx0, lw->b, // non-transposed lora_b
  947. ggml_get_rows(ctx0, lw->a, inp->tokens)
  948. ), scale);
  949. cur = ggml_add(ctx0, cur, inpL_delta);
  950. }
  951. } else {
  952. inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
  953. ggml_set_input(inp->embd);
  954. cur = inp->embd;
  955. }
  956. // For Granite architecture
  957. if (hparams.f_embedding_scale != 0.0f) {
  958. cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
  959. }
  960. cb(cur, "inp_embd", -1);
  961. res->add_input(std::move(inp));
  962. return cur;
  963. }
  964. ggml_tensor * llm_graph_context::build_inp_pos() const {
  965. auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
  966. auto & cur = inp->pos;
  967. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
  968. ggml_set_input(cur);
  969. res->add_input(std::move(inp));
  970. return cur;
  971. }
  972. ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
  973. auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
  974. auto & cur = inp->attn_scale;
  975. // this need to be 1x1xN for broadcasting
  976. cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
  977. ggml_set_input(cur);
  978. res->add_input(std::move(inp));
  979. return cur;
  980. }
  981. ggml_tensor * llm_graph_context::build_inp_out_ids() const {
  982. // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
  983. // but this would make the graph topology depend on the number of output tokens, which can interere with
  984. // features that require constant topology such as pipline parallelism
  985. // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
  986. //if (n_outputs < n_tokens) {
  987. // return nullptr;
  988. //}
  989. auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
  990. auto & cur = inp->out_ids;
  991. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
  992. ggml_set_input(cur);
  993. res->add_input(std::move(inp));
  994. return cur;
  995. }
  996. ggml_tensor * llm_graph_context::build_inp_mean() const {
  997. auto inp = std::make_unique<llm_graph_input_mean>(cparams);
  998. auto & cur = inp->mean;
  999. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
  1000. ggml_set_input(cur);
  1001. res->add_input(std::move(inp));
  1002. return cur;
  1003. }
  1004. ggml_tensor * llm_graph_context::build_inp_cls() const {
  1005. auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
  1006. auto & cur = inp->cls;
  1007. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
  1008. ggml_set_input(cur);
  1009. res->add_input(std::move(inp));
  1010. return cur;
  1011. }
  1012. ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
  1013. auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
  1014. auto & cur = inp->cross_embd;
  1015. // if we have the output embeddings from the encoder, use them directly
  1016. // TODO: needs more work to be correct, for now just use the tensor shape
  1017. //if (cross->t_embd) {
  1018. // cur = ggml_view_tensor(ctx0, cross->t_embd);
  1019. // return cur;
  1020. //}
  1021. const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
  1022. const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
  1023. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
  1024. ggml_set_input(cur);
  1025. res->add_input(std::move(inp));
  1026. return cur;
  1027. }
  1028. ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
  1029. auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);
  1030. auto & cur = inp->pos_bucket;
  1031. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
  1032. ggml_set_input(cur);
  1033. res->add_input(std::move(inp));
  1034. return cur;
  1035. }
  1036. ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
  1037. const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
  1038. auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
  1039. const auto n_kv = mctx_cur->get_n_kv();
  1040. auto & cur = inp->pos_bucket;
  1041. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
  1042. ggml_set_input(cur);
  1043. res->add_input(std::move(inp));
  1044. return cur;
  1045. }
  1046. ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
  1047. ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]);
  1048. cb(pos_bucket_1d, "pos_bucket_1d", -1);
  1049. ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
  1050. pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
  1051. pos_bias = ggml_permute (ctx0, pos_bias, 2, 0, 1, 3);
  1052. pos_bias = ggml_cont (ctx0, pos_bias);
  1053. cb(pos_bias, "pos_bias", -1);
  1054. return pos_bias;
  1055. }
  1056. ggml_tensor * llm_graph_context::build_attn_mha(
  1057. ggml_tensor * q,
  1058. ggml_tensor * k,
  1059. ggml_tensor * v,
  1060. ggml_tensor * kq_b,
  1061. ggml_tensor * kq_mask,
  1062. ggml_tensor * sinks,
  1063. ggml_tensor * v_mla,
  1064. float kq_scale,
  1065. int il) const {
  1066. const bool v_trans = v->nb[1] > v->nb[2];
  1067. // split the batch into streams if needed
  1068. const auto n_stream = k->ne[3];
  1069. q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0);
  1070. q = ggml_permute(ctx0, q, 0, 2, 1, 3);
  1071. k = ggml_permute(ctx0, k, 0, 2, 1, 3);
  1072. v = ggml_permute(ctx0, v, 0, 2, 1, 3);
  1073. ggml_tensor * cur;
  1074. if (cparams.flash_attn && kq_b == nullptr) {
  1075. GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
  1076. if (v_trans) {
  1077. v = ggml_transpose(ctx0, v);
  1078. }
  1079. // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
  1080. if (k->type == GGML_TYPE_F32) {
  1081. k = ggml_cast(ctx0, k, GGML_TYPE_F16);
  1082. }
  1083. if (v->type == GGML_TYPE_F32) {
  1084. v = ggml_cast(ctx0, v, GGML_TYPE_F16);
  1085. }
  1086. cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
  1087. hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
  1088. cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
  1089. ggml_flash_attn_ext_add_sinks(cur, sinks);
  1090. ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
  1091. if (v_mla) {
  1092. #if 0
  1093. // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
  1094. // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
  1095. cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
  1096. cur = ggml_mul_mat(ctx0, v_mla, cur);
  1097. #else
  1098. // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
  1099. // The permutations are noops and only change how the tensor data is interpreted.
  1100. cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
  1101. cur = ggml_mul_mat(ctx0, v_mla, cur);
  1102. cb(cur, "fattn_mla", il);
  1103. cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
  1104. cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
  1105. #endif
  1106. }
  1107. cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
  1108. } else {
  1109. ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
  1110. cb(kq, "kq", il);
  1111. // note: this op tends to require high floating point range
  1112. // while for some models F16 is enough, for others it is not, so we default to F32 here
  1113. ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
  1114. if (arch == LLM_ARCH_GROK) {
  1115. // need to do the following:
  1116. // multiply by attn_output_multiplier
  1117. // and then :
  1118. // kq = 30 * tanh(kq / 30)
  1119. // before the softmax below
  1120. kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
  1121. cb(kq, "kq_tanh", il);
  1122. kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
  1123. cb(kq, "kq_scaled", il);
  1124. }
  1125. if (hparams.attn_soft_cap) {
  1126. kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
  1127. cb(kq, "kq_scaled_1", il);
  1128. kq = ggml_tanh (ctx0, kq);
  1129. cb(kq, "kq_tanh", il);
  1130. kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
  1131. cb(kq, "kq_scaled_2", il);
  1132. }
  1133. if (kq_b) {
  1134. kq = ggml_add(ctx0, kq, kq_b);
  1135. cb(kq, "kq_plus_kq_b", il);
  1136. }
  1137. kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
  1138. ggml_soft_max_add_sinks(kq, sinks);
  1139. cb(kq, "kq_soft_max", il);
  1140. if (!v_trans) {
  1141. // note: avoid this branch
  1142. v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
  1143. cb(v, "v_cont", il);
  1144. }
  1145. ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
  1146. cb(kqv, "kqv", il);
  1147. // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
  1148. if (v_mla) {
  1149. kqv = ggml_mul_mat(ctx0, v_mla, kqv);
  1150. cb(kqv, "kqv_mla", il);
  1151. }
  1152. cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
  1153. // recombine streams
  1154. cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
  1155. if (!cparams.offload_kqv) {
  1156. // all nodes between the KV store and the attention output are run on the CPU
  1157. ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
  1158. }
  1159. }
  1160. ggml_build_forward_expand(gf, cur);
  1161. return cur;
  1162. }
  1163. llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
  1164. auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
  1165. // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
  1166. inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
  1167. ggml_set_input(inp->self_kq_mask);
  1168. inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
  1169. if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
  1170. inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
  1171. ggml_set_input(inp->self_kq_mask_swa);
  1172. inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
  1173. } else {
  1174. inp->self_kq_mask_swa = nullptr;
  1175. inp->self_kq_mask_swa_cnv = nullptr;
  1176. }
  1177. return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
  1178. }
  1179. ggml_tensor * llm_graph_context::build_attn(
  1180. llm_graph_input_attn_no_cache * inp,
  1181. ggml_tensor * wo,
  1182. ggml_tensor * wo_b,
  1183. ggml_tensor * q_cur,
  1184. ggml_tensor * k_cur,
  1185. ggml_tensor * v_cur,
  1186. ggml_tensor * kq_b,
  1187. ggml_tensor * sinks,
  1188. ggml_tensor * v_mla,
  1189. float kq_scale,
  1190. int il) const {
  1191. GGML_UNUSED(n_tokens);
  1192. // these nodes are added to the graph together so that they are not reordered
  1193. // by doing so, the number of splits in the graph is reduced
  1194. ggml_build_forward_expand(gf, q_cur);
  1195. ggml_build_forward_expand(gf, k_cur);
  1196. ggml_build_forward_expand(gf, v_cur);
  1197. const bool is_swa = hparams.is_swa(il);
  1198. const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
  1199. // [TAG_NO_CACHE_PAD]
  1200. // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
  1201. // but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636
  1202. //assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
  1203. ggml_tensor * q = q_cur;
  1204. ggml_tensor * k = k_cur;
  1205. ggml_tensor * v = v_cur;
  1206. ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
  1207. cb(cur, "kqv_out", il);
  1208. if (wo) {
  1209. cur = build_lora_mm(wo, cur);
  1210. }
  1211. if (wo_b) {
  1212. //cb(cur, "kqv_wo", il);
  1213. }
  1214. if (wo_b) {
  1215. cur = ggml_add(ctx0, cur, wo_b);
  1216. }
  1217. return cur;
  1218. }
  1219. static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
  1220. ggml_context * ctx0,
  1221. const llama_ubatch & ubatch,
  1222. const llama_hparams & hparams,
  1223. const llama_cparams & cparams,
  1224. const llama_kv_cache_context * mctx_cur) {
  1225. auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
  1226. {
  1227. GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
  1228. const auto n_kv = mctx_cur->get_n_kv();
  1229. const auto n_tokens = ubatch.n_tokens;
  1230. const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
  1231. inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
  1232. inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
  1233. inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
  1234. ggml_set_input(inp->self_kq_mask);
  1235. inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
  1236. }
  1237. return inp;
  1238. }
  1239. llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
  1240. const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
  1241. auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
  1242. return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
  1243. }
  1244. ggml_tensor * llm_graph_context::build_attn(
  1245. llm_graph_input_attn_kv * inp,
  1246. ggml_tensor * wo,
  1247. ggml_tensor * wo_b,
  1248. ggml_tensor * q_cur,
  1249. ggml_tensor * k_cur,
  1250. ggml_tensor * v_cur,
  1251. ggml_tensor * kq_b,
  1252. ggml_tensor * sinks,
  1253. ggml_tensor * v_mla,
  1254. float kq_scale,
  1255. int il) const {
  1256. // these nodes are added to the graph together so that they are not reordered
  1257. // by doing so, the number of splits in the graph is reduced
  1258. ggml_build_forward_expand(gf, q_cur);
  1259. ggml_build_forward_expand(gf, k_cur);
  1260. ggml_build_forward_expand(gf, v_cur);
  1261. const auto * mctx_cur = inp->mctx;
  1262. // store to KV cache
  1263. {
  1264. const auto & k_idxs = inp->get_k_idxs();
  1265. const auto & v_idxs = inp->get_v_idxs();
  1266. ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
  1267. ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
  1268. }
  1269. const auto & kq_mask = inp->get_kq_mask();
  1270. ggml_tensor * q = q_cur;
  1271. ggml_tensor * k = mctx_cur->get_k(ctx0, il);
  1272. ggml_tensor * v = mctx_cur->get_v(ctx0, il);
  1273. ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
  1274. cb(cur, "kqv_out", il);
  1275. if (wo) {
  1276. cur = build_lora_mm(wo, cur);
  1277. if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
  1278. // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
  1279. ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
  1280. }
  1281. }
  1282. if (wo_b) {
  1283. cur = ggml_add(ctx0, cur, wo_b);
  1284. }
  1285. return cur;
  1286. }
  1287. ggml_tensor * llm_graph_context::build_attn(
  1288. llm_graph_input_attn_kv_iswa * inp,
  1289. ggml_tensor * wo,
  1290. ggml_tensor * wo_b,
  1291. ggml_tensor * q_cur,
  1292. ggml_tensor * k_cur,
  1293. ggml_tensor * v_cur,
  1294. ggml_tensor * kq_b,
  1295. ggml_tensor * sinks,
  1296. ggml_tensor * v_mla,
  1297. float kq_scale,
  1298. int il) const {
  1299. // these nodes are added to the graph together so that they are not reordered
  1300. // by doing so, the number of splits in the graph is reduced
  1301. ggml_build_forward_expand(gf, q_cur);
  1302. if (k_cur) {
  1303. ggml_build_forward_expand(gf, k_cur);
  1304. }
  1305. if (v_cur) {
  1306. ggml_build_forward_expand(gf, v_cur);
  1307. }
  1308. const auto * mctx_iswa = inp->mctx;
  1309. const bool is_swa = hparams.is_swa(il);
  1310. const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
  1311. // optionally store to KV cache
  1312. if (k_cur) {
  1313. const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
  1314. ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
  1315. }
  1316. if (v_cur) {
  1317. const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
  1318. ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
  1319. }
  1320. const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
  1321. ggml_tensor * q = q_cur;
  1322. ggml_tensor * k = mctx_cur->get_k(ctx0, il);
  1323. ggml_tensor * v = mctx_cur->get_v(ctx0, il);
  1324. ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
  1325. cb(cur, "kqv_out", il);
  1326. if (wo) {
  1327. cur = build_lora_mm(wo, cur);
  1328. }
  1329. if (wo_b) {
  1330. //cb(cur, "kqv_wo", il);
  1331. }
  1332. if (wo_b) {
  1333. cur = ggml_add(ctx0, cur, wo_b);
  1334. }
  1335. return cur;
  1336. }
  1337. llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
  1338. auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
  1339. const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
  1340. inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
  1341. ggml_set_input(inp->cross_kq_mask);
  1342. inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
  1343. return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
  1344. }
  1345. ggml_tensor * llm_graph_context::build_attn(
  1346. llm_graph_input_attn_cross * inp,
  1347. ggml_tensor * wo,
  1348. ggml_tensor * wo_b,
  1349. ggml_tensor * q_cur,
  1350. ggml_tensor * k_cur,
  1351. ggml_tensor * v_cur,
  1352. ggml_tensor * kq_b,
  1353. ggml_tensor * sinks,
  1354. ggml_tensor * v_mla,
  1355. float kq_scale,
  1356. int il) const {
  1357. // these nodes are added to the graph together so that they are not reordered
  1358. // by doing so, the number of splits in the graph is reduced
  1359. ggml_build_forward_expand(gf, q_cur);
  1360. ggml_build_forward_expand(gf, k_cur);
  1361. ggml_build_forward_expand(gf, v_cur);
  1362. const auto & kq_mask = inp->get_kq_mask_cross();
  1363. ggml_tensor * q = q_cur;
  1364. ggml_tensor * k = k_cur;
  1365. ggml_tensor * v = v_cur;
  1366. ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
  1367. cb(cur, "kqv_out", il);
  1368. if (wo) {
  1369. cur = build_lora_mm(wo, cur);
  1370. }
  1371. if (wo_b) {
  1372. //cb(cur, "kqv_wo", il);
  1373. }
  1374. if (wo_b) {
  1375. cur = ggml_add(ctx0, cur, wo_b);
  1376. }
  1377. return cur;
  1378. }
  1379. // TODO: maybe separate the inner implementation into a separate function
  1380. // like with the non-sliding window equivalent
  1381. // once sliding-window hybrid caches are a thing.
  1382. llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
  1383. const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
  1384. auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
  1385. const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
  1386. {
  1387. const auto n_kv = mctx_cur->get_base()->get_n_kv();
  1388. inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
  1389. inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
  1390. inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
  1391. ggml_set_input(inp->self_kq_mask);
  1392. inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
  1393. }
  1394. {
  1395. GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
  1396. const auto n_kv = mctx_cur->get_swa()->get_n_kv();
  1397. inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
  1398. inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
  1399. inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
  1400. ggml_set_input(inp->self_kq_mask_swa);
  1401. inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
  1402. }
  1403. return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
  1404. }
  1405. ggml_tensor * llm_graph_context::build_rs(
  1406. ggml_tensor * s,
  1407. ggml_tensor * state_copy_main,
  1408. ggml_tensor * state_copy_extra,
  1409. int32_t state_size,
  1410. int32_t n_seqs,
  1411. uint32_t n_rs,
  1412. uint32_t rs_head,
  1413. uint32_t rs_size,
  1414. int32_t rs_zero,
  1415. const llm_graph_get_rows_fn & get_state_rows) const {
  1416. ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
  1417. // Clear a single state which will then be copied to the other cleared states.
  1418. // Note that this is a no-op when the view is zero-sized.
  1419. ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
  1420. ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
  1421. // copy states
  1422. // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
  1423. // {state_size, rs_size} -> {state_size, n_seqs}
  1424. ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
  1425. ggml_build_forward_expand(gf, output_states);
  1426. // copy extra states which won't be changed further (between n_seqs and n_rs)
  1427. ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
  1428. ggml_build_forward_expand(gf,
  1429. ggml_cpy(ctx0,
  1430. states_extra,
  1431. ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
  1432. return output_states;
  1433. }
  1434. static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
  1435. ggml_context * ctx0,
  1436. const llama_ubatch & ubatch,
  1437. const llama_memory_recurrent_context * mctx_cur) {
  1438. auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
  1439. const int64_t n_rs = mctx_cur->get_n_rs();
  1440. const int64_t n_seqs = ubatch.n_seqs;
  1441. inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
  1442. ggml_set_input(inp->s_copy);
  1443. inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
  1444. inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
  1445. return inp;
  1446. }
  1447. llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
  1448. const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
  1449. auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
  1450. return (llm_graph_input_rs *) res->add_input(std::move(inp));
  1451. }
  1452. ggml_tensor * llm_graph_context::build_rs(
  1453. llm_graph_input_rs * inp,
  1454. ggml_tensor * s,
  1455. int32_t state_size,
  1456. int32_t n_seqs,
  1457. const llm_graph_get_rows_fn & get_state_rows) const {
  1458. const auto * kv_state = inp->mctx;
  1459. return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
  1460. kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
  1461. get_state_rows);
  1462. }
  1463. ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
  1464. llm_graph_input_rs * inp,
  1465. const llama_ubatch & ubatch,
  1466. int il) const {
  1467. const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
  1468. const auto token_shift_count = hparams.token_shift_count;
  1469. const int64_t n_seqs = ubatch.n_seqs;
  1470. ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
  1471. ggml_tensor * token_shift = build_rs(
  1472. inp, token_shift_all,
  1473. hparams.n_embd_r(), n_seqs);
  1474. token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
  1475. return token_shift;
  1476. }
  1477. ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
  1478. ggml_tensor * token_shift,
  1479. const llama_ubatch & ubatch,
  1480. int il) const {
  1481. const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
  1482. const auto token_shift_count = hparams.token_shift_count;
  1483. const auto n_embd = hparams.n_embd;
  1484. const int64_t n_seqs = ubatch.n_seqs;
  1485. const auto kv_head = mctx_cur->get_head();
  1486. return ggml_cpy(
  1487. ctx0,
  1488. ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
  1489. ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
  1490. );
  1491. }
  1492. llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
  1493. const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
  1494. auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
  1495. auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
  1496. auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
  1497. return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
  1498. }
  1499. void llm_graph_context::build_dense_out(
  1500. ggml_tensor * dense_2,
  1501. ggml_tensor * dense_3) const {
  1502. if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) {
  1503. return;
  1504. }
  1505. ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
  1506. GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
  1507. cur = ggml_mul_mat(ctx0, dense_2, cur);
  1508. cur = ggml_mul_mat(ctx0, dense_3, cur);
  1509. cb(cur, "result_embd_pooled", -1);
  1510. res->t_embd_pooled = cur;
  1511. ggml_build_forward_expand(gf, cur);
  1512. }
  1513. void llm_graph_context::build_pooling(
  1514. ggml_tensor * cls,
  1515. ggml_tensor * cls_b,
  1516. ggml_tensor * cls_out,
  1517. ggml_tensor * cls_out_b) const {
  1518. if (!cparams.embeddings) {
  1519. return;
  1520. }
  1521. ggml_tensor * inp = res->t_embd;
  1522. //// find result_norm tensor for input
  1523. //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
  1524. // inp = ggml_graph_node(gf, i);
  1525. // if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
  1526. // break;
  1527. // }
  1528. // inp = nullptr;
  1529. //}
  1530. GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
  1531. ggml_tensor * cur;
  1532. switch (pooling_type) {
  1533. case LLAMA_POOLING_TYPE_NONE:
  1534. {
  1535. cur = inp;
  1536. } break;
  1537. case LLAMA_POOLING_TYPE_MEAN:
  1538. {
  1539. ggml_tensor * inp_mean = build_inp_mean();
  1540. cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
  1541. } break;
  1542. case LLAMA_POOLING_TYPE_CLS:
  1543. case LLAMA_POOLING_TYPE_LAST:
  1544. {
  1545. ggml_tensor * inp_cls = build_inp_cls();
  1546. cur = ggml_get_rows(ctx0, inp, inp_cls);
  1547. } break;
  1548. case LLAMA_POOLING_TYPE_RANK:
  1549. {
  1550. ggml_tensor * inp_cls = build_inp_cls();
  1551. cur = ggml_get_rows(ctx0, inp, inp_cls);
  1552. // classification head
  1553. // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
  1554. if (cls) {
  1555. cur = ggml_mul_mat(ctx0, cls, cur);
  1556. if (cls_b) {
  1557. cur = ggml_add(ctx0, cur, cls_b);
  1558. }
  1559. cur = ggml_tanh(ctx0, cur);
  1560. }
  1561. // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
  1562. // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
  1563. // Single layer classification head (direct projection)
  1564. // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
  1565. if (cls_out) {
  1566. cur = ggml_mul_mat(ctx0, cls_out, cur);
  1567. if (cls_out_b) {
  1568. cur = ggml_add(ctx0, cur, cls_out_b);
  1569. }
  1570. }
  1571. // softmax for qwen3 reranker
  1572. if (arch == LLM_ARCH_QWEN3) {
  1573. cur = ggml_soft_max(ctx0, cur);
  1574. }
  1575. } break;
  1576. default:
  1577. {
  1578. GGML_ABORT("unknown pooling type");
  1579. }
  1580. }
  1581. cb(cur, "result_embd_pooled", -1);
  1582. res->t_embd_pooled = cur;
  1583. ggml_build_forward_expand(gf, cur);
  1584. }
  1585. int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
  1586. // TODO move to hparams if a T5 variant appears that uses a different value
  1587. const int64_t max_distance = 128;
  1588. if (bidirectional) {
  1589. n_buckets >>= 1;
  1590. }
  1591. const int64_t max_exact = n_buckets >> 1;
  1592. int32_t relative_position = x - y;
  1593. int32_t relative_bucket = 0;
  1594. if (bidirectional) {
  1595. relative_bucket += (relative_position > 0) * n_buckets;
  1596. relative_position = std::abs(relative_position);
  1597. } else {
  1598. relative_position = -std::min<int32_t>(relative_position, 0);
  1599. }
  1600. int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
  1601. relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
  1602. relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
  1603. return relative_bucket;
  1604. }