llama-graph.cpp 53 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628
  1. #include "llama-graph.h"
  2. #include "llama-impl.h"
  3. #include "llama-batch.h"
  4. #include "llama-cparams.h"
  5. #include "llama-kv-cache-unified.h"
  6. #include "llama-kv-cache-unified-iswa.h"
  7. #include "llama-memory-hybrid.h"
  8. #include "llama-memory-recurrent.h"
  9. #include <cassert>
  10. #include <cmath>
  11. #include <cstring>
  12. void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
  13. if (ubatch->token) {
  14. const int64_t n_tokens = ubatch->n_tokens;
  15. ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
  16. }
  17. if (ubatch->embd) {
  18. const int64_t n_embd = embd->ne[0];
  19. const int64_t n_tokens = ubatch->n_tokens;
  20. ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
  21. }
  22. }
  23. void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
  24. if (ubatch->pos && pos) {
  25. const int64_t n_tokens = ubatch->n_tokens;
  26. if (ubatch->token && n_pos_per_embd == 4) {
  27. // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
  28. // the 3 first dims are the same, and 4th dim is all 0
  29. std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
  30. // copy the first dimension
  31. for (int i = 0; i < n_tokens; ++i) {
  32. pos_data[ i] = ubatch->pos[i];
  33. pos_data[ n_tokens + i] = ubatch->pos[i];
  34. pos_data[2 * n_tokens + i] = ubatch->pos[i];
  35. pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
  36. }
  37. ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
  38. } else {
  39. ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
  40. }
  41. }
  42. }
  43. void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
  44. if (ubatch->pos && attn_scale) {
  45. const int64_t n_tokens = ubatch->n_tokens;
  46. std::vector<float> attn_scale_data(n_tokens, 0.0f);
  47. for (int i = 0; i < n_tokens; ++i) {
  48. const float pos = ubatch->pos[i];
  49. attn_scale_data[i] = std::log(
  50. std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
  51. ) * f_attn_temp_scale + 1.0;
  52. }
  53. ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
  54. }
  55. }
  56. void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
  57. if (pos_bucket) {
  58. const int64_t n_tokens = ubatch->n_tokens;
  59. GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
  60. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  61. int32_t * data = (int32_t *) pos_bucket->data;
  62. for (int h = 0; h < 1; ++h) {
  63. for (int j = 0; j < n_tokens; ++j) {
  64. for (int i = 0; i < n_tokens; ++i) {
  65. data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
  66. }
  67. }
  68. }
  69. }
  70. }
  71. void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
  72. if (pos_bucket) {
  73. mctx->set_input_pos_bucket(pos_bucket, ubatch);
  74. }
  75. }
  76. void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
  77. GGML_ASSERT(out_ids);
  78. const int64_t n_tokens = ubatch->n_tokens;
  79. GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
  80. int32_t * data = (int32_t *) out_ids->data;
  81. if (n_outputs == n_tokens) {
  82. for (int i = 0; i < n_tokens; ++i) {
  83. data[i] = i;
  84. }
  85. return;
  86. }
  87. GGML_ASSERT(ubatch->output);
  88. int n_outputs = 0;
  89. for (int i = 0; i < n_tokens; ++i) {
  90. if (ubatch->output[i]) {
  91. data[n_outputs++] = i;
  92. }
  93. }
  94. }
  95. void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
  96. if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
  97. const int64_t n_tokens = ubatch->n_tokens;
  98. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  99. const int64_t n_seqs_unq = ubatch->n_seqs_unq;
  100. GGML_ASSERT(mean);
  101. GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
  102. float * data = (float *) mean->data;
  103. memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
  104. std::vector<uint64_t> sums(n_seqs_unq, 0);
  105. for (int i = 0; i < n_tokens; i += n_seq_tokens) {
  106. for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
  107. const llama_seq_id seq_id = ubatch->seq_id[i][s];
  108. const int32_t seq_idx = ubatch->seq_idx[seq_id];
  109. sums[seq_idx] += ubatch->n_seq_tokens;
  110. }
  111. }
  112. std::vector<float> div(n_seqs_unq, 0.0f);
  113. for (int s = 0; s < n_seqs_unq; ++s) {
  114. const uint64_t sum = sums[s];
  115. if (sum > 0) {
  116. div[s] = 1.0f/float(sum);
  117. }
  118. }
  119. for (int i = 0; i < n_tokens; i += n_seq_tokens) {
  120. for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
  121. const llama_seq_id seq_id = ubatch->seq_id[i][s];
  122. const int32_t seq_idx = ubatch->seq_idx[seq_id];
  123. for (int j = 0; j < n_seq_tokens; ++j) {
  124. data[seq_idx*n_tokens + i + j] = div[seq_idx];
  125. }
  126. }
  127. }
  128. }
  129. }
  130. void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
  131. const int64_t n_tokens = ubatch->n_tokens;
  132. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  133. const int64_t n_seqs_unq = ubatch->n_seqs_unq;
  134. if (cparams.embeddings && (
  135. cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
  136. cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
  137. )) {
  138. GGML_ASSERT(cls);
  139. GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
  140. uint32_t * data = (uint32_t *) cls->data;
  141. memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
  142. for (int i = 0; i < n_tokens; i += n_seq_tokens) {
  143. for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
  144. const llama_seq_id seq_id = ubatch->seq_id[i][s];
  145. const int32_t seq_idx = ubatch->seq_idx[seq_id];
  146. data[seq_idx] = i;
  147. }
  148. }
  149. }
  150. if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
  151. GGML_ASSERT(cls);
  152. GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
  153. uint32_t * data = (uint32_t *) cls->data;
  154. memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
  155. std::vector<int> last_pos(n_seqs_unq, -1);
  156. std::vector<int> last_row(n_seqs_unq, -1);
  157. for (int i = 0; i < n_tokens; ++i) {
  158. const llama_pos pos = ubatch->pos[i];
  159. for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
  160. const llama_seq_id seq_id = ubatch->seq_id[i][s];
  161. const int32_t seq_idx = ubatch->seq_idx[seq_id];
  162. if (pos >= last_pos[seq_idx]) {
  163. last_pos[seq_idx] = pos;
  164. last_row[seq_idx] = i;
  165. }
  166. }
  167. }
  168. for (int s = 0; s < n_seqs_unq; ++s) {
  169. if (last_row[s] >= 0) {
  170. data[s] = last_row[s];
  171. }
  172. }
  173. }
  174. }
  175. void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
  176. GGML_UNUSED(ubatch);
  177. const int64_t n_rs = mctx->get_n_rs();
  178. if (s_copy) {
  179. GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
  180. int32_t * data = (int32_t *) s_copy->data;
  181. // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
  182. for (uint32_t i = 0; i < n_rs; ++i) {
  183. data[i] = mctx->s_copy(i);
  184. }
  185. }
  186. }
  187. void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
  188. GGML_UNUSED(ubatch);
  189. if (cross_embd && !cross->v_embd.empty()) {
  190. assert(cross_embd->type == GGML_TYPE_F32);
  191. ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd));
  192. }
  193. }
  194. void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
  195. const int64_t n_kv = ubatch->n_tokens;
  196. const int64_t n_tokens = ubatch->n_tokens;
  197. GGML_ASSERT(kq_mask);
  198. GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
  199. float * data = (float *) kq_mask->data;
  200. for (int h = 0; h < 1; ++h) {
  201. for (int i1 = 0; i1 < n_tokens; ++i1) {
  202. const llama_seq_id s1 = ubatch->seq_id[i1][0];
  203. for (int i0 = 0; i0 < n_tokens; ++i0) {
  204. float f = -INFINITY;
  205. for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
  206. const llama_seq_id s0 = ubatch->seq_id[i0][0];
  207. // TODO: reimplement this like in llama_kv_cache_unified
  208. if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
  209. if (hparams.use_alibi) {
  210. f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
  211. } else {
  212. f = 0.0f;
  213. }
  214. break;
  215. }
  216. }
  217. data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
  218. }
  219. }
  220. }
  221. }
  222. void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
  223. mctx->set_input_k_idxs(self_k_idxs, ubatch);
  224. mctx->set_input_v_idxs(self_v_idxs, ubatch);
  225. mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
  226. }
  227. void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
  228. mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
  229. mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
  230. mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
  231. mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
  232. mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
  233. mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
  234. }
  235. void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
  236. GGML_ASSERT(cross_kq_mask);
  237. const int64_t n_enc = cross_kq_mask->ne[0];
  238. const int64_t n_tokens = ubatch->n_tokens;
  239. GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
  240. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  241. float * data = (float *) cross_kq_mask->data;
  242. for (int h = 0; h < 1; ++h) {
  243. for (int i = 0; i < n_tokens; ++i) {
  244. for (int j = 0; j < n_enc; ++j) {
  245. float f = -INFINITY;
  246. for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
  247. const llama_seq_id seq_id = ubatch->seq_id[i][s];
  248. if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
  249. f = 0.0f;
  250. }
  251. }
  252. data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
  253. }
  254. }
  255. for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  256. for (int j = 0; j < n_enc; ++j) {
  257. data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
  258. }
  259. }
  260. }
  261. }
  262. void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
  263. inp_attn->set_input(ubatch);
  264. inp_rs->set_input(ubatch);
  265. }
  266. //
  267. // llm_graph_context
  268. //
  269. llm_graph_context::llm_graph_context(const llm_graph_params & params) :
  270. arch (params.arch),
  271. hparams (params.hparams),
  272. cparams (params.cparams),
  273. ubatch (params.ubatch),
  274. n_embd (hparams.n_embd),
  275. n_layer (hparams.n_layer),
  276. n_rot (hparams.n_rot),
  277. n_ctx (cparams.n_ctx),
  278. n_head (hparams.n_head()),
  279. n_head_kv (hparams.n_head_kv()),
  280. n_embd_head_k (hparams.n_embd_head_k),
  281. n_embd_k_gqa (hparams.n_embd_k_gqa()),
  282. n_embd_head_v (hparams.n_embd_head_v),
  283. n_embd_v_gqa (hparams.n_embd_v_gqa()),
  284. n_expert (hparams.n_expert),
  285. n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
  286. freq_base (cparams.rope_freq_base),
  287. freq_scale (cparams.rope_freq_scale),
  288. ext_factor (cparams.yarn_ext_factor),
  289. attn_factor (cparams.yarn_attn_factor),
  290. beta_fast (cparams.yarn_beta_fast),
  291. beta_slow (cparams.yarn_beta_slow),
  292. norm_eps (hparams.f_norm_eps),
  293. norm_rms_eps (hparams.f_norm_rms_eps),
  294. n_tokens (ubatch.n_tokens),
  295. n_outputs (params.n_outputs),
  296. n_ctx_orig (cparams.n_ctx_orig_yarn),
  297. pooling_type (cparams.pooling_type),
  298. rope_type (hparams.rope_type),
  299. ctx0 (params.ctx),
  300. sched (params.sched),
  301. backend_cpu (params.backend_cpu),
  302. cvec (params.cvec),
  303. loras (params.loras),
  304. mctx (params.mctx),
  305. cross (params.cross),
  306. cb_func (params.cb),
  307. res (std::make_unique<llm_graph_result>()) {
  308. }
  309. void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
  310. if (cb_func) {
  311. cb_func(ubatch, cur, name, il);
  312. }
  313. }
  314. ggml_tensor * llm_graph_context::build_cvec(
  315. ggml_tensor * cur,
  316. int il) const {
  317. return cvec->apply_to(ctx0, cur, il);
  318. }
  319. ggml_tensor * llm_graph_context::build_lora_mm(
  320. ggml_tensor * w,
  321. ggml_tensor * cur) const {
  322. ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
  323. for (const auto & lora : *loras) {
  324. llama_adapter_lora_weight * lw = lora.first->get_weight(w);
  325. if (lw == nullptr) {
  326. continue;
  327. }
  328. const float adapter_scale = lora.second;
  329. const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
  330. ggml_tensor * ab_cur = ggml_mul_mat(
  331. ctx0, lw->b,
  332. ggml_mul_mat(ctx0, lw->a, cur)
  333. );
  334. ab_cur = ggml_scale(ctx0, ab_cur, scale);
  335. res = ggml_add(ctx0, res, ab_cur);
  336. }
  337. return res;
  338. }
  339. ggml_tensor * llm_graph_context::build_lora_mm_id(
  340. ggml_tensor * w, // ggml_tensor * as
  341. ggml_tensor * cur, // ggml_tensor * b
  342. ggml_tensor * ids) const {
  343. ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
  344. for (const auto & lora : *loras) {
  345. llama_adapter_lora_weight * lw = lora.first->get_weight(w);
  346. if (lw == nullptr) {
  347. continue;
  348. }
  349. const float alpha = lora.first->alpha;
  350. const float rank = (float) lw->b->ne[0];
  351. const float scale = alpha ? lora.second * alpha / rank : lora.second;
  352. ggml_tensor * ab_cur = ggml_mul_mat_id(
  353. ctx0, lw->b,
  354. ggml_mul_mat_id(ctx0, lw->a, cur, ids),
  355. ids
  356. );
  357. ab_cur = ggml_scale(ctx0, ab_cur, scale);
  358. res = ggml_add(ctx0, res, ab_cur);
  359. }
  360. return res;
  361. }
  362. ggml_tensor * llm_graph_context::build_norm(
  363. ggml_tensor * cur,
  364. ggml_tensor * mw,
  365. ggml_tensor * mb,
  366. llm_norm_type type,
  367. int il) const {
  368. switch (type) {
  369. case LLM_NORM: cur = ggml_norm (ctx0, cur, hparams.f_norm_eps); break;
  370. case LLM_NORM_RMS: cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;
  371. case LLM_NORM_GROUP:
  372. {
  373. cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]);
  374. cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
  375. cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[2]);
  376. } break;
  377. }
  378. if (mw || mb) {
  379. cb(cur, "norm", il);
  380. }
  381. if (mw) {
  382. cur = ggml_mul(ctx0, cur, mw);
  383. if (mb) {
  384. cb(cur, "norm_w", il);
  385. }
  386. }
  387. if (mb) {
  388. cur = ggml_add(ctx0, cur, mb);
  389. }
  390. return cur;
  391. }
  392. ggml_tensor * llm_graph_context::build_ffn(
  393. ggml_tensor * cur,
  394. ggml_tensor * up,
  395. ggml_tensor * up_b,
  396. ggml_tensor * up_s,
  397. ggml_tensor * gate,
  398. ggml_tensor * gate_b,
  399. ggml_tensor * gate_s,
  400. ggml_tensor * down,
  401. ggml_tensor * down_b,
  402. ggml_tensor * down_s,
  403. ggml_tensor * act_scales,
  404. llm_ffn_op_type type_op,
  405. llm_ffn_gate_type type_gate,
  406. int il) const {
  407. ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
  408. cb(tmp, "ffn_up", il);
  409. if (up_b) {
  410. tmp = ggml_add(ctx0, tmp, up_b);
  411. cb(tmp, "ffn_up_b", il);
  412. }
  413. if (up_s) {
  414. tmp = ggml_mul(ctx0, tmp, up_s);
  415. cb(tmp, "ffn_up_s", il);
  416. }
  417. if (gate) {
  418. switch (type_gate) {
  419. case LLM_FFN_SEQ:
  420. {
  421. cur = build_lora_mm(gate, tmp);
  422. cb(cur, "ffn_gate", il);
  423. } break;
  424. case LLM_FFN_PAR:
  425. {
  426. cur = build_lora_mm(gate, cur);
  427. cb(cur, "ffn_gate", il);
  428. } break;
  429. }
  430. if (gate_b) {
  431. cur = ggml_add(ctx0, cur, gate_b);
  432. cb(cur, "ffn_gate_b", il);
  433. }
  434. if (gate_s) {
  435. cur = ggml_mul(ctx0, cur, gate_s);
  436. cb(cur, "ffn_gate_s", il);
  437. }
  438. } else {
  439. cur = tmp;
  440. }
  441. switch (type_op) {
  442. case LLM_FFN_SILU:
  443. if (gate && type_gate == LLM_FFN_PAR) {
  444. cur = ggml_swiglu_split(ctx0, cur, tmp);
  445. cb(cur, "ffn_swiglu", il);
  446. type_gate = LLM_FFN_SEQ;
  447. } else {
  448. cur = ggml_silu(ctx0, cur);
  449. cb(cur, "ffn_silu", il);
  450. } break;
  451. case LLM_FFN_GELU:
  452. if (gate && type_gate == LLM_FFN_PAR) {
  453. cur = ggml_geglu_split(ctx0, cur, tmp);
  454. cb(cur, "ffn_geglu", il);
  455. type_gate = LLM_FFN_SEQ;
  456. } else {
  457. cur = ggml_gelu(ctx0, cur);
  458. cb(cur, "ffn_gelu", il);
  459. if (act_scales != NULL) {
  460. cur = ggml_div(ctx0, cur, act_scales);
  461. cb(cur, "ffn_act", il);
  462. }
  463. } break;
  464. case LLM_FFN_RELU:
  465. if (gate && type_gate == LLM_FFN_PAR) {
  466. cur = ggml_reglu_split(ctx0, cur, tmp);
  467. cb(cur, "ffn_reglu", il);
  468. type_gate = LLM_FFN_SEQ;
  469. } else {
  470. cur = ggml_relu(ctx0, cur);
  471. cb(cur, "ffn_relu", il);
  472. } break;
  473. case LLM_FFN_RELU_SQR:
  474. {
  475. cur = ggml_relu(ctx0, cur);
  476. cb(cur, "ffn_relu", il);
  477. cur = ggml_sqr(ctx0, cur);
  478. cb(cur, "ffn_sqr(relu)", il);
  479. } break;
  480. case LLM_FFN_SWIGLU:
  481. {
  482. cur = ggml_swiglu(ctx0, cur);
  483. cb(cur, "ffn_swiglu", il);
  484. } break;
  485. case LLM_FFN_GEGLU:
  486. {
  487. cur = ggml_geglu(ctx0, cur);
  488. cb(cur, "ffn_geglu", il);
  489. } break;
  490. case LLM_FFN_REGLU:
  491. {
  492. cur = ggml_reglu(ctx0, cur);
  493. cb(cur, "ffn_reglu", il);
  494. } break;
  495. }
  496. if (gate && type_gate == LLM_FFN_PAR) {
  497. cur = ggml_mul(ctx0, cur, tmp);
  498. cb(cur, "ffn_gate_par", il);
  499. }
  500. if (down) {
  501. cur = build_lora_mm(down, cur);
  502. if (arch == LLM_ARCH_GLM4) {
  503. // GLM4 seems to have numerical issues with half-precision accumulators
  504. ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
  505. }
  506. }
  507. if (down_b) {
  508. cb(cur, "ffn_down", il);
  509. }
  510. if (down_b) {
  511. cur = ggml_add(ctx0, cur, down_b);
  512. }
  513. if (down_s) {
  514. cur = ggml_mul(ctx0, cur, down_s);
  515. cb(cur, "ffn_down_s", il);
  516. }
  517. return cur;
  518. }
  519. ggml_tensor * llm_graph_context::build_moe_ffn(
  520. ggml_tensor * cur,
  521. ggml_tensor * gate_inp,
  522. ggml_tensor * up_exps,
  523. ggml_tensor * gate_exps,
  524. ggml_tensor * down_exps,
  525. ggml_tensor * exp_probs_b,
  526. int64_t n_expert,
  527. int64_t n_expert_used,
  528. llm_ffn_op_type type_op,
  529. bool norm_w,
  530. bool scale_w,
  531. float w_scale,
  532. llama_expert_gating_func_type gating_op,
  533. int il) const {
  534. const int64_t n_embd = cur->ne[0];
  535. const int64_t n_tokens = cur->ne[1];
  536. const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
  537. ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
  538. cb(logits, "ffn_moe_logits", il);
  539. ggml_tensor * probs = nullptr;
  540. switch (gating_op) {
  541. case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
  542. {
  543. probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
  544. } break;
  545. case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
  546. {
  547. probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
  548. } break;
  549. default:
  550. GGML_ABORT("fatal error");
  551. }
  552. cb(probs, "ffn_moe_probs", il);
  553. // add experts selection bias - introduced in DeepSeek V3
  554. // leave probs unbiased as it's later used to get expert weights
  555. ggml_tensor * selection_probs = probs;
  556. if (exp_probs_b != nullptr) {
  557. selection_probs = ggml_add(ctx0, probs, exp_probs_b);
  558. cb(selection_probs, "ffn_moe_probs_biased", il);
  559. }
  560. // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
  561. // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
  562. if (arch == LLM_ARCH_LLAMA4) {
  563. selection_probs = logits;
  564. }
  565. // select experts
  566. ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
  567. cb(selected_experts->src[0], "ffn_moe_argsort", il);
  568. cb(selected_experts, "ffn_moe_topk", il);
  569. ggml_tensor * weights = ggml_get_rows(ctx0,
  570. ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
  571. cb(weights, "ffn_moe_weights", il);
  572. if (norm_w) {
  573. weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
  574. ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
  575. cb(weights_sum, "ffn_moe_weights_sum", il);
  576. weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
  577. cb(weights, "ffn_moe_weights_norm", il);
  578. weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
  579. }
  580. if (scale_w) {
  581. weights = ggml_scale(ctx0, weights, w_scale);
  582. cb(weights, "ffn_moe_weights_scaled", il);
  583. }
  584. cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
  585. if (weight_before_ffn) {
  586. // repeat cur to [n_embd, n_expert_used, n_tokens]
  587. ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
  588. cur = ggml_mul(ctx0, repeated, weights);
  589. cb(cur, "ffn_moe_weighted", il);
  590. }
  591. ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
  592. cb(up, "ffn_moe_up", il);
  593. ggml_tensor * experts = nullptr;
  594. if (gate_exps) {
  595. cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
  596. cb(cur, "ffn_moe_gate", il);
  597. } else {
  598. cur = up;
  599. }
  600. switch (type_op) {
  601. case LLM_FFN_SILU:
  602. if (gate_exps) {
  603. cur = ggml_swiglu_split(ctx0, cur, up);
  604. cb(cur, "ffn_moe_swiglu", il);
  605. } else {
  606. cur = ggml_silu(ctx0, cur);
  607. cb(cur, "ffn_moe_silu", il);
  608. } break;
  609. case LLM_FFN_GELU:
  610. if (gate_exps) {
  611. cur = ggml_geglu_split(ctx0, cur, up);
  612. cb(cur, "ffn_moe_geglu", il);
  613. } else {
  614. cur = ggml_gelu(ctx0, cur);
  615. cb(cur, "ffn_moe_gelu", il);
  616. } break;
  617. default:
  618. GGML_ABORT("fatal error");
  619. }
  620. experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
  621. cb(experts, "ffn_moe_down", il);
  622. if (!weight_before_ffn) {
  623. experts = ggml_mul(ctx0, experts, weights);
  624. cb(cur, "ffn_moe_weighted", il);
  625. }
  626. // aggregate experts
  627. ggml_tensor * moe_out = nullptr;
  628. for (int i = 0; i < n_expert_used; ++i) {
  629. ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
  630. experts->nb[2], i*experts->nb[1]);
  631. if (i == 0) {
  632. moe_out = cur_expert;
  633. } else {
  634. moe_out = ggml_add(ctx0, moe_out, cur_expert);
  635. }
  636. }
  637. if (n_expert_used == 1) {
  638. // avoid returning a non-contiguous tensor
  639. moe_out = ggml_cont(ctx0, moe_out);
  640. }
  641. cb(moe_out, "ffn_moe_out", il);
  642. return moe_out;
  643. }
  644. // input embeddings with optional lora
  645. ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
  646. const int64_t n_embd = hparams.n_embd;
  647. auto inp = std::make_unique<llm_graph_input_embd>();
  648. ggml_tensor * cur = nullptr;
  649. if (ubatch.token) {
  650. inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
  651. //cb(inp->tokens, "inp_tokens", -1);
  652. ggml_set_input(inp->tokens);
  653. res->t_tokens = inp->tokens;
  654. cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
  655. // apply lora for embedding tokens if needed
  656. for (const auto & lora : *loras) {
  657. llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
  658. if (lw == nullptr) {
  659. continue;
  660. }
  661. const float adapter_scale = lora.second;
  662. const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
  663. ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
  664. ctx0, lw->b, // non-transposed lora_b
  665. ggml_get_rows(ctx0, lw->a, inp->tokens)
  666. ), scale);
  667. cur = ggml_add(ctx0, cur, inpL_delta);
  668. }
  669. } else {
  670. inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
  671. ggml_set_input(inp->embd);
  672. cur = inp->embd;
  673. }
  674. // For Granite architecture
  675. if (hparams.f_embedding_scale != 0.0f) {
  676. cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
  677. }
  678. cb(cur, "inp_embd", -1);
  679. res->add_input(std::move(inp));
  680. return cur;
  681. }
  682. ggml_tensor * llm_graph_context::build_inp_pos() const {
  683. auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
  684. auto & cur = inp->pos;
  685. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
  686. ggml_set_input(cur);
  687. res->add_input(std::move(inp));
  688. return cur;
  689. }
  690. ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
  691. auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
  692. auto & cur = inp->attn_scale;
  693. // this need to be 1x1xN for broadcasting
  694. cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
  695. ggml_set_input(cur);
  696. res->add_input(std::move(inp));
  697. return cur;
  698. }
  699. ggml_tensor * llm_graph_context::build_inp_out_ids() const {
  700. // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
  701. // but this would make the graph topology depend on the number of output tokens, which can interere with
  702. // features that require constant topology such as pipline parallelism
  703. // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
  704. //if (n_outputs < n_tokens) {
  705. // return nullptr;
  706. //}
  707. auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
  708. auto & cur = inp->out_ids;
  709. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
  710. ggml_set_input(cur);
  711. res->add_input(std::move(inp));
  712. return cur;
  713. }
  714. ggml_tensor * llm_graph_context::build_inp_mean() const {
  715. auto inp = std::make_unique<llm_graph_input_mean>(cparams);
  716. auto & cur = inp->mean;
  717. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
  718. ggml_set_input(cur);
  719. res->add_input(std::move(inp));
  720. return cur;
  721. }
  722. ggml_tensor * llm_graph_context::build_inp_cls() const {
  723. auto inp = std::make_unique<llm_graph_input_cls>(cparams);
  724. auto & cur = inp->cls;
  725. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
  726. ggml_set_input(cur);
  727. res->add_input(std::move(inp));
  728. return cur;
  729. }
  730. ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
  731. auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
  732. auto & cur = inp->cross_embd;
  733. // if we have the output embeddings from the encoder, use them directly
  734. // TODO: needs more work to be correct, for now just use the tensor shape
  735. //if (cross->t_embd) {
  736. // cur = ggml_view_tensor(ctx0, cross->t_embd);
  737. // return cur;
  738. //}
  739. const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd;
  740. const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
  741. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
  742. ggml_set_input(cur);
  743. res->add_input(std::move(inp));
  744. return cur;
  745. }
  746. ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
  747. auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);
  748. auto & cur = inp->pos_bucket;
  749. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
  750. ggml_set_input(cur);
  751. res->add_input(std::move(inp));
  752. return cur;
  753. }
  754. ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
  755. const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
  756. auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
  757. const auto n_kv = mctx_cur->get_n_kv();
  758. auto & cur = inp->pos_bucket;
  759. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
  760. ggml_set_input(cur);
  761. res->add_input(std::move(inp));
  762. return cur;
  763. }
  764. ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
  765. ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]);
  766. cb(pos_bucket_1d, "pos_bucket_1d", -1);
  767. ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
  768. pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
  769. pos_bias = ggml_permute (ctx0, pos_bias, 2, 0, 1, 3);
  770. pos_bias = ggml_cont (ctx0, pos_bias);
  771. cb(pos_bias, "pos_bias", -1);
  772. return pos_bias;
  773. }
  774. ggml_tensor * llm_graph_context::build_attn_mha(
  775. ggml_cgraph * gf,
  776. ggml_tensor * q,
  777. ggml_tensor * k,
  778. ggml_tensor * v,
  779. ggml_tensor * kq_b,
  780. ggml_tensor * kq_mask,
  781. ggml_tensor * v_mla,
  782. float kq_scale) const {
  783. const bool v_trans = v->nb[1] > v->nb[2];
  784. q = ggml_permute(ctx0, q, 0, 2, 1, 3);
  785. k = ggml_permute(ctx0, k, 0, 2, 1, 3);
  786. v = ggml_permute(ctx0, v, 0, 2, 1, 3);
  787. const auto n_tokens = q->ne[1];
  788. const auto n_head = q->ne[2];
  789. const auto n_kv = k->ne[1];
  790. ggml_tensor * cur;
  791. // TODO: replace hardcoded padding with ggml-provided padding
  792. if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
  793. GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
  794. if (v_trans) {
  795. v = ggml_transpose(ctx0, v);
  796. }
  797. // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
  798. if (k->type == GGML_TYPE_F32) {
  799. k = ggml_cast(ctx0, k, GGML_TYPE_F16);
  800. }
  801. if (v->type == GGML_TYPE_F32) {
  802. v = ggml_cast(ctx0, v, GGML_TYPE_F16);
  803. }
  804. cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
  805. hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
  806. ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
  807. if (v_mla) {
  808. #if 0
  809. // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
  810. // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
  811. cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
  812. cur = ggml_mul_mat(ctx0, v_mla, cur);
  813. #else
  814. // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
  815. // The permutations are noops and only change how the tensor data is interpreted.
  816. cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
  817. cur = ggml_mul_mat(ctx0, v_mla, cur);
  818. cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
  819. cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
  820. #endif
  821. }
  822. cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
  823. } else {
  824. ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
  825. // note: this op tends to require high floating point range
  826. // while for some models F16 is enough, for others it is not, so we default to F32 here
  827. ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
  828. if (arch == LLM_ARCH_GROK) {
  829. // need to do the following:
  830. // multiply by attn_output_multiplyer of 0.08838834764831845
  831. // and then :
  832. // kq = 30 * tanh(kq / 30)
  833. // before the softmax below
  834. kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
  835. kq = ggml_scale(ctx0, kq, 30);
  836. }
  837. if (hparams.attn_soft_cap) {
  838. kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
  839. kq = ggml_tanh (ctx0, kq);
  840. kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
  841. }
  842. if (kq_b) {
  843. kq = ggml_add(ctx0, kq, kq_b);
  844. }
  845. kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
  846. if (!v_trans) {
  847. // note: avoid this branch
  848. v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
  849. }
  850. ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
  851. // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
  852. if (v_mla) {
  853. kqv = ggml_mul_mat(ctx0, v_mla, kqv);
  854. }
  855. cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
  856. cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
  857. if (!cparams.offload_kqv) {
  858. // all nodes between the KV store and the attention output are run on the CPU
  859. ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
  860. }
  861. }
  862. ggml_build_forward_expand(gf, cur);
  863. return cur;
  864. }
  865. llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
  866. auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
  867. // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
  868. inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
  869. ggml_set_input(inp->kq_mask);
  870. inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
  871. return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
  872. }
  873. ggml_tensor * llm_graph_context::build_attn(
  874. llm_graph_input_attn_no_cache * inp,
  875. ggml_cgraph * gf,
  876. ggml_tensor * wo,
  877. ggml_tensor * wo_b,
  878. ggml_tensor * q_cur,
  879. ggml_tensor * k_cur,
  880. ggml_tensor * v_cur,
  881. ggml_tensor * kq_b,
  882. ggml_tensor * v_mla,
  883. float kq_scale,
  884. int il) const {
  885. GGML_UNUSED(n_tokens);
  886. // these nodes are added to the graph together so that they are not reordered
  887. // by doing so, the number of splits in the graph is reduced
  888. ggml_build_forward_expand(gf, q_cur);
  889. ggml_build_forward_expand(gf, k_cur);
  890. ggml_build_forward_expand(gf, v_cur);
  891. const auto & kq_mask = inp->get_kq_mask();
  892. ggml_tensor * q = q_cur;
  893. ggml_tensor * k = k_cur;
  894. ggml_tensor * v = v_cur;
  895. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
  896. cb(cur, "kqv_out", il);
  897. if (wo) {
  898. cur = build_lora_mm(wo, cur);
  899. }
  900. if (wo_b) {
  901. //cb(cur, "kqv_wo", il);
  902. }
  903. if (wo_b) {
  904. cur = ggml_add(ctx0, cur, wo_b);
  905. }
  906. return cur;
  907. }
  908. static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
  909. ggml_context * ctx0,
  910. const llama_ubatch & ubatch,
  911. const llama_hparams & hparams,
  912. const llama_cparams & cparams,
  913. const llama_kv_cache_unified_context * mctx_cur) {
  914. auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
  915. {
  916. GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
  917. const auto n_kv = mctx_cur->get_n_kv();
  918. const auto n_tokens = ubatch.n_tokens;
  919. inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
  920. inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
  921. inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
  922. ggml_set_input(inp->self_kq_mask);
  923. inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
  924. }
  925. return inp;
  926. }
  927. llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
  928. const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
  929. auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
  930. return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
  931. }
  932. ggml_tensor * llm_graph_context::build_attn(
  933. llm_graph_input_attn_kv_unified * inp,
  934. ggml_cgraph * gf,
  935. ggml_tensor * wo,
  936. ggml_tensor * wo_b,
  937. ggml_tensor * q_cur,
  938. ggml_tensor * k_cur,
  939. ggml_tensor * v_cur,
  940. ggml_tensor * kq_b,
  941. ggml_tensor * v_mla,
  942. float kq_scale,
  943. int il) const {
  944. // these nodes are added to the graph together so that they are not reordered
  945. // by doing so, the number of splits in the graph is reduced
  946. ggml_build_forward_expand(gf, q_cur);
  947. ggml_build_forward_expand(gf, k_cur);
  948. ggml_build_forward_expand(gf, v_cur);
  949. const auto * mctx_cur = inp->mctx;
  950. // store to KV cache
  951. {
  952. const auto & k_idxs = inp->get_k_idxs();
  953. const auto & v_idxs = inp->get_v_idxs();
  954. ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
  955. ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
  956. }
  957. const auto & kq_mask = inp->get_kq_mask();
  958. ggml_tensor * q = q_cur;
  959. ggml_tensor * k = mctx_cur->get_k(ctx0, il);
  960. ggml_tensor * v = mctx_cur->get_v(ctx0, il);
  961. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
  962. cb(cur, "kqv_out", il);
  963. if (wo) {
  964. cur = build_lora_mm(wo, cur);
  965. if (arch == LLM_ARCH_GLM4) {
  966. // GLM4 seems to have numerical issues with half-precision accumulators
  967. ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
  968. }
  969. }
  970. if (wo_b) {
  971. cur = ggml_add(ctx0, cur, wo_b);
  972. }
  973. return cur;
  974. }
  975. ggml_tensor * llm_graph_context::build_attn(
  976. llm_graph_input_attn_kv_unified_iswa * inp,
  977. ggml_cgraph * gf,
  978. ggml_tensor * wo,
  979. ggml_tensor * wo_b,
  980. ggml_tensor * q_cur,
  981. ggml_tensor * k_cur,
  982. ggml_tensor * v_cur,
  983. ggml_tensor * kq_b,
  984. ggml_tensor * v_mla,
  985. float kq_scale,
  986. int il) const {
  987. // these nodes are added to the graph together so that they are not reordered
  988. // by doing so, the number of splits in the graph is reduced
  989. ggml_build_forward_expand(gf, q_cur);
  990. if (k_cur) {
  991. ggml_build_forward_expand(gf, k_cur);
  992. }
  993. if (v_cur) {
  994. ggml_build_forward_expand(gf, v_cur);
  995. }
  996. const auto * mctx_iswa = inp->mctx;
  997. const bool is_swa = hparams.is_swa(il);
  998. const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
  999. // optionally store to KV cache
  1000. if (k_cur) {
  1001. const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
  1002. ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
  1003. }
  1004. if (v_cur) {
  1005. const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
  1006. ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
  1007. }
  1008. const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
  1009. ggml_tensor * q = q_cur;
  1010. ggml_tensor * k = mctx_cur->get_k(ctx0, il);
  1011. ggml_tensor * v = mctx_cur->get_v(ctx0, il);
  1012. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
  1013. cb(cur, "kqv_out", il);
  1014. if (wo) {
  1015. cur = build_lora_mm(wo, cur);
  1016. }
  1017. if (wo_b) {
  1018. //cb(cur, "kqv_wo", il);
  1019. }
  1020. if (wo_b) {
  1021. cur = ggml_add(ctx0, cur, wo_b);
  1022. }
  1023. return cur;
  1024. }
  1025. llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
  1026. auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
  1027. const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
  1028. inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
  1029. ggml_set_input(inp->cross_kq_mask);
  1030. inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
  1031. return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
  1032. }
  1033. ggml_tensor * llm_graph_context::build_attn(
  1034. llm_graph_input_attn_cross * inp,
  1035. ggml_cgraph * gf,
  1036. ggml_tensor * wo,
  1037. ggml_tensor * wo_b,
  1038. ggml_tensor * q_cur,
  1039. ggml_tensor * k_cur,
  1040. ggml_tensor * v_cur,
  1041. ggml_tensor * kq_b,
  1042. ggml_tensor * v_mla,
  1043. float kq_scale,
  1044. int il) const {
  1045. // these nodes are added to the graph together so that they are not reordered
  1046. // by doing so, the number of splits in the graph is reduced
  1047. ggml_build_forward_expand(gf, q_cur);
  1048. ggml_build_forward_expand(gf, k_cur);
  1049. ggml_build_forward_expand(gf, v_cur);
  1050. const auto & kq_mask = inp->get_kq_mask_cross();
  1051. ggml_tensor * q = q_cur;
  1052. ggml_tensor * k = k_cur;
  1053. ggml_tensor * v = v_cur;
  1054. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
  1055. cb(cur, "kqv_out", il);
  1056. if (wo) {
  1057. cur = build_lora_mm(wo, cur);
  1058. }
  1059. if (wo_b) {
  1060. //cb(cur, "kqv_wo", il);
  1061. }
  1062. if (wo_b) {
  1063. cur = ggml_add(ctx0, cur, wo_b);
  1064. }
  1065. return cur;
  1066. }
  1067. // TODO: maybe separate the inner implementation into a separate function
  1068. // like with the non-sliding window equivalent
  1069. // once sliding-window hybrid caches are a thing.
  1070. llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
  1071. const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
  1072. auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
  1073. {
  1074. const auto n_kv = mctx_cur->get_base()->get_n_kv();
  1075. inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
  1076. inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
  1077. inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
  1078. ggml_set_input(inp->self_kq_mask);
  1079. inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
  1080. }
  1081. {
  1082. GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
  1083. const auto n_kv = mctx_cur->get_swa()->get_n_kv();
  1084. inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
  1085. inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
  1086. inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
  1087. ggml_set_input(inp->self_kq_mask_swa);
  1088. inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
  1089. }
  1090. return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
  1091. }
  1092. ggml_tensor * llm_graph_context::build_rs(
  1093. ggml_cgraph * gf,
  1094. ggml_tensor * s,
  1095. ggml_tensor * state_copy,
  1096. int32_t state_size,
  1097. int32_t n_seqs,
  1098. uint32_t n_kv,
  1099. uint32_t kv_head,
  1100. uint32_t kv_size,
  1101. int32_t rs_zero,
  1102. const llm_graph_get_rows_fn & get_state_rows) const {
  1103. ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
  1104. // Clear a single state which will then be copied to the other cleared states.
  1105. // Note that this is a no-op when the view is zero-sized.
  1106. ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
  1107. ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
  1108. // copy states
  1109. // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
  1110. // {state_size, kv_size} -> {state_size, n_seqs}
  1111. ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
  1112. ggml_build_forward_expand(gf, output_states);
  1113. // copy extra states which won't be changed further (between n_seqs and n_kv)
  1114. ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
  1115. ggml_build_forward_expand(gf,
  1116. ggml_cpy(ctx0,
  1117. states_extra,
  1118. ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
  1119. return output_states;
  1120. }
  1121. static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
  1122. ggml_context * ctx0,
  1123. const llama_memory_recurrent_context * mctx_cur) {
  1124. auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
  1125. const auto n_rs = mctx_cur->get_n_rs();
  1126. inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
  1127. ggml_set_input(inp->s_copy);
  1128. return inp;
  1129. }
  1130. llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
  1131. const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
  1132. auto inp = build_rs_inp_impl(ctx0, mctx_cur);
  1133. return (llm_graph_input_rs *) res->add_input(std::move(inp));
  1134. }
  1135. ggml_tensor * llm_graph_context::build_rs(
  1136. llm_graph_input_rs * inp,
  1137. ggml_cgraph * gf,
  1138. ggml_tensor * s,
  1139. int32_t state_size,
  1140. int32_t n_seqs,
  1141. const llm_graph_get_rows_fn & get_state_rows) const {
  1142. const auto * kv_state = inp->mctx;
  1143. return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
  1144. }
  1145. ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
  1146. llm_graph_input_rs * inp,
  1147. ggml_cgraph * gf,
  1148. const llama_ubatch & ubatch,
  1149. int il) const {
  1150. const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
  1151. const auto token_shift_count = hparams.token_shift_count;
  1152. const int64_t n_seqs = ubatch.n_seqs;
  1153. ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
  1154. ggml_tensor * token_shift = build_rs(
  1155. inp, gf, token_shift_all,
  1156. hparams.n_embd_r(), n_seqs);
  1157. token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
  1158. return token_shift;
  1159. }
  1160. ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
  1161. ggml_tensor * token_shift,
  1162. const llama_ubatch & ubatch,
  1163. int il) const {
  1164. const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
  1165. const auto token_shift_count = hparams.token_shift_count;
  1166. const auto n_embd = hparams.n_embd;
  1167. const int64_t n_seqs = ubatch.n_seqs;
  1168. const auto kv_head = mctx_cur->get_head();
  1169. return ggml_cpy(
  1170. ctx0,
  1171. ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
  1172. ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
  1173. );
  1174. }
  1175. llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
  1176. const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
  1177. auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
  1178. auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
  1179. auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
  1180. return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
  1181. }
  1182. void llm_graph_context::build_pooling(
  1183. ggml_cgraph * gf,
  1184. ggml_tensor * cls,
  1185. ggml_tensor * cls_b,
  1186. ggml_tensor * cls_out,
  1187. ggml_tensor * cls_out_b) const {
  1188. if (!cparams.embeddings) {
  1189. return;
  1190. }
  1191. ggml_tensor * inp = res->t_embd;
  1192. //// find result_norm tensor for input
  1193. //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
  1194. // inp = ggml_graph_node(gf, i);
  1195. // if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
  1196. // break;
  1197. // }
  1198. // inp = nullptr;
  1199. //}
  1200. GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
  1201. ggml_tensor * cur;
  1202. switch (pooling_type) {
  1203. case LLAMA_POOLING_TYPE_NONE:
  1204. {
  1205. cur = inp;
  1206. } break;
  1207. case LLAMA_POOLING_TYPE_MEAN:
  1208. {
  1209. ggml_tensor * inp_mean = build_inp_mean();
  1210. cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
  1211. } break;
  1212. case LLAMA_POOLING_TYPE_CLS:
  1213. case LLAMA_POOLING_TYPE_LAST:
  1214. {
  1215. ggml_tensor * inp_cls = build_inp_cls();
  1216. cur = ggml_get_rows(ctx0, inp, inp_cls);
  1217. } break;
  1218. case LLAMA_POOLING_TYPE_RANK:
  1219. {
  1220. ggml_tensor * inp_cls = build_inp_cls();
  1221. inp = ggml_get_rows(ctx0, inp, inp_cls);
  1222. if (cls) {
  1223. // classification head
  1224. // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
  1225. cur = ggml_mul_mat(ctx0, cls, inp);
  1226. if (cls_b) {
  1227. cur = ggml_add(ctx0, cur, cls_b);
  1228. }
  1229. cur = ggml_tanh(ctx0, cur);
  1230. // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
  1231. // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
  1232. if (cls_out) {
  1233. cur = ggml_mul_mat(ctx0, cls_out, cur);
  1234. if (cls_out_b) {
  1235. cur = ggml_add(ctx0, cur, cls_out_b);
  1236. }
  1237. }
  1238. } else if (cls_out) {
  1239. // Single layer classification head (direct projection)
  1240. // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
  1241. cur = ggml_mul_mat(ctx0, cls_out, inp);
  1242. if (cls_out_b) {
  1243. cur = ggml_add(ctx0, cur, cls_out_b);
  1244. }
  1245. } else {
  1246. GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
  1247. }
  1248. } break;
  1249. default:
  1250. {
  1251. GGML_ABORT("unknown pooling type");
  1252. }
  1253. }
  1254. cb(cur, "result_embd_pooled", -1);
  1255. res->t_embd_pooled = cur;
  1256. ggml_build_forward_expand(gf, cur);
  1257. }
  1258. int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
  1259. // TODO move to hparams if a T5 variant appears that uses a different value
  1260. const int64_t max_distance = 128;
  1261. if (bidirectional) {
  1262. n_buckets >>= 1;
  1263. }
  1264. const int64_t max_exact = n_buckets >> 1;
  1265. int32_t relative_position = x - y;
  1266. int32_t relative_bucket = 0;
  1267. if (bidirectional) {
  1268. relative_bucket += (relative_position > 0) * n_buckets;
  1269. relative_position = abs(relative_position);
  1270. } else {
  1271. relative_position = -std::min<int32_t>(relative_position, 0);
  1272. }
  1273. int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
  1274. relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
  1275. relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
  1276. return relative_bucket;
  1277. }