llama-batch.cpp 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875
  1. #include "llama-batch.h"
  2. #include "llama-impl.h"
  3. #include "llama-vocab.h"
  4. #include "llama-memory.h"
  5. #include <cassert>
  6. #include <cstring>
  7. #include <algorithm>
  8. #include <sstream>
  9. llama_batch_allocr::llama_batch_allocr(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {
  10. const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
  11. debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
  12. seq_pos.resize(LLAMA_MAX_SEQ);
  13. seq_cpl.resize(LLAMA_MAX_SEQ);
  14. for (auto & cur : seq_cpl) {
  15. cur.resize(LLAMA_MAX_SEQ);
  16. }
  17. seq_idx.resize(LLAMA_MAX_SEQ, -1);
  18. }
  19. bool llama_batch_allocr::init(
  20. const llama_batch & batch_inp,
  21. const llama_vocab & vocab,
  22. const llama_memory_i * memory,
  23. uint32_t n_embd,
  24. uint32_t n_seq_max,
  25. bool output_all) {
  26. clear();
  27. batch = batch_inp;
  28. this->vocab = &vocab;
  29. GGML_ASSERT(batch.n_tokens > 0);
  30. //
  31. // validate input batch
  32. //
  33. if (n_seq_max > LLAMA_MAX_SEQ) {
  34. LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
  35. return false;
  36. }
  37. if (batch.token) {
  38. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  39. if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
  40. LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
  41. return false;
  42. }
  43. }
  44. }
  45. if (batch.seq_id) {
  46. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  47. for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
  48. if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
  49. LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d >= %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
  50. return false;
  51. }
  52. }
  53. }
  54. }
  55. //
  56. // auto-generate missing fields
  57. //
  58. if (!batch.n_seq_id) {
  59. n_seq_id.resize(batch.n_tokens);
  60. for (int32_t i = 0; i < batch.n_tokens; i++) {
  61. n_seq_id[i] = seq_id_0.size();
  62. }
  63. batch.n_seq_id = n_seq_id.data();
  64. }
  65. if (!batch.seq_id) {
  66. seq_id.resize(batch.n_tokens + 1);
  67. seq_id[batch.n_tokens] = NULL;
  68. for (int32_t i = 0; i < batch.n_tokens; i++) {
  69. seq_id[i] = seq_id_0.data();
  70. }
  71. batch.seq_id = seq_id.data();
  72. }
  73. if (!batch.pos) {
  74. pos.resize(batch.n_tokens);
  75. // initialize the starting position for each sequence based on the positions in the memory
  76. llama_pos p0[LLAMA_MAX_SEQ];
  77. for (uint32_t s = 0; s < n_seq_max; ++s) {
  78. if (!memory) {
  79. // if no memory -> start from 0
  80. p0[s] = 0;
  81. } else {
  82. p0[s] = memory->seq_pos_max(s) + 1;
  83. }
  84. }
  85. for (int32_t i = 0; i < batch.n_tokens; i++) {
  86. const llama_seq_id seq_id = batch.seq_id[i][0];
  87. pos[i] = p0[seq_id];
  88. // update the starting position for all sequences that are assigned to the this token
  89. for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
  90. const llama_seq_id seq_id = batch.seq_id[i][s];
  91. p0[seq_id] = pos[i] + 1;
  92. }
  93. }
  94. batch.pos = pos.data();
  95. }
  96. if (!batch.logits) {
  97. if (output_all) {
  98. // return the output for all tokens
  99. output.resize(batch.n_tokens, true);
  100. } else {
  101. // return the output only for the last token
  102. output.resize(batch.n_tokens, false);
  103. output[output.size() - 1] = true;
  104. }
  105. batch.logits = output.data();
  106. } else if (output_all) {
  107. bool warn = false;
  108. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  109. if (batch.logits[i] == 0) {
  110. warn = true;
  111. }
  112. }
  113. if (warn) {
  114. LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
  115. output.resize(batch.n_tokens, true);
  116. batch.logits = output.data();
  117. }
  118. }
  119. //
  120. // compute stats
  121. //
  122. this->n_embd = n_embd;
  123. this->n_seq_max = n_seq_max;
  124. // count the outputs in this batch
  125. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  126. n_outputs += batch.logits[i] != 0;
  127. }
  128. has_cpl = false;
  129. // determine coupled sequences
  130. // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
  131. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  132. const llama_seq_id s0 = batch.seq_id[i][0];
  133. for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
  134. const llama_seq_id s1 = batch.seq_id[i][s];
  135. seq_pos[s1].insert(batch.pos[i]);
  136. if (s > 0) {
  137. // mark that sequence s1 is coupled to s0
  138. seq_cpl[s1][s0] = true;
  139. // note: tracking the other way around is not necessary for now
  140. //seq_cpl[s0][s1] = true;
  141. has_cpl = true;
  142. }
  143. }
  144. }
  145. // precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
  146. {
  147. seq_set_t seq_set_unq;
  148. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  149. seq_set_t cur;
  150. for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
  151. const llama_seq_id seq_id = batch.seq_id[i][s];
  152. cur .set(seq_id);
  153. seq_set_unq.set(seq_id);
  154. }
  155. seq_set.push_back(cur);
  156. seq_set_map[cur].push_back(i);
  157. }
  158. for (uint32_t s = 0; s < n_seq_max; ++s) {
  159. if (seq_set_unq.test(s)) {
  160. seq_idx[s] = seq_id_unq.size();
  161. seq_id_unq.push_back(s);
  162. }
  163. }
  164. }
  165. if (debug > 0) {
  166. LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
  167. llama_ubatch ubatch {
  168. /*.b_equal_seqs =*/ false,
  169. /*.n_tokens =*/ (uint32_t) batch.n_tokens,
  170. /*.n_seq_tokens =*/ (uint32_t) 1,
  171. /*.n_seqs =*/ (uint32_t) batch.n_tokens,
  172. /*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
  173. /*.token =*/ batch.token,
  174. /*.embd =*/ batch.embd,
  175. /*.pos =*/ batch.pos,
  176. /*.n_seq_id =*/ batch.n_seq_id,
  177. /*.seq_id =*/ batch.seq_id,
  178. /*.seq_id_unq =*/ this->seq_id_unq.data(),
  179. /*.seq_idx =*/ this->seq_idx.data(),
  180. /*.output =*/ batch.logits,
  181. /*.data =*/ {},
  182. };
  183. ubatch_print(ubatch, debug);
  184. LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
  185. for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
  186. if (seq_pos[s0].empty()) {
  187. continue;
  188. }
  189. std::stringstream ss;
  190. for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
  191. if (seq_cpl[s0][s1]) {
  192. ss << s1 << " ";
  193. }
  194. }
  195. LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
  196. __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
  197. }
  198. LLAMA_LOG_DEBUG("%s: ]\n", __func__);
  199. }
  200. //
  201. // consistency checks
  202. //
  203. for (uint32_t s = 0; s < n_seq_max; ++s) {
  204. if (seq_pos[s].empty()) {
  205. continue;
  206. }
  207. const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
  208. if (p0 >= 0) {
  209. bool ok = true;
  210. if (batch.token) {
  211. if (seq_pos_min(s) != p0 + 1) {
  212. ok = false;
  213. }
  214. } else {
  215. assert(batch.embd);
  216. // for embeddings (typically used as vision input), we allow them to have repeating positions
  217. // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
  218. if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
  219. ok = false;
  220. }
  221. }
  222. if (!ok) {
  223. LLAMA_LOG_ERROR(
  224. "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
  225. " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
  226. " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
  227. " it is required that the sequence positions remain consecutive: Y = X + 1\n",
  228. __func__, s, s, p0, s, seq_pos_min(s));
  229. return false;
  230. }
  231. }
  232. if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
  233. LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
  234. return false;
  235. }
  236. }
  237. if (memory) {
  238. for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
  239. for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
  240. if (seq_cpl[s0][s1]) {
  241. if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
  242. memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
  243. LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
  244. return false;
  245. }
  246. }
  247. }
  248. }
  249. }
  250. // disallow partial sequence sub-sets:
  251. //
  252. // invalid: x
  253. // i: 0 1 2 ...
  254. // ---------------------------------------
  255. // seq_id[i][0]: 0 0 1
  256. // seq_id[i][1]: 1 1 2
  257. // seq_id[i][2]: 2
  258. //
  259. // disallow decreasing sequence positions:
  260. //
  261. // invalid: x
  262. // i: 0 1 2 3 4 5 6 ...
  263. // ---------------------------------------
  264. // pos[i]: 4 5 0 1 6 2 3
  265. // seq_id[i][0]: 0 0 1 1 0 1 0
  266. //
  267. {
  268. seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
  269. for (uint32_t s = 0; s < n_seq_max; ++s) {
  270. cur_seq_set[s].set();
  271. }
  272. llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
  273. for (uint32_t s = 0; s < n_seq_max; ++s) {
  274. cur_seq_pos[s] = -1;
  275. }
  276. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  277. const llama_pos pos = batch.pos[i];
  278. for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
  279. const llama_seq_id seq_id = batch.seq_id[i][s];
  280. cur_seq_set[seq_id] &= seq_set[i];
  281. if (cur_seq_set[seq_id].none()) {
  282. LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets (not allowed)\n", __func__, seq_id);
  283. return false;
  284. }
  285. if (pos < cur_seq_pos[seq_id]) {
  286. LLAMA_LOG_ERROR("%s: sequence %d positions are decreasing (not allowed)\n", __func__, seq_id);
  287. return false;
  288. }
  289. }
  290. }
  291. }
  292. split_reset();
  293. return true;
  294. }
  295. llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs) {
  296. const uint32_t n_tokens = n_seq_tokens*n_seqs;
  297. clear();
  298. split_reset();
  299. auto udata = std::make_shared<llama_ubatch::data_t>();
  300. udata->token .resize(n_tokens);
  301. udata->embd .clear();
  302. udata->pos .resize(n_tokens);
  303. udata->n_seq_id .resize(n_tokens);
  304. udata->seq_id .resize(n_tokens);
  305. udata->seq_id_unq.resize(0);
  306. udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
  307. udata->output .resize(n_tokens);
  308. for (uint32_t s = 0; s < n_seqs; ++s) {
  309. udata->seq_idx[s] = s;
  310. udata->seq_id_unq.push_back(s);
  311. }
  312. llama_ubatch res {
  313. /*.b_equal_seqs =*/ true,
  314. /*.n_tokens =*/ n_tokens,
  315. /*.n_seq_tokens =*/ n_seq_tokens,
  316. /*.n_seqs =*/ n_seqs,
  317. /*.n_seqs_unq =*/ n_seqs,
  318. /*.token =*/ udata->token.data(),
  319. /*.embd =*/ nullptr,
  320. /*.pos =*/ udata->pos.data(),
  321. /*.n_seq_id =*/ udata->n_seq_id.data(),
  322. /*.seq_id =*/ udata->seq_id.data(),
  323. /*.seq_id_unq =*/ udata->seq_id_unq.data(),
  324. /*.seq_idx =*/ udata->seq_idx.data(),
  325. /*.output =*/ udata->output.data(),
  326. /*.data =*/ std::move(udata),
  327. };
  328. return res;
  329. }
  330. const llama_batch & llama_batch_allocr::get_batch() const {
  331. return batch;
  332. }
  333. uint32_t llama_batch_allocr::get_n_tokens() const {
  334. return batch.n_tokens;
  335. }
  336. uint32_t llama_batch_allocr::get_n_outputs() const {
  337. return n_outputs;
  338. }
  339. uint32_t llama_batch_allocr::get_n_used() const {
  340. return n_used;
  341. }
  342. std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
  343. return out_ids;
  344. }
  345. llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
  346. return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
  347. }
  348. llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
  349. return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
  350. }
  351. void llama_batch_allocr::split_reset() {
  352. out_ids.clear();
  353. n_used = 0;
  354. used.clear();
  355. used.resize(get_n_tokens(), false);
  356. }
  357. llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
  358. // find the first unused token
  359. uint32_t cur_idx = 0;
  360. while (cur_idx < used.size() && used[cur_idx]) {
  361. ++cur_idx;
  362. }
  363. // we are done
  364. if (cur_idx >= used.size()) {
  365. return {};
  366. }
  367. std::vector<int32_t> idxs;
  368. while (true) {
  369. idxs.push_back(cur_idx);
  370. used[cur_idx] = true;
  371. ++n_used;
  372. ++cur_idx;
  373. if (cur_idx >= used.size()) {
  374. break;
  375. }
  376. if (idxs.size() >= n_ubatch) {
  377. break;
  378. }
  379. }
  380. return ubatch_add(idxs, idxs.size(), false);
  381. }
  382. llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
  383. if (sequential && has_cpl) {
  384. LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
  385. return {};
  386. }
  387. std::vector<seq_set_t> cur_seq_set;
  388. llama_seq_id last_seq_id = -1;
  389. // determine the non-overlapping sequence sets participating in this ubatch
  390. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  391. if (used[i]) {
  392. continue;
  393. }
  394. bool add = true;
  395. for (uint32_t s = 0; s < cur_seq_set.size(); ++s) {
  396. // no overlap with existing sequence sets:
  397. if (!(cur_seq_set[s] & seq_set[i]).none()) {
  398. add = false;
  399. break;
  400. }
  401. }
  402. // accept only increasing sequence ids
  403. if (sequential) {
  404. add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
  405. }
  406. if (add) {
  407. cur_seq_set.push_back(seq_set[i]);
  408. last_seq_id = batch.seq_id[i][0];
  409. if (cur_seq_set.size() > n_ubatch) {
  410. break;
  411. }
  412. }
  413. }
  414. const uint32_t n_seqs = cur_seq_set.size();
  415. // we are done
  416. if (n_seqs == 0) {
  417. return {};
  418. }
  419. // the current batch index of each sequence set
  420. std::vector<int32_t> cur_idx(n_seqs, 0);
  421. for (uint32_t s = 0; s < n_seqs; ++s) {
  422. while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
  423. ++cur_idx[s];
  424. }
  425. }
  426. // the list of batch indices for each sequence set
  427. // at the end we will concat these to get the final ubatch
  428. std::vector<idx_vec_t> idxs_per_seq(n_seqs);
  429. while (true) {
  430. // we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
  431. // if we haven't reached n_ubatch
  432. bool can_expand = true;
  433. for (uint32_t s = 0; s < n_seqs; ++s) {
  434. if (cur_idx[s] >= (int32_t) seq_set_map[cur_seq_set[s]].size()) {
  435. can_expand = false;
  436. break;
  437. }
  438. }
  439. if (!can_expand) {
  440. break;
  441. }
  442. for (uint32_t s = 0; s < n_seqs; ++s) {
  443. const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
  444. idxs_per_seq[s].push_back(idx);
  445. used[idx] = true;
  446. ++n_used;
  447. ++cur_idx[s];
  448. }
  449. if ((idxs_per_seq[0].size() + 1)*n_seqs > n_ubatch) {
  450. break;
  451. }
  452. }
  453. // concat the per-sequence-set lists
  454. std::vector<int32_t> idxs;
  455. for (uint32_t s = 0; s < n_seqs; ++s) {
  456. idxs.insert(idxs.end(), idxs_per_seq[s].begin(), idxs_per_seq[s].end());
  457. }
  458. return ubatch_add(idxs, n_seqs, true);
  459. }
  460. llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
  461. // find the first unused token
  462. uint32_t cur_idx = 0;
  463. while (cur_idx < used.size() && used[cur_idx]) {
  464. ++cur_idx;
  465. }
  466. // we are done
  467. if (cur_idx >= used.size()) {
  468. return {};
  469. }
  470. // this is the starting sequence set
  471. // we allow adding tokens only if their sequence set is a subset of the current sequence set
  472. auto cur_seq_set = seq_set[cur_idx];
  473. std::vector<int32_t> idxs;
  474. while (true) {
  475. idxs.push_back(cur_idx);
  476. used[cur_idx] = true;
  477. ++n_used;
  478. if (idxs.size() >= n_ubatch) {
  479. break;
  480. }
  481. do {
  482. ++cur_idx;
  483. } while (cur_idx < get_n_tokens() && (used[cur_idx] || ((cur_seq_set & seq_set[cur_idx]) != seq_set[cur_idx])));
  484. if (cur_idx == get_n_tokens()) {
  485. break;
  486. }
  487. cur_seq_set = seq_set[cur_idx];
  488. }
  489. return ubatch_add(idxs, 1, true);
  490. }
  491. void llama_batch_allocr::clear() {
  492. n_outputs = 0;
  493. batch = {};
  494. pos .clear();
  495. n_seq_id .clear();
  496. seq_id .clear();
  497. seq_id_unq.clear();
  498. output .clear();
  499. for (auto & cur : seq_pos) {
  500. cur.clear();
  501. }
  502. for (auto & cur : seq_cpl) {
  503. std::fill(cur.begin(), cur.end(), false);
  504. }
  505. seq_set.clear();
  506. seq_set_map.clear();
  507. std::fill(seq_idx.begin(), seq_idx.end(), -1);
  508. }
  509. llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs) {
  510. const uint32_t n_tokens = idxs.size();
  511. assert(n_tokens%n_seqs == 0);
  512. auto udata = std::make_shared<llama_ubatch::data_t>();
  513. const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
  514. const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
  515. const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
  516. udata->token .resize(n_tokens);
  517. udata->embd .resize(n_embd_all);
  518. udata->pos .resize(n_pos_all);
  519. udata->n_seq_id .resize(n_tokens);
  520. udata->seq_id .resize(n_tokens);
  521. udata->seq_id_unq.resize(0);
  522. udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
  523. udata->output .resize(n_tokens);
  524. seq_set_t seq_set_unq;
  525. for (size_t i = 0; i < idxs.size(); ++i) {
  526. if (batch.token) {
  527. udata->token[i] = batch.token[idxs[i]];
  528. }
  529. if (batch.embd) {
  530. memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
  531. }
  532. for (int j = 0; j < n_pos_cur; ++j) {
  533. udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
  534. }
  535. udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
  536. udata->seq_id[i] = batch.seq_id[idxs[i]];
  537. udata->output[i] = batch.logits[idxs[i]];
  538. for (int s = 0; s < udata->n_seq_id[i]; ++s) {
  539. seq_set_unq.set(udata->seq_id[i][s]);
  540. }
  541. if (udata->output[i]) {
  542. out_ids.push_back(idxs[i]);
  543. }
  544. }
  545. for (uint32_t s = 0; s < n_seq_max; ++s) {
  546. if (seq_set_unq.test(s)) {
  547. udata->seq_idx[s] = udata->seq_id_unq.size();
  548. udata->seq_id_unq.push_back(s);
  549. }
  550. }
  551. llama_ubatch res {
  552. /*.b_equal_seqs =*/ equal_seqs,
  553. /*.n_tokens =*/ n_tokens,
  554. /*.n_seq_tokens =*/ n_tokens/n_seqs,
  555. /*.n_seqs =*/ n_seqs,
  556. /*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
  557. /*.token =*/ batch.token ? udata->token.data() : nullptr,
  558. /*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
  559. /*.pos =*/ udata->pos.data(),
  560. /*.n_seq_id =*/ udata->n_seq_id.data(),
  561. /*.seq_id =*/ udata->seq_id.data(),
  562. /*.seq_id_unq =*/ udata->seq_id_unq.data(),
  563. /*.seq_idx =*/ udata->seq_idx.data(),
  564. /*.output =*/ udata->output.data(),
  565. /*.data =*/ std::move(udata),
  566. };
  567. if (debug > 0) {
  568. LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__);
  569. ubatch_print(res, debug);
  570. }
  571. return res;
  572. }
  573. void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
  574. if (debug > 0) {
  575. LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs());
  576. LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
  577. LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
  578. LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
  579. LLAMA_LOG_DEBUG("%s: n_seqs_unq = %d\n", __func__, ubatch.n_seqs_unq);
  580. std::stringstream ss_seq_id_unq;
  581. std::stringstream ss_seq_idx;
  582. ss_seq_id_unq << "[ ";
  583. ss_seq_idx << "[";
  584. for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
  585. ss_seq_id_unq << ubatch.seq_id_unq[s] << " ";
  586. }
  587. for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
  588. if (ubatch.seq_idx[s] >= 0) {
  589. ss_seq_idx << ubatch.seq_idx[s]%10;
  590. } else {
  591. ss_seq_idx << ".";
  592. }
  593. }
  594. ss_seq_id_unq << "]";
  595. ss_seq_idx << "]";
  596. LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token);
  597. LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd);
  598. LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos);
  599. LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id);
  600. LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id);
  601. LLAMA_LOG_DEBUG("%s: seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str());
  602. LLAMA_LOG_DEBUG("%s: seq_idx = %s\n", __func__, ss_seq_idx.str().c_str());
  603. LLAMA_LOG_DEBUG("%s: output = %p\n", __func__, (void *) ubatch.output);
  604. LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
  605. if (debug > 1) {
  606. int seq_id_max = 0;
  607. for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
  608. for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
  609. for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
  610. seq_id_max = std::max(seq_id_max, ubatch.seq_id[i][s]);
  611. }
  612. }
  613. }
  614. ++seq_id_max;
  615. LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
  616. for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
  617. std::vector<int8_t> seq_id(seq_id_max);
  618. for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
  619. seq_id[ubatch.seq_id[i][s]] = 1;
  620. }
  621. std::stringstream ss;
  622. for (int s = 0; s < seq_id_max; ++s) {
  623. if (seq_id[s]) {
  624. ss << s%10;
  625. } else {
  626. ss << ".";
  627. }
  628. }
  629. if (ubatch.token) {
  630. LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
  631. __func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
  632. ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
  633. } else {
  634. LLAMA_LOG_DEBUG("%s: %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
  635. __func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
  636. }
  637. }
  638. LLAMA_LOG_DEBUG("%s: ]\n", __func__);
  639. }
  640. }
  641. }
  642. //
  643. // interface implementation
  644. //
  645. struct llama_batch llama_batch_get_one(
  646. llama_token * tokens,
  647. int32_t n_tokens) {
  648. return {
  649. /*n_tokens =*/ n_tokens,
  650. /*tokens =*/ tokens,
  651. /*embd =*/ nullptr,
  652. /*pos =*/ nullptr,
  653. /*n_seq_id =*/ nullptr,
  654. /*seq_id =*/ nullptr,
  655. /*logits =*/ nullptr,
  656. };
  657. }
  658. struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
  659. llama_batch batch = {
  660. /*n_tokens =*/ 0,
  661. /*tokens =*/ nullptr,
  662. /*embd =*/ nullptr,
  663. /*pos =*/ nullptr,
  664. /*n_seq_id =*/ nullptr,
  665. /*seq_id =*/ nullptr,
  666. /*logits =*/ nullptr,
  667. };
  668. if (embd) {
  669. batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
  670. } else {
  671. batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
  672. }
  673. batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
  674. batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
  675. batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
  676. for (int i = 0; i < n_tokens_alloc; ++i) {
  677. batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
  678. }
  679. batch.seq_id[n_tokens_alloc] = nullptr;
  680. batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
  681. return batch;
  682. }
  683. void llama_batch_free(struct llama_batch batch) {
  684. if (batch.token) free(batch.token);
  685. if (batch.embd) free(batch.embd);
  686. if (batch.pos) free(batch.pos);
  687. if (batch.n_seq_id) free(batch.n_seq_id);
  688. if (batch.seq_id) {
  689. for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
  690. free(batch.seq_id[i]);
  691. }
  692. free(batch.seq_id);
  693. }
  694. if (batch.logits) free(batch.logits);
  695. }