llama-graph.cpp 55 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640
  1. #include "llama-graph.h"
  2. #include "llama-impl.h"
  3. #include "llama-batch.h"
  4. #include "llama-cparams.h"
  5. #include "llama-kv-cache-unified.h"
  6. #include "llama-kv-cache-unified-iswa.h"
  7. #include "llama-kv-cache-recurrent.h"
  8. #include <cassert>
  9. #include <cmath>
  10. #include <cstring>
  11. void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
  12. if (ubatch->token) {
  13. const int64_t n_tokens = ubatch->n_tokens;
  14. ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
  15. }
  16. if (ubatch->embd) {
  17. const int64_t n_embd = embd->ne[0];
  18. const int64_t n_tokens = ubatch->n_tokens;
  19. ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
  20. }
  21. }
  22. void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
  23. if (ubatch->pos && pos) {
  24. const int64_t n_tokens = ubatch->n_tokens;
  25. if (ubatch->token && n_pos_per_embd == 4) {
  26. // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
  27. // the 3 first dims are the same, and 4th dim is all 0
  28. std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
  29. // copy the first dimension
  30. for (int i = 0; i < n_tokens; ++i) {
  31. pos_data[ i] = ubatch->pos[i];
  32. pos_data[ n_tokens + i] = ubatch->pos[i];
  33. pos_data[2 * n_tokens + i] = ubatch->pos[i];
  34. pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
  35. }
  36. ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
  37. } else {
  38. ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
  39. }
  40. }
  41. }
  42. void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
  43. if (ubatch->pos && attn_scale) {
  44. const int64_t n_tokens = ubatch->n_tokens;
  45. std::vector<float> attn_scale_data(n_tokens, 0.0f);
  46. for (int i = 0; i < n_tokens; ++i) {
  47. const float pos = ubatch->pos[i];
  48. attn_scale_data[i] = std::log(
  49. std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
  50. ) * f_attn_temp_scale + 1.0;
  51. }
  52. ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
  53. }
  54. }
  55. void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
  56. if (pos_bucket) {
  57. const int64_t n_tokens = ubatch->n_tokens;
  58. GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
  59. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  60. int32_t * data = (int32_t *) pos_bucket->data;
  61. for (int h = 0; h < 1; ++h) {
  62. for (int j = 0; j < n_tokens; ++j) {
  63. for (int i = 0; i < n_tokens; ++i) {
  64. data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
  65. }
  66. }
  67. }
  68. }
  69. }
  70. void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
  71. if (pos_bucket) {
  72. kv_state->set_input_pos_bucket(pos_bucket, ubatch);
  73. }
  74. }
  75. void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
  76. if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
  77. //GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
  78. if (!out_ids) {
  79. LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
  80. } else {
  81. const int64_t n_tokens = ubatch->n_tokens;
  82. GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
  83. int32_t * data = (int32_t *) out_ids->data;
  84. if (n_outputs == n_tokens) {
  85. for (int i = 0; i < n_tokens; ++i) {
  86. data[i] = i;
  87. }
  88. } else if (ubatch->output) {
  89. int32_t n_outputs = 0;
  90. for (int i = 0; i < n_tokens; ++i) {
  91. if (ubatch->output[i]) {
  92. data[n_outputs++] = i;
  93. }
  94. }
  95. // the graph needs to have been passed the correct number of outputs
  96. GGML_ASSERT(n_outputs == n_outputs);
  97. } else if (n_outputs == 1) {
  98. // only keep last output
  99. data[0] = n_tokens - 1;
  100. } else {
  101. GGML_ASSERT(n_outputs == 0);
  102. }
  103. }
  104. }
  105. }
  106. void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
  107. if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
  108. const int64_t n_tokens = ubatch->n_tokens;
  109. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  110. const int64_t n_seqs = ubatch->n_seqs;
  111. GGML_ASSERT(mean);
  112. GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
  113. float * data = (float *) mean->data;
  114. memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
  115. std::vector<uint64_t> sum(n_tokens, 0);
  116. for (int s = 0; s < n_seqs; ++s) {
  117. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  118. // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
  119. GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
  120. sum[seq_id] += ubatch->n_seq_tokens;
  121. }
  122. std::vector<float> div(n_tokens, 0.0f);
  123. for (int i = 0; i < n_tokens; ++i) {
  124. const uint64_t s = sum[i];
  125. if (s > 0) {
  126. div[i] = 1.0f/float(s);
  127. }
  128. }
  129. for (int s = 0; s < n_seqs; ++s) {
  130. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  131. for (int i = 0; i < n_seq_tokens; ++i) {
  132. data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
  133. }
  134. }
  135. }
  136. }
  137. void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
  138. if (cparams.embeddings && (
  139. cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
  140. cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
  141. const int64_t n_tokens = ubatch->n_tokens;
  142. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  143. const int64_t n_seqs = ubatch->n_seqs;
  144. GGML_ASSERT(cls);
  145. GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
  146. uint32_t * data = (uint32_t *) cls->data;
  147. memset(cls->data, 0, n_tokens * ggml_element_size(cls));
  148. for (int s = 0; s < n_seqs; ++s) {
  149. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  150. // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
  151. GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
  152. for (int i = 0; i < n_seq_tokens; ++i) {
  153. const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
  154. if (pos == 0) {
  155. data[seq_id] = s*n_seq_tokens + i;
  156. }
  157. }
  158. }
  159. }
  160. if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
  161. const int64_t n_tokens = ubatch->n_tokens;
  162. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  163. const int64_t n_seqs = ubatch->n_seqs;
  164. GGML_ASSERT(cls);
  165. GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
  166. uint32_t * data = (uint32_t *) cls->data;
  167. memset(cls->data, 0, n_tokens * ggml_element_size(cls));
  168. std::vector<int> last_pos(n_tokens, -1);
  169. std::vector<int> last_row(n_tokens, -1);
  170. for (int s = 0; s < n_seqs; ++s) {
  171. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  172. // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
  173. GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
  174. for (int i = 0; i < n_seq_tokens; ++i) {
  175. const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
  176. if (pos >= last_pos[seq_id]) {
  177. last_pos[seq_id] = pos;
  178. last_row[seq_id] = s*n_seq_tokens + i;
  179. }
  180. }
  181. }
  182. for (int i = 0; i < n_tokens; ++i) {
  183. if (last_row[i] >= 0) {
  184. data[i] = last_row[i];
  185. }
  186. }
  187. }
  188. }
  189. void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
  190. GGML_UNUSED(ubatch);
  191. const int64_t n_kv = kv_state->get_n_kv();
  192. if (s_copy) {
  193. GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
  194. int32_t * data = (int32_t *) s_copy->data;
  195. // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
  196. for (uint32_t i = 0; i < n_kv; ++i) {
  197. data[i] = kv_state->s_copy(i);
  198. }
  199. }
  200. }
  201. void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
  202. GGML_UNUSED(ubatch);
  203. const int64_t n_kv = kv_state->get_n_kv();
  204. if (s_mask) {
  205. GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
  206. float * data = (float *) s_mask->data;
  207. // clear unused states
  208. for (int i = 0; i < n_kv; ++i) {
  209. data[i] = kv_state->s_mask(i);
  210. }
  211. }
  212. }
  213. void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
  214. GGML_UNUSED(ubatch);
  215. if (cross_embd && !cross->v_embd.empty()) {
  216. assert(cross_embd->type == GGML_TYPE_F32);
  217. ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd));
  218. }
  219. }
  220. void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
  221. if (kq_mask) {
  222. if (cparams.causal_attn) {
  223. const int64_t n_kv = ubatch->n_tokens;
  224. const int64_t n_tokens = ubatch->n_tokens;
  225. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  226. const int64_t n_seqs = ubatch->n_seqs;
  227. GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
  228. float * data = (float *) kq_mask->data;
  229. for (int h = 0; h < 1; ++h) {
  230. for (int s1 = 0; s1 < n_seqs; ++s1) {
  231. const llama_seq_id seq_id = ubatch->seq_id[s1][0];
  232. for (int j = 0; j < n_seq_tokens; ++j) {
  233. const int32_t tj = s1*n_seq_tokens + j;
  234. for (int s0 = 0; s0 < n_seqs; ++s0) {
  235. for (int i = 0; i < n_seq_tokens; ++i) {
  236. const int32_t ti = s0*n_seq_tokens + i;
  237. float f = -INFINITY;
  238. for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
  239. if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
  240. if (hparams.use_alibi) {
  241. f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
  242. } else {
  243. f = 0.0f;
  244. }
  245. break;
  246. }
  247. }
  248. data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
  249. }
  250. }
  251. }
  252. }
  253. }
  254. } else {
  255. const int64_t n_tokens = ubatch->n_tokens;
  256. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  257. const int64_t n_seqs = ubatch->n_seqs;
  258. const int64_t n_stride = ubatch->n_tokens;
  259. GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
  260. float * data = (float *) kq_mask->data;
  261. for (int h = 0; h < 1; ++h) {
  262. for (int s1 = 0; s1 < n_seqs; ++s1) {
  263. const llama_seq_id seq_id = ubatch->seq_id[s1][0];
  264. for (int j = 0; j < n_seq_tokens; ++j) {
  265. const int32_t tj = s1*n_seq_tokens + j;
  266. for (int s0 = 0; s0 < n_seqs; ++s0) {
  267. for (int i = 0; i < n_seq_tokens; ++i) {
  268. const int32_t ti = s0*n_seq_tokens + i;
  269. float f = -INFINITY;
  270. for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
  271. if (ubatch->seq_id[s0][s] == seq_id) {
  272. if (hparams.use_alibi) {
  273. f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
  274. } else {
  275. f = 0.0f;
  276. }
  277. break;
  278. }
  279. }
  280. data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
  281. }
  282. }
  283. for (int i = n_tokens; i < n_stride; ++i) {
  284. data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
  285. }
  286. }
  287. }
  288. }
  289. }
  290. }
  291. }
  292. void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
  293. if (self_kq_mask) {
  294. kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
  295. }
  296. }
  297. void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
  298. if (self_kq_mask) {
  299. kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
  300. }
  301. if (self_kq_mask_swa) {
  302. kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
  303. }
  304. }
  305. void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
  306. if (cross_kq_mask) {
  307. const int64_t n_enc = cross_kq_mask->ne[0];
  308. const int64_t n_tokens = ubatch->n_tokens;
  309. GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
  310. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  311. float * data = (float *) cross_kq_mask->data;
  312. for (int h = 0; h < 1; ++h) {
  313. for (int j = 0; j < n_tokens; ++j) {
  314. for (int i = 0; i < n_enc; ++i) {
  315. float f = -INFINITY;
  316. for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
  317. const llama_seq_id seq_id = ubatch->seq_id[j][s];
  318. if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
  319. f = 0.0f;
  320. }
  321. }
  322. data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
  323. }
  324. }
  325. for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  326. for (int j = 0; j < n_enc; ++j) {
  327. data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
  328. }
  329. }
  330. }
  331. }
  332. }
  333. //
  334. // llm_graph_context
  335. //
  336. llm_graph_context::llm_graph_context(const llm_graph_params & params) :
  337. arch (params.arch),
  338. hparams (params.hparams),
  339. cparams (params.cparams),
  340. ubatch (params.ubatch),
  341. n_embd (hparams.n_embd),
  342. n_layer (hparams.n_layer),
  343. n_rot (hparams.n_rot),
  344. n_ctx (cparams.n_ctx),
  345. n_head (hparams.n_head()),
  346. n_head_kv (hparams.n_head_kv()),
  347. n_embd_head_k (hparams.n_embd_head_k),
  348. n_embd_k_gqa (hparams.n_embd_k_gqa()),
  349. n_embd_head_v (hparams.n_embd_head_v),
  350. n_embd_v_gqa (hparams.n_embd_v_gqa()),
  351. n_expert (hparams.n_expert),
  352. n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
  353. freq_base (cparams.rope_freq_base),
  354. freq_scale (cparams.rope_freq_scale),
  355. ext_factor (cparams.yarn_ext_factor),
  356. attn_factor (cparams.yarn_attn_factor),
  357. beta_fast (cparams.yarn_beta_fast),
  358. beta_slow (cparams.yarn_beta_slow),
  359. norm_eps (hparams.f_norm_eps),
  360. norm_rms_eps (hparams.f_norm_rms_eps),
  361. n_tokens (ubatch.n_tokens),
  362. n_outputs (params.n_outputs),
  363. n_ctx_orig (cparams.n_ctx_orig_yarn),
  364. pooling_type (cparams.pooling_type),
  365. rope_type (hparams.rope_type),
  366. ctx0 (params.ctx),
  367. sched (params.sched),
  368. backend_cpu (params.backend_cpu),
  369. cvec (params.cvec),
  370. loras (params.loras),
  371. mstate (params.mstate),
  372. cross (params.cross),
  373. cb_func (params.cb),
  374. res (std::make_unique<llm_graph_result>()) {
  375. }
  376. int64_t llm_graph_context::n_pos_per_embd() const {
  377. return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
  378. }
  379. void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
  380. if (cb_func) {
  381. cb_func(ubatch, cur, name, il);
  382. }
  383. }
  384. ggml_tensor * llm_graph_context::build_cvec(
  385. ggml_tensor * cur,
  386. int il) const {
  387. return cvec->apply_to(ctx0, cur, il);
  388. }
  389. ggml_tensor * llm_graph_context::build_lora_mm(
  390. ggml_tensor * w,
  391. ggml_tensor * cur) const {
  392. ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
  393. for (const auto & lora : *loras) {
  394. llama_adapter_lora_weight * lw = lora.first->get_weight(w);
  395. if (lw == nullptr) {
  396. continue;
  397. }
  398. const float adapter_scale = lora.second;
  399. const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
  400. ggml_tensor * ab_cur = ggml_mul_mat(
  401. ctx0, lw->b,
  402. ggml_mul_mat(ctx0, lw->a, cur)
  403. );
  404. ab_cur = ggml_scale(ctx0, ab_cur, scale);
  405. res = ggml_add(ctx0, res, ab_cur);
  406. }
  407. return res;
  408. }
  409. ggml_tensor * llm_graph_context::build_lora_mm_id(
  410. ggml_tensor * w, // ggml_tensor * as
  411. ggml_tensor * cur, // ggml_tensor * b
  412. ggml_tensor * ids) const {
  413. ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
  414. for (const auto & lora : *loras) {
  415. llama_adapter_lora_weight * lw = lora.first->get_weight(w);
  416. if (lw == nullptr) {
  417. continue;
  418. }
  419. const float alpha = lora.first->alpha;
  420. const float rank = (float) lw->b->ne[0];
  421. const float scale = alpha ? lora.second * alpha / rank : lora.second;
  422. ggml_tensor * ab_cur = ggml_mul_mat_id(
  423. ctx0, lw->b,
  424. ggml_mul_mat_id(ctx0, lw->a, cur, ids),
  425. ids
  426. );
  427. ab_cur = ggml_scale(ctx0, ab_cur, scale);
  428. res = ggml_add(ctx0, res, ab_cur);
  429. }
  430. return res;
  431. }
  432. ggml_tensor * llm_graph_context::build_norm(
  433. ggml_tensor * cur,
  434. ggml_tensor * mw,
  435. ggml_tensor * mb,
  436. llm_norm_type type,
  437. int il) const {
  438. switch (type) {
  439. case LLM_NORM: cur = ggml_norm (ctx0, cur, hparams.f_norm_eps); break;
  440. case LLM_NORM_RMS: cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;
  441. case LLM_NORM_GROUP:
  442. {
  443. cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]);
  444. cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
  445. cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[2]);
  446. } break;
  447. }
  448. if (mw || mb) {
  449. cb(cur, "norm", il);
  450. }
  451. if (mw) {
  452. cur = ggml_mul(ctx0, cur, mw);
  453. if (mb) {
  454. cb(cur, "norm_w", il);
  455. }
  456. }
  457. if (mb) {
  458. cur = ggml_add(ctx0, cur, mb);
  459. }
  460. return cur;
  461. }
  462. ggml_tensor * llm_graph_context::build_ffn(
  463. ggml_tensor * cur,
  464. ggml_tensor * up,
  465. ggml_tensor * up_b,
  466. ggml_tensor * up_s,
  467. ggml_tensor * gate,
  468. ggml_tensor * gate_b,
  469. ggml_tensor * gate_s,
  470. ggml_tensor * down,
  471. ggml_tensor * down_b,
  472. ggml_tensor * down_s,
  473. ggml_tensor * act_scales,
  474. llm_ffn_op_type type_op,
  475. llm_ffn_gate_type type_gate,
  476. int il) const {
  477. ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
  478. cb(tmp, "ffn_up", il);
  479. if (up_b) {
  480. tmp = ggml_add(ctx0, tmp, up_b);
  481. cb(tmp, "ffn_up_b", il);
  482. }
  483. if (up_s) {
  484. tmp = ggml_mul(ctx0, tmp, up_s);
  485. cb(tmp, "ffn_up_s", il);
  486. }
  487. if (gate) {
  488. switch (type_gate) {
  489. case LLM_FFN_SEQ:
  490. {
  491. cur = build_lora_mm(gate, tmp);
  492. cb(cur, "ffn_gate", il);
  493. } break;
  494. case LLM_FFN_PAR:
  495. {
  496. cur = build_lora_mm(gate, cur);
  497. cb(cur, "ffn_gate", il);
  498. } break;
  499. }
  500. if (gate_b) {
  501. cur = ggml_add(ctx0, cur, gate_b);
  502. cb(cur, "ffn_gate_b", il);
  503. }
  504. if (gate_s) {
  505. cur = ggml_mul(ctx0, cur, gate_s);
  506. cb(cur, "ffn_gate_s", il);
  507. }
  508. } else {
  509. cur = tmp;
  510. }
  511. switch (type_op) {
  512. case LLM_FFN_SILU:
  513. {
  514. cur = ggml_silu(ctx0, cur);
  515. cb(cur, "ffn_silu", il);
  516. } break;
  517. case LLM_FFN_GELU:
  518. {
  519. cur = ggml_gelu(ctx0, cur);
  520. cb(cur, "ffn_gelu", il);
  521. if (act_scales != NULL) {
  522. cur = ggml_div(ctx0, cur, act_scales);
  523. cb(cur, "ffn_act", il);
  524. }
  525. } break;
  526. case LLM_FFN_RELU:
  527. {
  528. cur = ggml_relu(ctx0, cur);
  529. cb(cur, "ffn_relu", il);
  530. } break;
  531. case LLM_FFN_RELU_SQR:
  532. {
  533. cur = ggml_relu(ctx0, cur);
  534. cb(cur, "ffn_relu", il);
  535. cur = ggml_sqr(ctx0, cur);
  536. cb(cur, "ffn_sqr(relu)", il);
  537. } break;
  538. case LLM_FFN_SWIGLU:
  539. {
  540. // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
  541. int64_t split_point = cur->ne[0] / 2;
  542. // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
  543. ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
  544. ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
  545. x0 = ggml_silu(ctx0, x0);
  546. cb(cur, "ffn_silu", il);
  547. cur = ggml_mul(ctx0, x0, x1);
  548. cb(cur, "ffn_mul", il);
  549. } break;
  550. case LLM_FFN_GEGLU:
  551. {
  552. // Split into two equal parts
  553. int64_t split_point = cur->ne[0] / 2;
  554. // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
  555. ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
  556. ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
  557. x0 = ggml_gelu(ctx0, x0);
  558. cb(x0, "ffn_gelu", il);
  559. cur = ggml_mul(ctx0, x0, x1);
  560. cb(cur, "ffn_geglu", il);
  561. } break;
  562. }
  563. if (gate && type_gate == LLM_FFN_PAR) {
  564. cur = ggml_mul(ctx0, cur, tmp);
  565. cb(cur, "ffn_gate_par", il);
  566. }
  567. if (down) {
  568. cur = build_lora_mm(down, cur);
  569. if (arch == LLM_ARCH_GLM4) {
  570. // GLM4 seems to have numerical issues with half-precision accumulators
  571. ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
  572. }
  573. }
  574. if (down_b) {
  575. cb(cur, "ffn_down", il);
  576. }
  577. if (down_b) {
  578. cur = ggml_add(ctx0, cur, down_b);
  579. }
  580. if (down_s) {
  581. cur = ggml_mul(ctx0, cur, down_s);
  582. cb(cur, "ffn_down_s", il);
  583. }
  584. return cur;
  585. }
  586. ggml_tensor * llm_graph_context::build_moe_ffn(
  587. ggml_tensor * cur,
  588. ggml_tensor * gate_inp,
  589. ggml_tensor * up_exps,
  590. ggml_tensor * gate_exps,
  591. ggml_tensor * down_exps,
  592. ggml_tensor * exp_probs_b,
  593. int64_t n_expert,
  594. int64_t n_expert_used,
  595. llm_ffn_op_type type_op,
  596. bool norm_w,
  597. bool scale_w,
  598. float w_scale,
  599. llama_expert_gating_func_type gating_op,
  600. int il) const {
  601. const int64_t n_embd = cur->ne[0];
  602. const int64_t n_tokens = cur->ne[1];
  603. const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
  604. ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
  605. cb(logits, "ffn_moe_logits", il);
  606. ggml_tensor * probs = nullptr;
  607. switch (gating_op) {
  608. case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
  609. {
  610. probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
  611. } break;
  612. case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
  613. {
  614. probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
  615. } break;
  616. default:
  617. GGML_ABORT("fatal error");
  618. }
  619. cb(probs, "ffn_moe_probs", il);
  620. // add experts selection bias - introduced in DeepSeek V3
  621. // leave probs unbiased as it's later used to get expert weights
  622. ggml_tensor * selection_probs = probs;
  623. if (exp_probs_b != nullptr) {
  624. selection_probs = ggml_add(ctx0, probs, exp_probs_b);
  625. cb(selection_probs, "ffn_moe_probs_biased", il);
  626. }
  627. // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
  628. // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
  629. if (arch == LLM_ARCH_LLAMA4) {
  630. selection_probs = logits;
  631. }
  632. // select experts
  633. ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
  634. cb(selected_experts->src[0], "ffn_moe_argsort", il);
  635. cb(selected_experts, "ffn_moe_topk", il);
  636. ggml_tensor * weights = ggml_get_rows(ctx0,
  637. ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
  638. cb(weights, "ffn_moe_weights", il);
  639. if (norm_w) {
  640. weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
  641. ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
  642. cb(weights_sum, "ffn_moe_weights_sum", il);
  643. weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
  644. cb(weights, "ffn_moe_weights_norm", il);
  645. weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
  646. }
  647. if (scale_w) {
  648. weights = ggml_scale(ctx0, weights, w_scale);
  649. cb(weights, "ffn_moe_weights_scaled", il);
  650. }
  651. cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
  652. if (weight_before_ffn) {
  653. // repeat cur to [n_embd, n_expert_used, n_tokens]
  654. ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
  655. cur = ggml_mul(ctx0, repeated, weights);
  656. cb(cur, "ffn_moe_weighted", il);
  657. }
  658. ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
  659. cb(up, "ffn_moe_up", il);
  660. ggml_tensor * experts = nullptr;
  661. if (gate_exps) {
  662. cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
  663. cb(cur, "ffn_moe_gate", il);
  664. } else {
  665. cur = up;
  666. }
  667. switch (type_op) {
  668. case LLM_FFN_SILU:
  669. {
  670. cur = ggml_silu(ctx0, cur);
  671. cb(cur, "ffn_moe_silu", il);
  672. } break;
  673. case LLM_FFN_GELU:
  674. {
  675. cur = ggml_gelu(ctx0, cur);
  676. cb(cur, "ffn_moe_gelu", il);
  677. } break;
  678. default:
  679. GGML_ABORT("fatal error");
  680. }
  681. if (gate_exps) {
  682. cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
  683. cb(cur, "ffn_moe_gate_par", il);
  684. }
  685. experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
  686. cb(experts, "ffn_moe_down", il);
  687. if (!weight_before_ffn) {
  688. experts = ggml_mul(ctx0, experts, weights);
  689. cb(cur, "ffn_moe_weighted", il);
  690. }
  691. // aggregate experts
  692. ggml_tensor * moe_out = nullptr;
  693. for (int i = 0; i < n_expert_used; ++i) {
  694. ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
  695. experts->nb[2], i*experts->nb[1]);
  696. if (i == 0) {
  697. moe_out = cur_expert;
  698. } else {
  699. moe_out = ggml_add(ctx0, moe_out, cur_expert);
  700. }
  701. }
  702. if (n_expert_used == 1) {
  703. // avoid returning a non-contiguous tensor
  704. moe_out = ggml_cont(ctx0, moe_out);
  705. }
  706. cb(moe_out, "ffn_moe_out", il);
  707. return moe_out;
  708. }
  709. // input embeddings with optional lora
  710. ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
  711. const int64_t n_embd = hparams.n_embd;
  712. auto inp = std::make_unique<llm_graph_input_embd>();
  713. ggml_tensor * cur = nullptr;
  714. if (ubatch.token) {
  715. inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
  716. //cb(inp->tokens, "inp_tokens", -1);
  717. ggml_set_input(inp->tokens);
  718. res->t_tokens = inp->tokens;
  719. cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
  720. // apply lora for embedding tokens if needed
  721. for (const auto & lora : *loras) {
  722. llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
  723. if (lw == nullptr) {
  724. continue;
  725. }
  726. const float adapter_scale = lora.second;
  727. const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
  728. ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
  729. ctx0, lw->b, // non-transposed lora_b
  730. ggml_get_rows(ctx0, lw->a, inp->tokens)
  731. ), scale);
  732. cur = ggml_add(ctx0, cur, inpL_delta);
  733. }
  734. } else {
  735. inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
  736. ggml_set_input(inp->embd);
  737. cur = inp->embd;
  738. }
  739. // For Granite architecture
  740. if (hparams.f_embedding_scale != 0.0f) {
  741. cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
  742. }
  743. cb(cur, "inp_embd", -1);
  744. res->add_input(std::move(inp));
  745. return cur;
  746. }
  747. ggml_tensor * llm_graph_context::build_inp_pos() const {
  748. auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
  749. auto & cur = inp->pos;
  750. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
  751. ggml_set_input(cur);
  752. res->add_input(std::move(inp));
  753. return cur;
  754. }
  755. ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
  756. auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
  757. auto & cur = inp->attn_scale;
  758. // this need to be 1x1xN for broadcasting
  759. cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
  760. ggml_set_input(cur);
  761. res->add_input(std::move(inp));
  762. return cur;
  763. }
  764. ggml_tensor * llm_graph_context::build_inp_out_ids() const {
  765. auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
  766. auto & cur = inp->out_ids;
  767. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
  768. ggml_set_input(cur);
  769. res->add_input(std::move(inp));
  770. return cur;
  771. }
  772. ggml_tensor * llm_graph_context::build_inp_mean() const {
  773. auto inp = std::make_unique<llm_graph_input_mean>(cparams);
  774. auto & cur = inp->mean;
  775. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
  776. ggml_set_input(cur);
  777. res->add_input(std::move(inp));
  778. return cur;
  779. }
  780. ggml_tensor * llm_graph_context::build_inp_cls() const {
  781. auto inp = std::make_unique<llm_graph_input_cls>(cparams);
  782. auto & cur = inp->cls;
  783. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
  784. ggml_set_input(cur);
  785. res->add_input(std::move(inp));
  786. return cur;
  787. }
  788. ggml_tensor * llm_graph_context::build_inp_s_copy() const {
  789. const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
  790. auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
  791. const auto n_kv = kv_state->get_n_kv();
  792. auto & cur = inp->s_copy;
  793. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
  794. ggml_set_input(cur);
  795. res->add_input(std::move(inp));
  796. return cur;
  797. }
  798. ggml_tensor * llm_graph_context::build_inp_s_mask() const {
  799. const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
  800. auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
  801. const auto n_kv = kv_state->get_n_kv();
  802. auto & cur = inp->s_mask;
  803. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
  804. ggml_set_input(cur);
  805. res->add_input(std::move(inp));
  806. return cur;
  807. }
  808. ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
  809. auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
  810. auto & cur = inp->cross_embd;
  811. // if we have the output embeddings from the encoder, use them directly
  812. // TODO: needs more work to be correct, for now just use the tensor shape
  813. //if (cross->t_embd) {
  814. // cur = ggml_view_tensor(ctx0, cross->t_embd);
  815. // return cur;
  816. //}
  817. const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd;
  818. const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
  819. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
  820. ggml_set_input(cur);
  821. res->add_input(std::move(inp));
  822. return cur;
  823. }
  824. ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
  825. auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);
  826. auto & cur = inp->pos_bucket;
  827. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
  828. ggml_set_input(cur);
  829. res->add_input(std::move(inp));
  830. return cur;
  831. }
  832. ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
  833. const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
  834. auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
  835. const auto n_kv = kv_state->get_n_kv();
  836. auto & cur = inp->pos_bucket;
  837. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
  838. ggml_set_input(cur);
  839. res->add_input(std::move(inp));
  840. return cur;
  841. }
  842. ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
  843. ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]);
  844. cb(pos_bucket_1d, "pos_bucket_1d", -1);
  845. ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
  846. pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
  847. pos_bias = ggml_permute (ctx0, pos_bias, 2, 0, 1, 3);
  848. pos_bias = ggml_cont (ctx0, pos_bias);
  849. cb(pos_bias, "pos_bias", -1);
  850. return pos_bias;
  851. }
  852. ggml_tensor * llm_graph_context::build_attn_mha(
  853. ggml_cgraph * gf,
  854. ggml_tensor * q,
  855. ggml_tensor * k,
  856. ggml_tensor * v,
  857. ggml_tensor * kq_b,
  858. ggml_tensor * kq_mask,
  859. ggml_tensor * v_mla,
  860. float kq_scale) const {
  861. const bool v_trans = v->nb[1] > v->nb[2];
  862. q = ggml_permute(ctx0, q, 0, 2, 1, 3);
  863. k = ggml_permute(ctx0, k, 0, 2, 1, 3);
  864. v = ggml_permute(ctx0, v, 0, 2, 1, 3);
  865. const auto n_tokens = q->ne[1];
  866. const auto n_head = q->ne[2];
  867. const auto n_kv = k->ne[1];
  868. ggml_tensor * cur;
  869. // TODO: replace hardcoded padding with ggml-provided padding
  870. if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
  871. GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
  872. if (v_trans) {
  873. v = ggml_transpose(ctx0, v);
  874. }
  875. // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
  876. if (k->type == GGML_TYPE_F32) {
  877. k = ggml_cast(ctx0, k, GGML_TYPE_F16);
  878. }
  879. if (v->type == GGML_TYPE_F32) {
  880. v = ggml_cast(ctx0, v, GGML_TYPE_F16);
  881. }
  882. cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
  883. hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
  884. ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
  885. if (v_mla) {
  886. #if 0
  887. // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
  888. // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
  889. cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
  890. cur = ggml_mul_mat(ctx0, v_mla, cur);
  891. #else
  892. // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
  893. // The permutations are noops and only change how the tensor data is interpreted.
  894. cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
  895. cur = ggml_mul_mat(ctx0, v_mla, cur);
  896. cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
  897. cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
  898. #endif
  899. }
  900. cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
  901. } else {
  902. ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
  903. // note: this op tends to require high floating point range
  904. // while for some models F16 is enough, for others it is not, so we default to F32 here
  905. ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
  906. if (arch == LLM_ARCH_GROK) {
  907. // need to do the following:
  908. // multiply by attn_output_multiplyer of 0.08838834764831845
  909. // and then :
  910. // kq = 30 * tanh(kq / 30)
  911. // before the softmax below
  912. kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
  913. kq = ggml_scale(ctx0, kq, 30);
  914. }
  915. if (hparams.attn_soft_cap) {
  916. kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
  917. kq = ggml_tanh (ctx0, kq);
  918. kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
  919. }
  920. if (kq_b) {
  921. kq = ggml_add(ctx0, kq, kq_b);
  922. }
  923. kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
  924. if (!v_trans) {
  925. // note: avoid this branch
  926. v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
  927. }
  928. ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
  929. // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
  930. if (v_mla) {
  931. kqv = ggml_mul_mat(ctx0, v_mla, kqv);
  932. }
  933. cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
  934. cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
  935. if (!cparams.offload_kqv) {
  936. // all nodes between the KV store and the attention output are run on the CPU
  937. ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
  938. }
  939. }
  940. ggml_build_forward_expand(gf, cur);
  941. return cur;
  942. }
  943. llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
  944. auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
  945. // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
  946. inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  947. //cb(inp_kq_mask, "KQ_mask", -1);
  948. ggml_set_input(inp->kq_mask);
  949. inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
  950. return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
  951. }
  952. ggml_tensor * llm_graph_context::build_attn(
  953. llm_graph_input_attn_no_cache * inp,
  954. ggml_cgraph * gf,
  955. ggml_tensor * wo,
  956. ggml_tensor * wo_b,
  957. ggml_tensor * q_cur,
  958. ggml_tensor * k_cur,
  959. ggml_tensor * v_cur,
  960. ggml_tensor * kq_b,
  961. ggml_tensor * v_mla,
  962. float kq_scale,
  963. int il) const {
  964. GGML_UNUSED(n_tokens);
  965. // these nodes are added to the graph together so that they are not reordered
  966. // by doing so, the number of splits in the graph is reduced
  967. ggml_build_forward_expand(gf, q_cur);
  968. ggml_build_forward_expand(gf, k_cur);
  969. ggml_build_forward_expand(gf, v_cur);
  970. const auto & kq_mask = inp->get_kq_mask();
  971. ggml_tensor * q = q_cur;
  972. ggml_tensor * k = k_cur;
  973. ggml_tensor * v = v_cur;
  974. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
  975. cb(cur, "kqv_out", il);
  976. if (wo) {
  977. cur = build_lora_mm(wo, cur);
  978. }
  979. if (wo_b) {
  980. //cb(cur, "kqv_wo", il);
  981. }
  982. if (wo_b) {
  983. cur = ggml_add(ctx0, cur, wo_b);
  984. }
  985. return cur;
  986. }
  987. llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
  988. const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
  989. auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
  990. {
  991. GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
  992. const auto n_kv = kv_state->get_n_kv();
  993. inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  994. //cb(inp->self_kq_mask, "KQ_mask", -1);
  995. ggml_set_input(inp->self_kq_mask);
  996. inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
  997. }
  998. return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
  999. }
  1000. ggml_tensor * llm_graph_context::build_attn(
  1001. llm_graph_input_attn_kv_unified * inp,
  1002. ggml_cgraph * gf,
  1003. ggml_tensor * wo,
  1004. ggml_tensor * wo_b,
  1005. ggml_tensor * q_cur,
  1006. ggml_tensor * k_cur,
  1007. ggml_tensor * v_cur,
  1008. ggml_tensor * kq_b,
  1009. ggml_tensor * v_mla,
  1010. float kq_scale,
  1011. int il) const {
  1012. // these nodes are added to the graph together so that they are not reordered
  1013. // by doing so, the number of splits in the graph is reduced
  1014. ggml_build_forward_expand(gf, q_cur);
  1015. ggml_build_forward_expand(gf, k_cur);
  1016. ggml_build_forward_expand(gf, v_cur);
  1017. const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
  1018. // store to KV cache
  1019. {
  1020. ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
  1021. ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
  1022. }
  1023. const auto & kq_mask = inp->get_kq_mask();
  1024. ggml_tensor * q = q_cur;
  1025. ggml_tensor * k = kv_state->get_k(ctx0, il);
  1026. ggml_tensor * v = kv_state->get_v(ctx0, il);
  1027. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
  1028. cb(cur, "kqv_out", il);
  1029. if (wo) {
  1030. cur = build_lora_mm(wo, cur);
  1031. if (arch == LLM_ARCH_GLM4) {
  1032. // GLM4 seems to have numerical issues with half-precision accumulators
  1033. ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
  1034. }
  1035. }
  1036. if (wo_b) {
  1037. cur = ggml_add(ctx0, cur, wo_b);
  1038. }
  1039. return cur;
  1040. }
  1041. llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
  1042. const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
  1043. auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
  1044. {
  1045. const auto n_kv = kv_state->get_base()->get_n_kv();
  1046. inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1047. //cb(inp->self_kq_mask, "KQ_mask", -1);
  1048. ggml_set_input(inp->self_kq_mask);
  1049. inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
  1050. }
  1051. {
  1052. GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
  1053. const auto n_kv = kv_state->get_swa()->get_n_kv();
  1054. inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1055. //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
  1056. ggml_set_input(inp->self_kq_mask_swa);
  1057. inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
  1058. }
  1059. return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
  1060. }
  1061. ggml_tensor * llm_graph_context::build_attn(
  1062. llm_graph_input_attn_kv_unified_iswa * inp,
  1063. ggml_cgraph * gf,
  1064. ggml_tensor * wo,
  1065. ggml_tensor * wo_b,
  1066. ggml_tensor * q_cur,
  1067. ggml_tensor * k_cur,
  1068. ggml_tensor * v_cur,
  1069. ggml_tensor * kq_b,
  1070. ggml_tensor * v_mla,
  1071. float kq_scale,
  1072. int il) const {
  1073. // these nodes are added to the graph together so that they are not reordered
  1074. // by doing so, the number of splits in the graph is reduced
  1075. ggml_build_forward_expand(gf, q_cur);
  1076. ggml_build_forward_expand(gf, k_cur);
  1077. ggml_build_forward_expand(gf, v_cur);
  1078. const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
  1079. const bool is_swa = hparams.is_swa(il);
  1080. const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
  1081. // store to KV cache
  1082. {
  1083. ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
  1084. ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
  1085. }
  1086. const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
  1087. ggml_tensor * q = q_cur;
  1088. ggml_tensor * k = kv_state->get_k(ctx0, il);
  1089. ggml_tensor * v = kv_state->get_v(ctx0, il);
  1090. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
  1091. cb(cur, "kqv_out", il);
  1092. if (wo) {
  1093. cur = build_lora_mm(wo, cur);
  1094. }
  1095. if (wo_b) {
  1096. //cb(cur, "kqv_wo", il);
  1097. }
  1098. if (wo_b) {
  1099. cur = ggml_add(ctx0, cur, wo_b);
  1100. }
  1101. return cur;
  1102. }
  1103. llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
  1104. auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
  1105. const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
  1106. inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1107. ggml_set_input(inp->cross_kq_mask);
  1108. inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
  1109. return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
  1110. }
  1111. ggml_tensor * llm_graph_context::build_attn(
  1112. llm_graph_input_attn_cross * inp,
  1113. ggml_cgraph * gf,
  1114. ggml_tensor * wo,
  1115. ggml_tensor * wo_b,
  1116. ggml_tensor * q_cur,
  1117. ggml_tensor * k_cur,
  1118. ggml_tensor * v_cur,
  1119. ggml_tensor * kq_b,
  1120. ggml_tensor * v_mla,
  1121. float kq_scale,
  1122. int il) const {
  1123. // these nodes are added to the graph together so that they are not reordered
  1124. // by doing so, the number of splits in the graph is reduced
  1125. ggml_build_forward_expand(gf, q_cur);
  1126. ggml_build_forward_expand(gf, k_cur);
  1127. ggml_build_forward_expand(gf, v_cur);
  1128. const auto & kq_mask = inp->get_kq_mask_cross();
  1129. ggml_tensor * q = q_cur;
  1130. ggml_tensor * k = k_cur;
  1131. ggml_tensor * v = v_cur;
  1132. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
  1133. cb(cur, "kqv_out", il);
  1134. if (wo) {
  1135. cur = build_lora_mm(wo, cur);
  1136. }
  1137. if (wo_b) {
  1138. //cb(cur, "kqv_wo", il);
  1139. }
  1140. if (wo_b) {
  1141. cur = ggml_add(ctx0, cur, wo_b);
  1142. }
  1143. return cur;
  1144. }
  1145. ggml_tensor * llm_graph_context::build_copy_mask_state(
  1146. ggml_cgraph * gf,
  1147. ggml_tensor * s,
  1148. ggml_tensor * state_copy,
  1149. ggml_tensor * state_mask,
  1150. int32_t n_state,
  1151. int32_t n_seqs) const {
  1152. const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
  1153. const auto n_kv = kv_state->get_n_kv();
  1154. const auto kv_head = kv_state->get_head();
  1155. ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
  1156. // copy states
  1157. // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
  1158. // this shrinks the tensors's ne[1] to n_kv
  1159. states = ggml_get_rows(ctx0, states, state_copy);
  1160. // clear states of sequences which are starting at the beginning of this batch
  1161. // FIXME: zero-out NANs?
  1162. states = ggml_mul(ctx0, states, state_mask);
  1163. // copy states which won't be changed further (between n_seqs and n_kv)
  1164. ggml_build_forward_expand(gf,
  1165. ggml_cpy(ctx0,
  1166. ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
  1167. ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
  1168. // the part of the states that will be used and modified
  1169. return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
  1170. }
  1171. ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
  1172. ggml_cgraph * gf,
  1173. ggml_tensor * state_copy,
  1174. ggml_tensor * state_mask,
  1175. const llama_ubatch & ubatch,
  1176. int il) const {
  1177. const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
  1178. const auto token_shift_count = hparams.token_shift_count;
  1179. const int64_t n_seqs = ubatch.n_seqs;
  1180. ggml_tensor * token_shift_all = kv_state->get_k_l(il);
  1181. ggml_tensor * token_shift = build_copy_mask_state(
  1182. gf, token_shift_all, state_copy, state_mask,
  1183. hparams.n_embd_k_s(), n_seqs);
  1184. token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
  1185. return token_shift;
  1186. }
  1187. ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
  1188. ggml_tensor * token_shift,
  1189. const llama_ubatch & ubatch,
  1190. int il) const {
  1191. const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
  1192. const auto token_shift_count = hparams.token_shift_count;
  1193. const auto n_embd = hparams.n_embd;
  1194. const int64_t n_seqs = ubatch.n_seqs;
  1195. const auto kv_head = kv_state->get_head();
  1196. return ggml_cpy(
  1197. ctx0,
  1198. ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
  1199. ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
  1200. );
  1201. }
  1202. void llm_graph_context::build_pooling(
  1203. ggml_cgraph * gf,
  1204. ggml_tensor * cls,
  1205. ggml_tensor * cls_b,
  1206. ggml_tensor * cls_out,
  1207. ggml_tensor * cls_out_b) const {
  1208. if (!cparams.embeddings) {
  1209. return;
  1210. }
  1211. ggml_tensor * inp = res->t_embd;
  1212. //// find result_norm tensor for input
  1213. //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
  1214. // inp = ggml_graph_node(gf, i);
  1215. // if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
  1216. // break;
  1217. // }
  1218. // inp = nullptr;
  1219. //}
  1220. GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
  1221. ggml_tensor * cur;
  1222. switch (pooling_type) {
  1223. case LLAMA_POOLING_TYPE_NONE:
  1224. {
  1225. cur = inp;
  1226. } break;
  1227. case LLAMA_POOLING_TYPE_MEAN:
  1228. {
  1229. ggml_tensor * inp_mean = build_inp_mean();
  1230. cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
  1231. } break;
  1232. case LLAMA_POOLING_TYPE_CLS:
  1233. case LLAMA_POOLING_TYPE_LAST:
  1234. {
  1235. ggml_tensor * inp_cls = build_inp_cls();
  1236. cur = ggml_get_rows(ctx0, inp, inp_cls);
  1237. } break;
  1238. case LLAMA_POOLING_TYPE_RANK:
  1239. {
  1240. ggml_tensor * inp_cls = build_inp_cls();
  1241. inp = ggml_get_rows(ctx0, inp, inp_cls);
  1242. if (cls != nullptr && cls_b != nullptr) {
  1243. // classification head
  1244. // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
  1245. cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
  1246. cur = ggml_tanh(ctx0, cur);
  1247. // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
  1248. // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
  1249. if (cls_out) {
  1250. GGML_ASSERT(cls_out_b != nullptr);
  1251. cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
  1252. }
  1253. } else if (cls_out) {
  1254. // Single layer classification head (direct projection)
  1255. // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
  1256. GGML_ASSERT(cls_out_b != nullptr);
  1257. cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
  1258. } else {
  1259. GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
  1260. }
  1261. } break;
  1262. default:
  1263. {
  1264. GGML_ABORT("unknown pooling type");
  1265. }
  1266. }
  1267. cb(cur, "result_embd_pooled", -1);
  1268. res->t_embd_pooled = cur;
  1269. ggml_build_forward_expand(gf, cur);
  1270. }
  1271. int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
  1272. // TODO move to hparams if a T5 variant appears that uses a different value
  1273. const int64_t max_distance = 128;
  1274. if (bidirectional) {
  1275. n_buckets >>= 1;
  1276. }
  1277. const int64_t max_exact = n_buckets >> 1;
  1278. int32_t relative_position = x - y;
  1279. int32_t relative_bucket = 0;
  1280. if (bidirectional) {
  1281. relative_bucket += (relative_position > 0) * n_buckets;
  1282. relative_position = abs(relative_position);
  1283. } else {
  1284. relative_position = -std::min<int32_t>(relative_position, 0);
  1285. }
  1286. int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
  1287. relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
  1288. relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
  1289. return relative_bucket;
  1290. }