llama-graph.cpp 54 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628
  1. #include "llama-graph.h"
  2. #include "llama-impl.h"
  3. #include "llama-batch.h"
  4. #include "llama-cparams.h"
  5. #include "llama-kv-cache.h"
  6. #include <cassert>
  7. #include <cmath>
  8. #include <cstring>
  9. static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
  10. // TODO move to hparams if a T5 variant appears that uses a different value
  11. const int64_t max_distance = 128;
  12. if (bidirectional) {
  13. n_buckets >>= 1;
  14. }
  15. const int64_t max_exact = n_buckets >> 1;
  16. int32_t relative_position = x - y;
  17. int32_t relative_bucket = 0;
  18. if (bidirectional) {
  19. relative_bucket += (relative_position > 0) * n_buckets;
  20. relative_position = abs(relative_position);
  21. } else {
  22. relative_position = -std::min<int32_t>(relative_position, 0);
  23. }
  24. int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
  25. relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
  26. relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
  27. return relative_bucket;
  28. }
  29. void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
  30. if (ubatch->token) {
  31. const int64_t n_tokens = ubatch->n_tokens;
  32. ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
  33. }
  34. if (ubatch->embd) {
  35. const int64_t n_embd = embd->ne[0];
  36. const int64_t n_tokens = ubatch->n_tokens;
  37. ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
  38. }
  39. }
  40. void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
  41. if (ubatch->pos && pos) {
  42. const int64_t n_tokens = ubatch->n_tokens;
  43. ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_token*ggml_element_size(pos));
  44. }
  45. }
  46. void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
  47. if (pos_bucket) {
  48. const int64_t n_tokens = ubatch->n_tokens;
  49. GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
  50. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  51. int32_t * data = (int32_t *) pos_bucket->data;
  52. for (int h = 0; h < 1; ++h) {
  53. for (int j = 0; j < n_tokens; ++j) {
  54. for (int i = 0; i < n_tokens; ++i) {
  55. data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
  56. }
  57. }
  58. }
  59. }
  60. }
  61. void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
  62. if (pos_bucket) {
  63. const int64_t n_tokens = ubatch->n_tokens;
  64. GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
  65. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  66. int32_t * data = (int32_t *) pos_bucket->data;
  67. const int64_t n_kv = kv_self->n;
  68. for (int h = 0; h < 1; ++h) {
  69. for (int j = 0; j < n_tokens; ++j) {
  70. for (int i = 0; i < n_kv; ++i) {
  71. data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
  72. }
  73. }
  74. }
  75. }
  76. }
  77. void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
  78. if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
  79. //GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
  80. if (!out_ids) {
  81. LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
  82. } else {
  83. const int64_t n_tokens = ubatch->n_tokens;
  84. GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
  85. int32_t * data = (int32_t *) out_ids->data;
  86. if (n_outputs == n_tokens) {
  87. for (int i = 0; i < n_tokens; ++i) {
  88. data[i] = i;
  89. }
  90. } else if (ubatch->output) {
  91. int32_t n_outputs = 0;
  92. for (int i = 0; i < n_tokens; ++i) {
  93. if (ubatch->output[i]) {
  94. data[n_outputs++] = i;
  95. }
  96. }
  97. // the graph needs to have been passed the correct number of outputs
  98. GGML_ASSERT(n_outputs == n_outputs);
  99. } else if (n_outputs == 1) {
  100. // only keep last output
  101. data[0] = n_tokens - 1;
  102. } else {
  103. GGML_ASSERT(n_outputs == 0);
  104. }
  105. }
  106. }
  107. }
  108. void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
  109. if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
  110. const int64_t n_tokens = ubatch->n_tokens;
  111. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  112. const int64_t n_seqs = ubatch->n_seqs;
  113. GGML_ASSERT(mean);
  114. GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
  115. float * data = (float *) mean->data;
  116. memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
  117. std::vector<uint64_t> sum(n_tokens, 0);
  118. for (int s = 0; s < n_seqs; ++s) {
  119. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  120. // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
  121. GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
  122. sum[seq_id] += ubatch->n_seq_tokens;
  123. }
  124. std::vector<float> div(n_tokens, 0.0f);
  125. for (int i = 0; i < n_tokens; ++i) {
  126. const uint64_t s = sum[i];
  127. if (s > 0) {
  128. div[i] = 1.0f/float(s);
  129. }
  130. }
  131. for (int s = 0; s < n_seqs; ++s) {
  132. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  133. for (int i = 0; i < n_seq_tokens; ++i) {
  134. data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
  135. }
  136. }
  137. }
  138. }
  139. void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
  140. if (cparams.embeddings && (
  141. cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
  142. cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
  143. const int64_t n_tokens = ubatch->n_tokens;
  144. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  145. const int64_t n_seqs = ubatch->n_seqs;
  146. GGML_ASSERT(cls);
  147. GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
  148. uint32_t * data = (uint32_t *) cls->data;
  149. memset(cls->data, 0, n_tokens * ggml_element_size(cls));
  150. for (int s = 0; s < n_seqs; ++s) {
  151. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  152. // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
  153. GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
  154. for (int i = 0; i < n_seq_tokens; ++i) {
  155. const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
  156. if (pos == 0) {
  157. data[seq_id] = s*n_seq_tokens + i;
  158. }
  159. }
  160. }
  161. }
  162. if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
  163. const int64_t n_tokens = ubatch->n_tokens;
  164. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  165. const int64_t n_seqs = ubatch->n_seqs;
  166. GGML_ASSERT(cls);
  167. GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
  168. uint32_t * data = (uint32_t *) cls->data;
  169. memset(cls->data, 0, n_tokens * ggml_element_size(cls));
  170. std::vector<int> last_pos(n_tokens, -1);
  171. std::vector<int> last_row(n_tokens, -1);
  172. for (int s = 0; s < n_seqs; ++s) {
  173. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  174. // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
  175. GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
  176. for (int i = 0; i < n_seq_tokens; ++i) {
  177. const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
  178. if (pos >= last_pos[seq_id]) {
  179. last_pos[seq_id] = pos;
  180. last_row[seq_id] = s*n_seq_tokens + i;
  181. }
  182. }
  183. }
  184. for (int i = 0; i < n_tokens; ++i) {
  185. if (last_row[i] >= 0) {
  186. data[i] = last_row[i];
  187. }
  188. }
  189. }
  190. }
  191. void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
  192. GGML_UNUSED(ubatch);
  193. const int64_t n_kv = kv_self->n;
  194. if (s_copy) {
  195. GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
  196. int32_t * data = (int32_t *) s_copy->data;
  197. // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
  198. for (uint32_t i = 0; i < n_kv; ++i) {
  199. const uint32_t cell_id = i + kv_self->head;
  200. //////////////////////////////////////////////
  201. // TODO: this should not mutate the KV cache !
  202. llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
  203. // prevent out-of-bound sources
  204. if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
  205. kv_cell.src = cell_id;
  206. }
  207. data[i] = kv_cell.src;
  208. // TODO: do not mutate the KV cache
  209. // ensure copy only happens once
  210. if (kv_cell.src != (int32_t) cell_id) {
  211. kv_cell.src = cell_id;
  212. }
  213. }
  214. }
  215. }
  216. void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
  217. GGML_UNUSED(ubatch);
  218. const int64_t n_kv = kv_self->n;
  219. if (s_mask) {
  220. GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
  221. float * data = (float *) s_mask->data;
  222. // clear unused states
  223. for (int i = 0; i < n_kv; ++i) {
  224. const uint32_t cell_id = i + kv_self->head;
  225. //////////////////////////////////////////////
  226. // TODO: this should not mutate the KV cache !
  227. llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
  228. data[i] = (float) (kv_cell.src >= 0);
  229. // only clear once
  230. if (kv_cell.src < 0) {
  231. kv_cell.src = cell_id;
  232. }
  233. }
  234. }
  235. }
  236. void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
  237. GGML_UNUSED(ubatch);
  238. if (cross_embd && !cross->v_embd.empty()) {
  239. assert(cross_embd->type == GGML_TYPE_F32);
  240. ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd));
  241. }
  242. }
  243. void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
  244. if (kq_mask) {
  245. if (cparams.causal_attn) {
  246. const int64_t n_kv = ubatch->n_tokens;
  247. const int64_t n_tokens = ubatch->n_tokens;
  248. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  249. const int64_t n_seqs = ubatch->n_seqs;
  250. GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
  251. float * data = (float *) kq_mask->data;
  252. for (int h = 0; h < 1; ++h) {
  253. for (int s1 = 0; s1 < n_seqs; ++s1) {
  254. const llama_seq_id seq_id = ubatch->seq_id[s1][0];
  255. for (int j = 0; j < n_seq_tokens; ++j) {
  256. const int32_t tj = s1*n_seq_tokens + j;
  257. for (int s0 = 0; s0 < n_seqs; ++s0) {
  258. for (int i = 0; i < n_seq_tokens; ++i) {
  259. const int32_t ti = s0*n_seq_tokens + i;
  260. float f = -INFINITY;
  261. for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
  262. if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
  263. if (hparams.use_alibi) {
  264. f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
  265. } else {
  266. f = 0.0f;
  267. }
  268. break;
  269. }
  270. }
  271. data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
  272. }
  273. }
  274. }
  275. }
  276. }
  277. } else {
  278. const int64_t n_tokens = ubatch->n_tokens;
  279. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  280. const int64_t n_seqs = ubatch->n_seqs;
  281. const int64_t n_stride = ubatch->n_tokens;
  282. GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
  283. float * data = (float *) kq_mask->data;
  284. for (int h = 0; h < 1; ++h) {
  285. for (int s1 = 0; s1 < n_seqs; ++s1) {
  286. const llama_seq_id seq_id = ubatch->seq_id[s1][0];
  287. for (int j = 0; j < n_seq_tokens; ++j) {
  288. const int32_t tj = s1*n_seq_tokens + j;
  289. for (int s0 = 0; s0 < n_seqs; ++s0) {
  290. for (int i = 0; i < n_seq_tokens; ++i) {
  291. const int32_t ti = s0*n_seq_tokens + i;
  292. float f = -INFINITY;
  293. for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
  294. if (ubatch->seq_id[s0][s] == seq_id) {
  295. if (hparams.use_alibi) {
  296. f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
  297. } else {
  298. f = 0.0f;
  299. }
  300. break;
  301. }
  302. }
  303. data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
  304. }
  305. }
  306. for (int i = n_tokens; i < n_stride; ++i) {
  307. data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
  308. }
  309. }
  310. }
  311. }
  312. }
  313. }
  314. }
  315. void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
  316. if (self_kq_mask || self_kq_mask_swa) {
  317. const int64_t n_kv = kv_self->n;
  318. const int64_t n_tokens = ubatch->n_tokens;
  319. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  320. const int64_t n_seqs = ubatch->n_seqs;
  321. float * data = nullptr;
  322. float * data_swa = nullptr;
  323. if (self_kq_mask) {
  324. GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
  325. data = (float *) self_kq_mask->data;
  326. }
  327. if (self_kq_mask_swa) {
  328. GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
  329. data_swa = (float *) self_kq_mask_swa->data;
  330. }
  331. // Use only the previous KV cells of the correct sequence for each token of the ubatch.
  332. // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
  333. // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
  334. // Causal mask:
  335. // xxx-------
  336. // xxxx------
  337. // xxxxx-----
  338. // Non-causal mask:
  339. // xxxxx-----
  340. // xxxxx-----
  341. // xxxxx-----
  342. // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
  343. for (int h = 0; h < 1; ++h) {
  344. for (int s = 0; s < n_seqs; ++s) {
  345. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  346. for (int j = 0; j < n_seq_tokens; ++j) {
  347. const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
  348. for (int i = 0; i < n_kv; ++i) {
  349. float f;
  350. // mask the token if:
  351. if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
  352. || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
  353. ) {
  354. f = -INFINITY;
  355. } else {
  356. if (hparams.use_alibi) {
  357. f = -std::abs(kv_self->cells[i].pos - pos);
  358. } else {
  359. f = 0.0f;
  360. }
  361. }
  362. if (data) {
  363. data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
  364. }
  365. // may need to cut off old tokens for sliding window
  366. if (data_swa) {
  367. if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
  368. f = -INFINITY;
  369. }
  370. data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
  371. }
  372. }
  373. }
  374. }
  375. // mask padded tokens
  376. if (data) {
  377. for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  378. for (int j = 0; j < n_kv; ++j) {
  379. data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
  380. }
  381. }
  382. }
  383. // mask padded tokens
  384. if (data_swa) {
  385. for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  386. for (int j = 0; j < n_kv; ++j) {
  387. data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
  388. }
  389. }
  390. }
  391. }
  392. }
  393. }
  394. void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
  395. if (cross_kq_mask) {
  396. const int64_t n_enc = cross_kq_mask->ne[0];
  397. const int64_t n_tokens = ubatch->n_tokens;
  398. GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
  399. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  400. float * data = (float *) cross_kq_mask->data;
  401. for (int h = 0; h < 1; ++h) {
  402. for (int j = 0; j < n_tokens; ++j) {
  403. for (int i = 0; i < n_enc; ++i) {
  404. float f = -INFINITY;
  405. for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
  406. const llama_seq_id seq_id = ubatch->seq_id[j][s];
  407. if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
  408. f = 0.0f;
  409. }
  410. }
  411. data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
  412. }
  413. }
  414. for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  415. for (int j = 0; j < n_enc; ++j) {
  416. data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
  417. }
  418. }
  419. }
  420. }
  421. }
  422. //
  423. // llm_graph_context
  424. //
  425. llm_graph_context::llm_graph_context(const llm_graph_params & params) :
  426. arch (params.arch),
  427. hparams (params.hparams),
  428. cparams (params.cparams),
  429. ubatch (params.ubatch),
  430. n_embd (hparams.n_embd),
  431. n_layer (hparams.n_layer),
  432. n_rot (hparams.n_rot),
  433. n_ctx (cparams.n_ctx),
  434. n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
  435. n_head (hparams.n_head()),
  436. n_head_kv (hparams.n_head_kv()),
  437. n_embd_head_k (hparams.n_embd_head_k),
  438. n_embd_k_gqa (hparams.n_embd_k_gqa()),
  439. n_embd_head_v (hparams.n_embd_head_v),
  440. n_embd_v_gqa (hparams.n_embd_v_gqa()),
  441. n_expert (hparams.n_expert),
  442. n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
  443. freq_base (cparams.rope_freq_base),
  444. freq_scale (cparams.rope_freq_scale),
  445. ext_factor (cparams.yarn_ext_factor),
  446. attn_factor (cparams.yarn_attn_factor),
  447. beta_fast (cparams.yarn_beta_fast),
  448. beta_slow (cparams.yarn_beta_slow),
  449. norm_eps (hparams.f_norm_eps),
  450. norm_rms_eps (hparams.f_norm_rms_eps),
  451. n_tokens (ubatch.n_tokens),
  452. n_outputs (params.n_outputs),
  453. n_ctx_orig (cparams.n_ctx_orig_yarn),
  454. pooling_type (cparams.pooling_type),
  455. rope_type (hparams.rope_type),
  456. ctx0 (params.ctx),
  457. sched (params.sched),
  458. backend_cpu (params.backend_cpu),
  459. cvec (params.cvec),
  460. loras (params.loras),
  461. memory (params.memory),
  462. cross (params.cross),
  463. cb_func (params.cb),
  464. res (std::make_unique<llm_graph_result>()) {
  465. }
  466. int64_t llm_graph_context::n_pos_per_token() const {
  467. return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
  468. }
  469. void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
  470. if (cb_func) {
  471. cb_func(ubatch, cur, name, il);
  472. }
  473. }
  474. ggml_tensor * llm_graph_context::build_cvec(
  475. ggml_tensor * cur,
  476. int il) const {
  477. return cvec->apply_to(ctx0, cur, il);
  478. }
  479. ggml_tensor * llm_graph_context::build_lora_mm(
  480. ggml_tensor * w,
  481. ggml_tensor * cur) const {
  482. ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
  483. for (const auto & lora : *loras) {
  484. llama_adapter_lora_weight * lw = lora.first->get_weight(w);
  485. if (lw == nullptr) {
  486. continue;
  487. }
  488. const float adapter_scale = lora.second;
  489. const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
  490. ggml_tensor * ab_cur = ggml_mul_mat(
  491. ctx0, lw->b,
  492. ggml_mul_mat(ctx0, lw->a, cur)
  493. );
  494. ab_cur = ggml_scale(ctx0, ab_cur, scale);
  495. res = ggml_add(ctx0, res, ab_cur);
  496. }
  497. return res;
  498. }
  499. ggml_tensor * llm_graph_context::build_lora_mm_id(
  500. ggml_tensor * w, // ggml_tensor * as
  501. ggml_tensor * cur, // ggml_tensor * b
  502. ggml_tensor * ids) const {
  503. ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
  504. for (const auto & lora : *loras) {
  505. llama_adapter_lora_weight * lw = lora.first->get_weight(w);
  506. if (lw == nullptr) {
  507. continue;
  508. }
  509. const float alpha = lora.first->alpha;
  510. const float rank = (float) lw->b->ne[0];
  511. const float scale = alpha ? lora.second * alpha / rank : lora.second;
  512. ggml_tensor * ab_cur = ggml_mul_mat_id(
  513. ctx0, lw->b,
  514. ggml_mul_mat_id(ctx0, lw->a, cur, ids),
  515. ids
  516. );
  517. ab_cur = ggml_scale(ctx0, ab_cur, scale);
  518. res = ggml_add(ctx0, res, ab_cur);
  519. }
  520. return res;
  521. }
  522. ggml_tensor * llm_graph_context::build_norm(
  523. ggml_tensor * cur,
  524. ggml_tensor * mw,
  525. ggml_tensor * mb,
  526. llm_norm_type type,
  527. int il) const {
  528. switch (type) {
  529. case LLM_NORM: cur = ggml_norm (ctx0, cur, hparams.f_norm_eps); break;
  530. case LLM_NORM_RMS: cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;
  531. case LLM_NORM_GROUP:
  532. {
  533. cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]);
  534. cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
  535. cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[2]);
  536. } break;
  537. }
  538. if (mw || mb) {
  539. cb(cur, "norm", il);
  540. }
  541. if (mw) {
  542. cur = ggml_mul(ctx0, cur, mw);
  543. if (mb) {
  544. cb(cur, "norm_w", il);
  545. }
  546. }
  547. if (mb) {
  548. cur = ggml_add(ctx0, cur, mb);
  549. }
  550. return cur;
  551. }
  552. ggml_tensor * llm_graph_context::build_ffn(
  553. ggml_tensor * cur,
  554. ggml_tensor * up,
  555. ggml_tensor * up_b,
  556. ggml_tensor * up_s,
  557. ggml_tensor * gate,
  558. ggml_tensor * gate_b,
  559. ggml_tensor * gate_s,
  560. ggml_tensor * down,
  561. ggml_tensor * down_b,
  562. ggml_tensor * down_s,
  563. ggml_tensor * act_scales,
  564. llm_ffn_op_type type_op,
  565. llm_ffn_gate_type type_gate,
  566. int il) const {
  567. ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
  568. cb(tmp, "ffn_up", il);
  569. if (up_b) {
  570. tmp = ggml_add(ctx0, tmp, up_b);
  571. cb(tmp, "ffn_up_b", il);
  572. }
  573. if (up_s) {
  574. tmp = ggml_mul(ctx0, tmp, up_s);
  575. cb(tmp, "ffn_up_s", il);
  576. }
  577. if (gate) {
  578. switch (type_gate) {
  579. case LLM_FFN_SEQ:
  580. {
  581. cur = build_lora_mm(gate, tmp);
  582. cb(cur, "ffn_gate", il);
  583. } break;
  584. case LLM_FFN_PAR:
  585. {
  586. cur = build_lora_mm(gate, cur);
  587. cb(cur, "ffn_gate", il);
  588. } break;
  589. }
  590. if (gate_b) {
  591. cur = ggml_add(ctx0, cur, gate_b);
  592. cb(cur, "ffn_gate_b", il);
  593. }
  594. if (gate_s) {
  595. cur = ggml_mul(ctx0, cur, gate_s);
  596. cb(cur, "ffn_gate_s", il);
  597. }
  598. } else {
  599. cur = tmp;
  600. }
  601. switch (type_op) {
  602. case LLM_FFN_SILU:
  603. {
  604. cur = ggml_silu(ctx0, cur);
  605. cb(cur, "ffn_silu", il);
  606. } break;
  607. case LLM_FFN_GELU:
  608. {
  609. cur = ggml_gelu(ctx0, cur);
  610. cb(cur, "ffn_gelu", il);
  611. if (act_scales != NULL) {
  612. cur = ggml_div(ctx0, cur, act_scales);
  613. cb(cur, "ffn_act", il);
  614. }
  615. } break;
  616. case LLM_FFN_RELU:
  617. {
  618. cur = ggml_relu(ctx0, cur);
  619. cb(cur, "ffn_relu", il);
  620. } break;
  621. case LLM_FFN_RELU_SQR:
  622. {
  623. cur = ggml_relu(ctx0, cur);
  624. cb(cur, "ffn_relu", il);
  625. cur = ggml_sqr(ctx0, cur);
  626. cb(cur, "ffn_sqr(relu)", il);
  627. } break;
  628. case LLM_FFN_SWIGLU:
  629. {
  630. // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
  631. int64_t split_point = cur->ne[0] / 2;
  632. ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
  633. ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
  634. x0 = ggml_silu(ctx0, x0);
  635. cb(cur, "ffn_silu", il);
  636. cur = ggml_mul(ctx0, x0, x1);
  637. cb(cur, "ffn_mul", il);
  638. } break;
  639. }
  640. if (type_gate == LLM_FFN_PAR) {
  641. cur = ggml_mul(ctx0, cur, tmp);
  642. cb(cur, "ffn_gate_par", il);
  643. }
  644. if (down) {
  645. cur = build_lora_mm(down, cur);
  646. }
  647. if (down_b) {
  648. cb(cur, "ffn_down", il);
  649. }
  650. if (down_b) {
  651. cur = ggml_add(ctx0, cur, down_b);
  652. }
  653. if (down_s) {
  654. cur = ggml_mul(ctx0, cur, down_s);
  655. cb(cur, "ffn_down_s", il);
  656. }
  657. return cur;
  658. }
  659. ggml_tensor * llm_graph_context::build_moe_ffn(
  660. ggml_tensor * cur,
  661. ggml_tensor * gate_inp,
  662. ggml_tensor * up_exps,
  663. ggml_tensor * gate_exps,
  664. ggml_tensor * down_exps,
  665. ggml_tensor * exp_probs_b,
  666. int64_t n_expert,
  667. int64_t n_expert_used,
  668. llm_ffn_op_type type_op,
  669. bool norm_w,
  670. bool scale_w,
  671. float w_scale,
  672. llama_expert_gating_func_type gating_op,
  673. int il) const {
  674. int64_t n_embd = cur->ne[0];
  675. int64_t n_tokens = cur->ne[1];
  676. ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
  677. cb(logits, "ffn_moe_logits", il);
  678. ggml_tensor * probs = nullptr;
  679. switch (gating_op) {
  680. case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
  681. {
  682. probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
  683. } break;
  684. case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
  685. {
  686. probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
  687. } break;
  688. default:
  689. GGML_ABORT("fatal error");
  690. }
  691. cb(probs, "ffn_moe_probs", il);
  692. // add experts selection bias - introduced in DeepSeek V3
  693. // leave probs unbiased as it's later used to get expert weights
  694. ggml_tensor * selection_probs = probs;
  695. if (exp_probs_b != nullptr) {
  696. selection_probs = ggml_add(ctx0, probs, exp_probs_b);
  697. cb(selection_probs, "ffn_moe_probs_biased", il);
  698. }
  699. // select experts
  700. ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
  701. cb(selected_experts->src[0], "ffn_moe_argsort", il);
  702. cb(selected_experts, "ffn_moe_topk", il);
  703. ggml_tensor * weights = ggml_get_rows(ctx0,
  704. ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
  705. cb(weights, "ffn_moe_weights", il);
  706. if (norm_w) {
  707. weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
  708. ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
  709. cb(weights_sum, "ffn_moe_weights_sum", il);
  710. weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
  711. cb(weights, "ffn_moe_weights_norm", il);
  712. weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
  713. }
  714. if (scale_w) {
  715. weights = ggml_scale(ctx0, weights, w_scale);
  716. cb(weights, "ffn_moe_weights_scaled", il);
  717. }
  718. cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
  719. ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
  720. cb(up, "ffn_moe_up", il);
  721. ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
  722. cb(gate, "ffn_moe_gate", il);
  723. switch (type_op) {
  724. case LLM_FFN_SILU:
  725. {
  726. gate = ggml_silu(ctx0, gate);
  727. cb(gate, "ffn_moe_silu", il);
  728. } break;
  729. case LLM_FFN_GELU:
  730. {
  731. gate = ggml_gelu(ctx0, gate);
  732. cb(gate, "ffn_moe_gelu", il);
  733. } break;
  734. default:
  735. GGML_ABORT("fatal error");
  736. }
  737. ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
  738. cb(par, "ffn_moe_gate_par", il);
  739. ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
  740. cb(experts, "ffn_moe_down", il);
  741. experts = ggml_mul(ctx0, experts, weights);
  742. // aggregate experts
  743. ggml_tensor * moe_out = nullptr;
  744. for (int i = 0; i < n_expert_used; ++i) {
  745. ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
  746. experts->nb[2], i*experts->nb[1]);
  747. if (i == 0) {
  748. moe_out = cur_expert;
  749. } else {
  750. moe_out = ggml_add(ctx0, moe_out, cur_expert);
  751. }
  752. }
  753. if (n_expert_used == 1) {
  754. // avoid returning a non-contiguous tensor
  755. moe_out = ggml_cont(ctx0, moe_out);
  756. }
  757. return moe_out;
  758. }
  759. // input embeddings with optional lora
  760. ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
  761. const int64_t n_embd = hparams.n_embd;
  762. auto inp = std::make_unique<llm_graph_input_embd>();
  763. ggml_tensor * cur = nullptr;
  764. if (ubatch.token) {
  765. inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
  766. //cb(inp->tokens, "inp_tokens", -1);
  767. ggml_set_input(inp->tokens);
  768. cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
  769. // apply lora for embedding tokens if needed
  770. for (const auto & lora : *loras) {
  771. llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
  772. if (lw == nullptr) {
  773. continue;
  774. }
  775. const float adapter_scale = lora.second;
  776. const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
  777. ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
  778. ctx0, lw->b, // non-transposed lora_b
  779. ggml_get_rows(ctx0, lw->a, inp->tokens)
  780. ), scale);
  781. cur = ggml_add(ctx0, cur, inpL_delta);
  782. }
  783. } else {
  784. inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
  785. ggml_set_input(inp->embd);
  786. cur = inp->embd;
  787. }
  788. // For Granite architecture
  789. if (hparams.f_embedding_scale != 0.0f) {
  790. cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
  791. }
  792. cb(cur, "inp_embd", -1);
  793. res->add_input(std::move(inp));
  794. return cur;
  795. }
  796. ggml_tensor * llm_graph_context::build_inp_pos() const {
  797. auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_token());
  798. auto & cur = inp->pos;
  799. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
  800. ggml_set_input(cur);
  801. res->add_input(std::move(inp));
  802. return cur;
  803. }
  804. ggml_tensor * llm_graph_context::build_inp_out_ids() const {
  805. auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
  806. auto & cur = inp->out_ids;
  807. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
  808. ggml_set_input(cur);
  809. res->add_input(std::move(inp));
  810. return cur;
  811. }
  812. ggml_tensor * llm_graph_context::build_inp_mean() const {
  813. auto inp = std::make_unique<llm_graph_input_mean>(cparams);
  814. auto & cur = inp->mean;
  815. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
  816. ggml_set_input(cur);
  817. res->add_input(std::move(inp));
  818. return cur;
  819. }
  820. ggml_tensor * llm_graph_context::build_inp_cls() const {
  821. auto inp = std::make_unique<llm_graph_input_cls>(cparams);
  822. auto & cur = inp->cls;
  823. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
  824. ggml_set_input(cur);
  825. res->add_input(std::move(inp));
  826. return cur;
  827. }
  828. ggml_tensor * llm_graph_context::build_inp_s_copy() const {
  829. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  830. auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
  831. const auto n_kv = kv_self->n;
  832. auto & cur = inp->s_copy;
  833. cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
  834. ggml_set_input(cur);
  835. res->add_input(std::move(inp));
  836. return cur;
  837. }
  838. ggml_tensor * llm_graph_context::build_inp_s_mask() const {
  839. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  840. auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
  841. const auto n_kv = kv_self->n;
  842. auto & cur = inp->s_mask;
  843. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
  844. ggml_set_input(cur);
  845. res->add_input(std::move(inp));
  846. return cur;
  847. }
  848. ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
  849. auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
  850. auto & cur = inp->cross_embd;
  851. // if we have the output embeddings from the encoder, use them directly
  852. // TODO: needs more work to be correct, for now just use the tensor shape
  853. //if (cross->t_embd) {
  854. // cur = ggml_view_tensor(ctx0, cross->t_embd);
  855. // return cur;
  856. //}
  857. const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd;
  858. const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
  859. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
  860. ggml_set_input(cur);
  861. res->add_input(std::move(inp));
  862. return cur;
  863. }
  864. ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
  865. auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);
  866. auto & cur = inp->pos_bucket;
  867. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
  868. ggml_set_input(cur);
  869. res->add_input(std::move(inp));
  870. return cur;
  871. }
  872. ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
  873. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  874. auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
  875. const auto n_kv = kv_self->n;
  876. auto & cur = inp->pos_bucket;
  877. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
  878. ggml_set_input(cur);
  879. res->add_input(std::move(inp));
  880. return cur;
  881. }
  882. ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
  883. ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]);
  884. cb(pos_bucket_1d, "pos_bucket_1d", -1);
  885. ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
  886. pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
  887. pos_bias = ggml_permute (ctx0, pos_bias, 2, 0, 1, 3);
  888. pos_bias = ggml_cont (ctx0, pos_bias);
  889. cb(pos_bias, "pos_bias", -1);
  890. return pos_bias;
  891. }
  892. ggml_tensor * llm_graph_context::build_attn_mha(
  893. ggml_cgraph * gf,
  894. ggml_tensor * q,
  895. ggml_tensor * k,
  896. ggml_tensor * v,
  897. ggml_tensor * kq_b,
  898. ggml_tensor * kq_mask,
  899. bool v_trans,
  900. float kq_scale) const {
  901. //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  902. //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  903. //const int64_t n_head = hparams.n_head(il);
  904. //const int64_t n_head_kv = hparams.n_head_kv(il);
  905. //const auto & n_embd_head_k = hparams.n_embd_head_k;
  906. //const auto & n_embd_head_v = hparams.n_embd_head_v;
  907. const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];
  908. const auto n_tokens = q->ne[1];
  909. const auto n_head = q->ne[2];
  910. const auto n_kv = k->ne[1];
  911. ggml_tensor * cur;
  912. // TODO: replace hardcoded padding with ggml-provided padding
  913. if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
  914. GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
  915. if (v_trans) {
  916. v = ggml_transpose(ctx0, v);
  917. }
  918. cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
  919. hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
  920. ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
  921. cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
  922. } else {
  923. ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
  924. // note: this op tends to require high floating point range
  925. // while for some models F16 is enough, for others it is not, so we default to F32 here
  926. ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
  927. if (arch == LLM_ARCH_GROK) {
  928. // need to do the following:
  929. // multiply by attn_output_multiplyer of 0.08838834764831845
  930. // and then :
  931. // kq = 30 * tanh(kq / 30)
  932. // before the softmax below
  933. kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
  934. kq = ggml_scale(ctx0, kq, 30);
  935. }
  936. if (hparams.attn_soft_cap) {
  937. kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
  938. kq = ggml_tanh (ctx0, kq);
  939. kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
  940. }
  941. if (kq_b) {
  942. kq = ggml_add(ctx0, kq, kq_b);
  943. }
  944. kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
  945. if (!v_trans) {
  946. // note: avoid this branch
  947. v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
  948. }
  949. ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
  950. ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
  951. cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
  952. if (!cparams.offload_kqv) {
  953. // all nodes between the KV store and the attention output are run on the CPU
  954. ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
  955. }
  956. }
  957. ggml_build_forward_expand(gf, cur);
  958. return cur;
  959. }
  960. llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
  961. auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
  962. // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
  963. inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  964. //cb(inp_kq_mask, "KQ_mask", -1);
  965. ggml_set_input(inp->kq_mask);
  966. inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
  967. return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
  968. }
  969. ggml_tensor * llm_graph_context::build_attn(
  970. llm_graph_input_attn_no_cache * inp,
  971. ggml_cgraph * gf,
  972. ggml_tensor * wo,
  973. ggml_tensor * wo_b,
  974. ggml_tensor * q_cur,
  975. ggml_tensor * k_cur,
  976. ggml_tensor * v_cur,
  977. ggml_tensor * kq_b,
  978. float kq_scale,
  979. int il) const {
  980. GGML_UNUSED(n_tokens);
  981. // these nodes are added to the graph together so that they are not reordered
  982. // by doing so, the number of splits in the graph is reduced
  983. ggml_build_forward_expand(gf, q_cur);
  984. ggml_build_forward_expand(gf, k_cur);
  985. ggml_build_forward_expand(gf, v_cur);
  986. const auto & kq_mask = inp->get_kq_mask();
  987. ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
  988. //cb(q, "q", il);
  989. ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
  990. //cb(k, "k", il);
  991. ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
  992. //cb(k, "v", il);
  993. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
  994. cb(cur, "kqv_out", il);
  995. if (wo) {
  996. cur = build_lora_mm(wo, cur);
  997. }
  998. if (wo_b) {
  999. //cb(cur, "kqv_wo", il);
  1000. }
  1001. if (wo_b) {
  1002. cur = ggml_add(ctx0, cur, wo_b);
  1003. }
  1004. return cur;
  1005. }
  1006. llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
  1007. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  1008. auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
  1009. const auto n_kv = kv_self->n;
  1010. inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1011. //cb(inp->self_kq_mask, "KQ_mask", -1);
  1012. ggml_set_input(inp->self_kq_mask);
  1013. inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
  1014. if (hparams.n_swa_pattern > 1) {
  1015. GGML_ASSERT(hparams.n_swa > 0);
  1016. inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1017. //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
  1018. ggml_set_input(inp->self_kq_mask_swa);
  1019. inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
  1020. }
  1021. return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
  1022. }
  1023. ggml_tensor * llm_graph_context::build_attn(
  1024. llm_graph_input_attn_kv_unified * inp,
  1025. ggml_cgraph * gf,
  1026. ggml_tensor * wo,
  1027. ggml_tensor * wo_b,
  1028. ggml_tensor * q_cur,
  1029. ggml_tensor * k_cur,
  1030. ggml_tensor * v_cur,
  1031. ggml_tensor * kq_b,
  1032. float kq_scale,
  1033. int il) const {
  1034. // these nodes are added to the graph together so that they are not reordered
  1035. // by doing so, the number of splits in the graph is reduced
  1036. ggml_build_forward_expand(gf, q_cur);
  1037. ggml_build_forward_expand(gf, k_cur);
  1038. ggml_build_forward_expand(gf, v_cur);
  1039. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  1040. const auto & n_ctx = cparams.n_ctx;
  1041. const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  1042. const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1043. const auto n_tokens = q_cur->ne[2];
  1044. const bool v_trans = !cparams.flash_attn;
  1045. // store to KV cache
  1046. {
  1047. GGML_ASSERT(!kv_self->recurrent);
  1048. const auto kv_head = kv_self->head;
  1049. GGML_ASSERT(kv_self->size == n_ctx);
  1050. ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
  1051. //cb(k_cache_view, "k_cache_view", il);
  1052. // note: storing RoPE-ed version of K in the KV cache
  1053. ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
  1054. v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
  1055. ggml_tensor * v_cache_view = nullptr;
  1056. if (!v_trans) {
  1057. v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
  1058. } else {
  1059. // note: the V cache is transposed when not using flash attention
  1060. v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
  1061. ( n_ctx)*ggml_element_size(kv_self->v_l[il]),
  1062. (kv_head)*ggml_element_size(kv_self->v_l[il]));
  1063. v_cur = ggml_transpose(ctx0, v_cur);
  1064. }
  1065. //cb(v_cache_view, "v_cache_view", il);
  1066. ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
  1067. }
  1068. const bool is_swa = hparams.is_swa(il);
  1069. const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
  1070. const auto n_kv = kv_self->n;
  1071. const int64_t n_head_kv = hparams.n_head_kv(il);
  1072. const auto & n_embd_head_k = hparams.n_embd_head_k;
  1073. const auto & n_embd_head_v = hparams.n_embd_head_v;
  1074. ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
  1075. //cb(q, "q", il);
  1076. ggml_tensor * k =
  1077. ggml_view_3d(ctx0, kv_self->k_l[il],
  1078. n_embd_head_k, n_kv, n_head_kv,
  1079. ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
  1080. ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
  1081. 0);
  1082. //cb(k, "k", il);
  1083. ggml_tensor * v = !v_trans ?
  1084. ggml_view_3d(ctx0, kv_self->v_l[il],
  1085. n_embd_head_v, n_kv, n_head_kv,
  1086. ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
  1087. ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
  1088. 0) :
  1089. ggml_view_3d(ctx0, kv_self->v_l[il],
  1090. n_kv, n_embd_head_v, n_head_kv,
  1091. ggml_element_size(kv_self->v_l[il])*n_ctx,
  1092. ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
  1093. 0);
  1094. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
  1095. cb(cur, "kqv_out", il);
  1096. if (wo) {
  1097. cur = build_lora_mm(wo, cur);
  1098. }
  1099. if (wo_b) {
  1100. //cb(cur, "kqv_wo", il);
  1101. }
  1102. if (wo_b) {
  1103. cur = ggml_add(ctx0, cur, wo_b);
  1104. }
  1105. return cur;
  1106. }
  1107. llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
  1108. auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
  1109. const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
  1110. inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
  1111. ggml_set_input(inp->cross_kq_mask);
  1112. inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
  1113. return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
  1114. }
  1115. ggml_tensor * llm_graph_context::build_attn(
  1116. llm_graph_input_attn_cross * inp,
  1117. ggml_cgraph * gf,
  1118. ggml_tensor * wo,
  1119. ggml_tensor * wo_b,
  1120. ggml_tensor * q_cur,
  1121. ggml_tensor * k_cur,
  1122. ggml_tensor * v_cur,
  1123. ggml_tensor * kq_b,
  1124. float kq_scale,
  1125. int il) const {
  1126. // these nodes are added to the graph together so that they are not reordered
  1127. // by doing so, the number of splits in the graph is reduced
  1128. ggml_build_forward_expand(gf, q_cur);
  1129. ggml_build_forward_expand(gf, k_cur);
  1130. ggml_build_forward_expand(gf, v_cur);
  1131. const auto & kq_mask = inp->get_kq_mask_cross();
  1132. ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
  1133. //cb(q, "q", il);
  1134. ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
  1135. //cb(k, "k", il);
  1136. ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
  1137. //cb(k, "v", il);
  1138. ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
  1139. cb(cur, "kqv_out", il);
  1140. if (wo) {
  1141. cur = build_lora_mm(wo, cur);
  1142. }
  1143. if (wo_b) {
  1144. //cb(cur, "kqv_wo", il);
  1145. }
  1146. if (wo_b) {
  1147. cur = ggml_add(ctx0, cur, wo_b);
  1148. }
  1149. return cur;
  1150. }
  1151. ggml_tensor * llm_graph_context::build_copy_mask_state(
  1152. ggml_cgraph * gf,
  1153. ggml_tensor * s,
  1154. ggml_tensor * state_copy,
  1155. ggml_tensor * state_mask,
  1156. int32_t n_state,
  1157. int32_t n_seqs) const {
  1158. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  1159. const auto n_kv = kv_self->n;
  1160. const auto kv_head = kv_self->head;
  1161. ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
  1162. // copy states
  1163. // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
  1164. // this shrinks the tensors's ne[1] to n_kv
  1165. states = ggml_get_rows(ctx0, states, state_copy);
  1166. // clear states of sequences which are starting at the beginning of this batch
  1167. // FIXME: zero-out NANs?
  1168. states = ggml_mul(ctx0, states, state_mask);
  1169. // copy states which won't be changed further (between n_seqs and n_kv)
  1170. ggml_build_forward_expand(gf,
  1171. ggml_cpy(ctx0,
  1172. ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
  1173. ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
  1174. // the part of the states that will be used and modified
  1175. return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
  1176. }
  1177. ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
  1178. ggml_cgraph * gf,
  1179. ggml_tensor * state_copy,
  1180. ggml_tensor * state_mask,
  1181. const llama_ubatch & ubatch,
  1182. int il) const {
  1183. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  1184. const auto token_shift_count = hparams.token_shift_count;
  1185. const int64_t n_seqs = ubatch.n_seqs;
  1186. ggml_tensor * token_shift_all = kv_self->k_l[il];
  1187. ggml_tensor * token_shift = build_copy_mask_state(
  1188. gf, token_shift_all, state_copy, state_mask,
  1189. hparams.n_embd_k_s(), n_seqs);
  1190. token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
  1191. return token_shift;
  1192. }
  1193. ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
  1194. ggml_tensor * token_shift,
  1195. const llama_ubatch & ubatch,
  1196. int il) const {
  1197. const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
  1198. const auto token_shift_count = hparams.token_shift_count;
  1199. const auto n_embd = hparams.n_embd;
  1200. const int64_t n_seqs = ubatch.n_seqs;
  1201. const auto kv_head = kv_self->head;
  1202. return ggml_cpy(
  1203. ctx0,
  1204. ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
  1205. ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
  1206. );
  1207. }
  1208. void llm_graph_context::build_pooling(
  1209. ggml_cgraph * gf,
  1210. ggml_tensor * cls,
  1211. ggml_tensor * cls_b,
  1212. ggml_tensor * cls_out,
  1213. ggml_tensor * cls_out_b) const {
  1214. if (!cparams.embeddings) {
  1215. return;
  1216. }
  1217. ggml_tensor * inp = res->t_embd;
  1218. //// find result_norm tensor for input
  1219. //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
  1220. // inp = ggml_graph_node(gf, i);
  1221. // if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
  1222. // break;
  1223. // }
  1224. // inp = nullptr;
  1225. //}
  1226. GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
  1227. ggml_tensor * cur;
  1228. switch (pooling_type) {
  1229. case LLAMA_POOLING_TYPE_NONE:
  1230. {
  1231. cur = inp;
  1232. } break;
  1233. case LLAMA_POOLING_TYPE_MEAN:
  1234. {
  1235. ggml_tensor * inp_mean = build_inp_mean();
  1236. cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
  1237. } break;
  1238. case LLAMA_POOLING_TYPE_CLS:
  1239. case LLAMA_POOLING_TYPE_LAST:
  1240. {
  1241. ggml_tensor * inp_cls = build_inp_cls();
  1242. cur = ggml_get_rows(ctx0, inp, inp_cls);
  1243. } break;
  1244. case LLAMA_POOLING_TYPE_RANK:
  1245. {
  1246. ggml_tensor * inp_cls = build_inp_cls();
  1247. inp = ggml_get_rows(ctx0, inp, inp_cls);
  1248. // classification head
  1249. // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
  1250. GGML_ASSERT(cls != nullptr);
  1251. GGML_ASSERT(cls_b != nullptr);
  1252. cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
  1253. cur = ggml_tanh(ctx0, cur);
  1254. // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
  1255. // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
  1256. if (cls_out) {
  1257. GGML_ASSERT(cls_out_b != nullptr);
  1258. cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
  1259. }
  1260. } break;
  1261. default:
  1262. {
  1263. GGML_ABORT("unknown pooling type");
  1264. }
  1265. }
  1266. cb(cur, "result_embd_pooled", -1);
  1267. res->t_embd_pooled = cur;
  1268. ggml_build_forward_expand(gf, cur);
  1269. }