llama-batch.cpp 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830
  1. #include "llama-batch.h"
  2. #include "llama-impl.h"
  3. #include "llama-vocab.h"
  4. #include "llama-memory.h"
  5. #include <cassert>
  6. #include <cstring>
  7. #include <algorithm>
  8. #include <sstream>
  9. llama_batch_allocr::llama_batch_allocr(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {
  10. const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
  11. debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
  12. seq_pos.resize(LLAMA_MAX_SEQ);
  13. seq_cpl.resize(LLAMA_MAX_SEQ);
  14. for (auto & cur : seq_cpl) {
  15. cur.resize(LLAMA_MAX_SEQ);
  16. }
  17. seq_idx.resize(LLAMA_MAX_SEQ, -1);
  18. }
  19. bool llama_batch_allocr::init(
  20. const llama_batch & batch_inp,
  21. const llama_vocab & vocab,
  22. const llama_memory_i * memory,
  23. uint32_t n_embd,
  24. bool output_all) {
  25. clear();
  26. batch = batch_inp;
  27. this->vocab = &vocab;
  28. GGML_ASSERT(batch.n_tokens > 0);
  29. //
  30. // validate input batch
  31. //
  32. if (batch.token) {
  33. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  34. if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
  35. LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
  36. return false;
  37. }
  38. }
  39. }
  40. if (batch.seq_id) {
  41. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  42. for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
  43. if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
  44. LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
  45. return false;
  46. }
  47. }
  48. }
  49. }
  50. //
  51. // auto-generate missing fields
  52. //
  53. if (!batch.n_seq_id) {
  54. n_seq_id.resize(batch.n_tokens);
  55. for (int32_t i = 0; i < batch.n_tokens; i++) {
  56. n_seq_id[i] = seq_id_0.size();
  57. }
  58. batch.n_seq_id = n_seq_id.data();
  59. }
  60. if (!batch.seq_id) {
  61. seq_id.resize(batch.n_tokens + 1);
  62. seq_id[batch.n_tokens] = NULL;
  63. for (int32_t i = 0; i < batch.n_tokens; i++) {
  64. seq_id[i] = seq_id_0.data();
  65. }
  66. batch.seq_id = seq_id.data();
  67. }
  68. if (!batch.pos) {
  69. pos.resize(batch.n_tokens);
  70. // initialize the starting position for each sequence based on the positions in the memory
  71. llama_pos p0[LLAMA_MAX_SEQ];
  72. for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
  73. if (!memory) {
  74. // if no memory -> start from 0
  75. p0[s] = 0;
  76. } else {
  77. p0[s] = memory->seq_pos_max(s) + 1;
  78. }
  79. }
  80. for (int32_t i = 0; i < batch.n_tokens; i++) {
  81. const llama_seq_id seq_id = batch.seq_id[i][0];
  82. pos[i] = p0[seq_id];
  83. // update the starting position for all sequences that are assigned to the this token
  84. for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
  85. const llama_seq_id seq_id = batch.seq_id[i][s];
  86. p0[seq_id] = pos[i] + 1;
  87. }
  88. }
  89. batch.pos = pos.data();
  90. }
  91. if (!batch.logits) {
  92. if (output_all) {
  93. // return the output for all tokens
  94. output.resize(batch.n_tokens, true);
  95. } else {
  96. // return the output only for the last token
  97. output.resize(batch.n_tokens, false);
  98. output[output.size() - 1] = true;
  99. }
  100. batch.logits = output.data();
  101. } else if (output_all) {
  102. bool warn = false;
  103. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  104. if (batch.logits[i] == 0) {
  105. warn = true;
  106. }
  107. }
  108. if (warn) {
  109. LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
  110. output.resize(batch.n_tokens, true);
  111. batch.logits = output.data();
  112. }
  113. }
  114. //
  115. // compute stats
  116. //
  117. this->n_embd = n_embd;
  118. // count the outputs in this batch
  119. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  120. n_outputs += batch.logits[i] != 0;
  121. }
  122. // determine coupled sequences
  123. // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
  124. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  125. const llama_seq_id s0 = batch.seq_id[i][0];
  126. for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
  127. const llama_seq_id s1 = batch.seq_id[i][s];
  128. seq_pos[s1].insert(batch.pos[i]);
  129. if (s > 0) {
  130. // mark that sequence s1 is coupled to s0
  131. seq_cpl[s1][s0] = true;
  132. // note: tracking the other way around is not necessary for now
  133. //seq_cpl[s0][s1] = true;
  134. }
  135. }
  136. }
  137. // precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
  138. {
  139. seq_set_t seq_set_unq;
  140. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  141. seq_set_t cur;
  142. for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
  143. const llama_seq_id seq_id = batch.seq_id[i][s];
  144. cur .set(seq_id);
  145. seq_set_unq.set(seq_id);
  146. }
  147. seq_set.push_back(cur);
  148. seq_set_map[cur].push_back(i);
  149. }
  150. for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
  151. if (seq_set_unq.test(s)) {
  152. seq_idx[s] = seq_id_unq.size();
  153. seq_id_unq.push_back(s);
  154. }
  155. }
  156. }
  157. if (debug > 0) {
  158. LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
  159. llama_ubatch ubatch {
  160. /*.equal_seqs =*/ false,
  161. /*.n_tokens =*/ (uint32_t) batch.n_tokens,
  162. /*.n_seq_tokens =*/ (uint32_t) 1,
  163. /*.n_seqs =*/ (uint32_t) batch.n_tokens,
  164. /*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
  165. /*.token =*/ batch.token,
  166. /*.embd =*/ batch.embd,
  167. /*.pos =*/ batch.pos,
  168. /*.n_seq_id =*/ batch.n_seq_id,
  169. /*.seq_id =*/ batch.seq_id,
  170. /*.seq_id_unq =*/ this->seq_id_unq.data(),
  171. /*.seq_idx =*/ this->seq_idx.data(),
  172. /*.output =*/ batch.logits,
  173. };
  174. ubatch_print(ubatch, debug);
  175. LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
  176. for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
  177. if (seq_pos[s0].empty()) {
  178. continue;
  179. }
  180. std::stringstream ss;
  181. for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
  182. if (seq_cpl[s0][s1]) {
  183. ss << s1 << " ";
  184. }
  185. }
  186. LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
  187. __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
  188. }
  189. LLAMA_LOG_DEBUG("%s: ]\n", __func__);
  190. }
  191. //
  192. // consistency checks
  193. //
  194. for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
  195. if (seq_pos[s].empty()) {
  196. continue;
  197. }
  198. if (memory) {
  199. if (batch.token) {
  200. if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
  201. LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
  202. return false;
  203. }
  204. } else {
  205. assert(batch.embd);
  206. // for embeddings (typically used as vision input), we allow them to have repeating positions
  207. // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
  208. if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
  209. LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
  210. return false;
  211. }
  212. }
  213. }
  214. if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
  215. LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
  216. return false;
  217. }
  218. }
  219. if (memory) {
  220. for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
  221. for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
  222. if (seq_cpl[s0][s1]) {
  223. if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
  224. memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
  225. LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
  226. return false;
  227. }
  228. }
  229. }
  230. }
  231. }
  232. // disallow partial sequence sub-sets:
  233. //
  234. // invalid: x
  235. // i: 0 1 2 ...
  236. // ---------------------------------------
  237. // seq_id[i][0]: 0 0 1
  238. // seq_id[i][1]: 1 1 2
  239. // seq_id[i][2]: 2
  240. //
  241. // disallow decreasing sequence positions:
  242. //
  243. // invalid: x
  244. // i: 0 1 2 3 4 5 6 ...
  245. // ---------------------------------------
  246. // pos[i]: 4 5 0 1 6 2 3
  247. // seq_id[i][0]: 0 0 1 1 0 1 0
  248. //
  249. {
  250. seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
  251. for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
  252. cur_seq_set[s].set();
  253. }
  254. llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
  255. for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
  256. cur_seq_pos[s] = -1;
  257. }
  258. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  259. const llama_pos pos = batch.pos[i];
  260. for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
  261. const llama_seq_id seq_id = batch.seq_id[i][s];
  262. cur_seq_set[seq_id] &= seq_set[i];
  263. if (cur_seq_set[seq_id].none()) {
  264. LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets (not allowed)\n", __func__, seq_id);
  265. return false;
  266. }
  267. if (pos < cur_seq_pos[seq_id]) {
  268. LLAMA_LOG_ERROR("%s: sequence %d positions are decreasing (not allowed)\n", __func__, seq_id);
  269. return false;
  270. }
  271. }
  272. }
  273. }
  274. split_reset();
  275. return true;
  276. }
  277. llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs) {
  278. const uint32_t n_tokens = n_seq_tokens*n_seqs;
  279. clear();
  280. split_reset();
  281. ubatches.emplace_back();
  282. auto & ubatch = ubatches.back();
  283. ubatch.token .resize(n_tokens);
  284. ubatch.embd .clear();
  285. ubatch.pos .resize(n_tokens);
  286. ubatch.n_seq_id .resize(n_tokens);
  287. ubatch.seq_id .resize(n_tokens);
  288. ubatch.seq_id_unq.resize(0);
  289. ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
  290. ubatch.output .resize(n_tokens);
  291. for (uint32_t s = 0; s < n_seqs; ++s) {
  292. ubatch.seq_idx[s] = s;
  293. ubatch.seq_id_unq.push_back(s);
  294. }
  295. llama_ubatch res {
  296. /*.equal_seqs =*/ true,
  297. /*.n_tokens =*/ n_tokens,
  298. /*.n_seq_tokens =*/ n_seq_tokens,
  299. /*.n_seqs =*/ n_seqs,
  300. /*.n_seqs_unq =*/ n_seqs,
  301. /*.token =*/ ubatch.token.data(),
  302. /*.embd =*/ nullptr,
  303. /*.pos =*/ ubatch.pos.data(),
  304. /*.n_seq_id =*/ ubatch.n_seq_id.data(),
  305. /*.seq_id =*/ ubatch.seq_id.data(),
  306. /*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
  307. /*.seq_idx =*/ ubatch.seq_idx.data(),
  308. /*.output =*/ ubatch.output.data(),
  309. };
  310. return res;
  311. }
  312. const llama_batch & llama_batch_allocr::get_batch() const {
  313. return batch;
  314. }
  315. uint32_t llama_batch_allocr::get_n_tokens() const {
  316. return batch.n_tokens;
  317. }
  318. uint32_t llama_batch_allocr::get_n_outputs() const {
  319. return n_outputs;
  320. }
  321. std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
  322. return out_ids;
  323. }
  324. llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
  325. return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
  326. }
  327. llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
  328. return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
  329. }
  330. void llama_batch_allocr::split_reset() {
  331. out_ids.clear();
  332. used.clear();
  333. used.resize(get_n_tokens(), false);
  334. ubatches.clear();
  335. }
  336. llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
  337. // find the first unused token
  338. uint32_t cur_idx = 0;
  339. while (cur_idx < used.size() && used[cur_idx]) {
  340. ++cur_idx;
  341. }
  342. // we are done
  343. if (cur_idx >= used.size()) {
  344. return {};
  345. }
  346. std::vector<int32_t> idxs;
  347. while (true) {
  348. idxs.push_back(cur_idx);
  349. used[cur_idx] = true;
  350. ++cur_idx;
  351. if (cur_idx >= used.size()) {
  352. break;
  353. }
  354. if (idxs.size() >= n_ubatch) {
  355. break;
  356. }
  357. }
  358. return ubatch_add(idxs, idxs.size(), false);
  359. }
  360. llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
  361. std::vector<seq_set_t> cur_seq_set;
  362. // determine the non-overlapping sequence sets participating in this ubatch
  363. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  364. if (used[i]) {
  365. continue;
  366. }
  367. bool add = true;
  368. for (uint32_t s = 0; s < cur_seq_set.size(); ++s) {
  369. // no overlap with existing sequence sets:
  370. if (!(cur_seq_set[s] & seq_set[i]).none()) {
  371. add = false;
  372. break;
  373. }
  374. }
  375. if (add) {
  376. cur_seq_set.push_back(seq_set[i]);
  377. if (cur_seq_set.size() > n_ubatch) {
  378. break;
  379. }
  380. }
  381. }
  382. const uint32_t n_seqs = cur_seq_set.size();
  383. // we are done
  384. if (n_seqs == 0) {
  385. return {};
  386. }
  387. // the current batch index of each sequence set
  388. std::vector<int32_t> cur_idx(n_seqs, 0);
  389. for (uint32_t s = 0; s < n_seqs; ++s) {
  390. while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
  391. ++cur_idx[s];
  392. }
  393. }
  394. // the list of batch indices for each sequence set
  395. // at the end we will concat these to get the final ubatch
  396. std::vector<idx_vec_t> idxs_per_seq(n_seqs);
  397. while (true) {
  398. // we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
  399. // if we haven't reached n_ubatch
  400. bool can_expand = true;
  401. for (uint32_t s = 0; s < n_seqs; ++s) {
  402. if (cur_idx[s] >= (int32_t) seq_set_map[cur_seq_set[s]].size()) {
  403. can_expand = false;
  404. break;
  405. }
  406. }
  407. if (!can_expand) {
  408. break;
  409. }
  410. for (uint32_t s = 0; s < n_seqs; ++s) {
  411. const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
  412. idxs_per_seq[s].push_back(idx);
  413. used[idx] = true;
  414. ++cur_idx[s];
  415. }
  416. if ((idxs_per_seq[0].size() + 1)*n_seqs > n_ubatch) {
  417. break;
  418. }
  419. }
  420. // concat the per-sequence-set lists
  421. std::vector<int32_t> idxs;
  422. for (uint32_t s = 0; s < n_seqs; ++s) {
  423. idxs.insert(idxs.end(), idxs_per_seq[s].begin(), idxs_per_seq[s].end());
  424. }
  425. return ubatch_add(idxs, n_seqs, true);
  426. }
  427. llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
  428. // find the first unused token
  429. uint32_t cur_idx = 0;
  430. while (cur_idx < used.size() && used[cur_idx]) {
  431. ++cur_idx;
  432. }
  433. // we are done
  434. if (cur_idx >= used.size()) {
  435. return {};
  436. }
  437. // this is the starting sequence set
  438. // we allow adding tokens only if their sequence set is a subset of the current sequence set
  439. auto cur_seq_set = seq_set[cur_idx];
  440. std::vector<int32_t> idxs;
  441. while (true) {
  442. idxs.push_back(cur_idx);
  443. used[cur_idx] = true;
  444. if (idxs.size() >= n_ubatch) {
  445. break;
  446. }
  447. do {
  448. ++cur_idx;
  449. } while (cur_idx < get_n_tokens() && (used[cur_idx] || ((cur_seq_set & seq_set[cur_idx]) != seq_set[cur_idx])));
  450. if (cur_idx == get_n_tokens()) {
  451. break;
  452. }
  453. cur_seq_set = seq_set[cur_idx];
  454. }
  455. return ubatch_add(idxs, 1, true);
  456. }
  457. void llama_batch_allocr::clear() {
  458. n_outputs = 0;
  459. batch = {};
  460. pos .clear();
  461. n_seq_id .clear();
  462. seq_id .clear();
  463. seq_id_unq.clear();
  464. output .clear();
  465. for (auto & cur : seq_pos) {
  466. cur.clear();
  467. }
  468. for (auto & cur : seq_cpl) {
  469. std::fill(cur.begin(), cur.end(), false);
  470. }
  471. seq_set.clear();
  472. seq_set_map.clear();
  473. std::fill(seq_idx.begin(), seq_idx.end(), -1);
  474. }
  475. llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs) {
  476. const uint32_t n_tokens = idxs.size();
  477. assert(n_tokens%n_seqs == 0);
  478. ubatches.emplace_back();
  479. auto & ubatch = ubatches.back();
  480. const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
  481. const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
  482. const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
  483. ubatch.token .resize(n_tokens);
  484. ubatch.embd .resize(n_embd_all);
  485. ubatch.pos .resize(n_pos_all);
  486. ubatch.n_seq_id .resize(n_tokens);
  487. ubatch.seq_id .resize(n_tokens);
  488. ubatch.seq_id_unq.resize(0);
  489. ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
  490. ubatch.output .resize(n_tokens);
  491. seq_set_t seq_set_unq;
  492. for (size_t i = 0; i < idxs.size(); ++i) {
  493. if (batch.token) {
  494. ubatch.token[i] = batch.token[idxs[i]];
  495. }
  496. if (batch.embd) {
  497. memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
  498. }
  499. for (int j = 0; j < n_pos_cur; ++j) {
  500. ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
  501. }
  502. ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
  503. ubatch.seq_id[i] = batch.seq_id[idxs[i]];
  504. ubatch.output[i] = batch.logits[idxs[i]];
  505. for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
  506. seq_set_unq.set(ubatch.seq_id[i][s]);
  507. }
  508. if (ubatch.output[i]) {
  509. out_ids.push_back(idxs[i]);
  510. }
  511. }
  512. for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
  513. if (seq_set_unq.test(s)) {
  514. ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
  515. ubatch.seq_id_unq.push_back(s);
  516. }
  517. }
  518. llama_ubatch res {
  519. /*.equal_seqs =*/ equal_seqs,
  520. /*.n_tokens =*/ n_tokens,
  521. /*.n_seq_tokens =*/ n_tokens/n_seqs,
  522. /*.n_seqs =*/ n_seqs,
  523. /*.n_seqs_unq =*/ (uint32_t) ubatch.seq_id_unq.size(),
  524. /*.token =*/ batch.token ? ubatch.token.data() : nullptr,
  525. /*.embd =*/ batch.embd ? ubatch.embd.data() : nullptr,
  526. /*.pos =*/ ubatch.pos.data(),
  527. /*.n_seq_id =*/ ubatch.n_seq_id.data(),
  528. /*.seq_id =*/ ubatch.seq_id.data(),
  529. /*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
  530. /*.seq_idx =*/ ubatch.seq_idx.data(),
  531. /*.output =*/ ubatch.output.data(),
  532. };
  533. if (debug > 0) {
  534. LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1);
  535. ubatch_print(res, debug);
  536. }
  537. return res;
  538. }
  539. void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
  540. if (debug > 0) {
  541. LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs);
  542. LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
  543. LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
  544. LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
  545. LLAMA_LOG_DEBUG("%s: n_seqs_unq = %d\n", __func__, ubatch.n_seqs_unq);
  546. std::stringstream ss_seq_id_unq;
  547. std::stringstream ss_seq_idx;
  548. ss_seq_id_unq << "[ ";
  549. ss_seq_idx << "[";
  550. for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
  551. ss_seq_id_unq << ubatch.seq_id_unq[s] << " ";
  552. }
  553. for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
  554. if (ubatch.seq_idx[s] >= 0) {
  555. ss_seq_idx << ubatch.seq_idx[s]%10;
  556. } else {
  557. ss_seq_idx << ".";
  558. }
  559. }
  560. ss_seq_id_unq << "]";
  561. ss_seq_idx << "]";
  562. LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token);
  563. LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd);
  564. LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos);
  565. LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id);
  566. LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id);
  567. LLAMA_LOG_DEBUG("%s: seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str());
  568. LLAMA_LOG_DEBUG("%s: seq_idx = %s\n", __func__, ss_seq_idx.str().c_str());
  569. LLAMA_LOG_DEBUG("%s: output = %p\n", __func__, (void *) ubatch.output);
  570. LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
  571. if (debug > 1) {
  572. int seq_id_max = 0;
  573. for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
  574. for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
  575. for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
  576. seq_id_max = std::max(seq_id_max, ubatch.seq_id[i][s]);
  577. }
  578. }
  579. }
  580. ++seq_id_max;
  581. LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
  582. for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
  583. std::vector<int8_t> seq_id(seq_id_max);
  584. for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
  585. seq_id[ubatch.seq_id[i][s]] = 1;
  586. }
  587. std::stringstream ss;
  588. for (int s = 0; s < seq_id_max; ++s) {
  589. if (seq_id[s]) {
  590. ss << s%10;
  591. } else {
  592. ss << ".";
  593. }
  594. }
  595. if (ubatch.token) {
  596. LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
  597. __func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
  598. ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
  599. } else {
  600. LLAMA_LOG_DEBUG("%s: %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
  601. __func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
  602. }
  603. }
  604. LLAMA_LOG_DEBUG("%s: ]\n", __func__);
  605. }
  606. }
  607. }
  608. //
  609. // interface implementation
  610. //
  611. struct llama_batch llama_batch_get_one(
  612. llama_token * tokens,
  613. int32_t n_tokens) {
  614. return {
  615. /*n_tokens =*/ n_tokens,
  616. /*tokens =*/ tokens,
  617. /*embd =*/ nullptr,
  618. /*pos =*/ nullptr,
  619. /*n_seq_id =*/ nullptr,
  620. /*seq_id =*/ nullptr,
  621. /*logits =*/ nullptr,
  622. };
  623. }
  624. struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
  625. llama_batch batch = {
  626. /*n_tokens =*/ 0,
  627. /*tokens =*/ nullptr,
  628. /*embd =*/ nullptr,
  629. /*pos =*/ nullptr,
  630. /*n_seq_id =*/ nullptr,
  631. /*seq_id =*/ nullptr,
  632. /*logits =*/ nullptr,
  633. };
  634. if (embd) {
  635. batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
  636. } else {
  637. batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
  638. }
  639. batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
  640. batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
  641. batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
  642. for (int i = 0; i < n_tokens_alloc; ++i) {
  643. batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
  644. }
  645. batch.seq_id[n_tokens_alloc] = nullptr;
  646. batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
  647. return batch;
  648. }
  649. void llama_batch_free(struct llama_batch batch) {
  650. if (batch.token) free(batch.token);
  651. if (batch.embd) free(batch.embd);
  652. if (batch.pos) free(batch.pos);
  653. if (batch.n_seq_id) free(batch.n_seq_id);
  654. if (batch.seq_id) {
  655. for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
  656. free(batch.seq_id[i]);
  657. }
  658. free(batch.seq_id);
  659. }
  660. if (batch.logits) free(batch.logits);
  661. }