llama-batch.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. #include "llama-batch.h"
  2. #include "llama-impl.h"
  3. #include "llama-cparams.h"
  4. #include "llama-vocab.h"
  5. #include <cassert>
  6. #include <cstring>
  7. #include <algorithm>
  8. #include <sstream>
  9. llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
  10. // clear empty sequences
  11. // the previous ubatch is assumed to be gone,
  12. // so nothing should refer to values in these sequences anymore.
  13. for (size_t i = seq.size(); i-- > 0;) {
  14. if (seq[i].length == 0) {
  15. seq.pop_back();
  16. } else {
  17. break;
  18. }
  19. }
  20. udatas.push_back({});
  21. auto & udata = udatas.back();
  22. udata.token.resize(!has_embd ? n_ubatch : 0);
  23. udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
  24. udata.pos.resize(n_ubatch);
  25. udata.n_seq_id.resize(n_ubatch);
  26. udata.seq_id.resize(n_ubatch);
  27. udata.output.resize(n_ubatch);
  28. llama_ubatch ubatch = {
  29. /*equal_seqs =*/ true,
  30. /*n_tokens =*/ 0,
  31. /*n_seq_tokens =*/ 0,
  32. /*n_seqs =*/ 0,
  33. /*token =*/ !has_embd ? udata.token.data() : nullptr,
  34. /*embd =*/ has_embd ? udata.embd.data() : nullptr,
  35. /*pos =*/ udata.pos.data(),
  36. /*n_seq_id =*/ udata.n_seq_id.data(),
  37. /*seq_id =*/ udata.seq_id.data(),
  38. /*output =*/ udata.output.data(),
  39. };
  40. return ubatch;
  41. }
  42. void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
  43. GGML_ASSERT(batch != nullptr);
  44. GGML_ASSERT(length <= seq.length);
  45. // Can only add sequences of equal lengths to a batch,
  46. // otherwise it isn't clear to which sequence a token belongs
  47. GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
  48. GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
  49. // NOTE: loops are separated for cache-friendliness
  50. if (batch->token) {
  51. if (ubatch.equal_seqs) {
  52. for (size_t i = 0; i < length; ++i) {
  53. ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
  54. }
  55. } else {
  56. // simple split
  57. ubatch.token = batch->token + seq.offset;
  58. }
  59. } else {
  60. ubatch.token = nullptr;
  61. }
  62. if (batch->embd) {
  63. if (ubatch.equal_seqs) {
  64. for (size_t i = 0; i < length; ++i) {
  65. memcpy(
  66. ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
  67. batch->embd + (n_embd * ids[seq.offset + i]),
  68. n_embd * sizeof(float)
  69. );
  70. }
  71. } else {
  72. // simple split
  73. ubatch.embd = batch->embd + (n_embd * seq.offset);
  74. }
  75. } else {
  76. ubatch.embd = nullptr;
  77. }
  78. if (ubatch.equal_seqs) {
  79. for (size_t i = 0; i < length; ++i) {
  80. ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
  81. }
  82. } else {
  83. // simple split
  84. ubatch.pos = batch->pos + seq.offset;
  85. }
  86. if (ubatch.equal_seqs) {
  87. ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
  88. if (seq.seq_id) {
  89. ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
  90. }
  91. } else {
  92. // simple split
  93. if (batch->n_seq_id) {
  94. ubatch.n_seq_id = batch->n_seq_id + seq.offset;
  95. } else {
  96. for (size_t i = 0; i < length; ++i) {
  97. ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
  98. }
  99. }
  100. if (batch->seq_id) {
  101. ubatch.seq_id = batch->seq_id + seq.offset;
  102. }
  103. }
  104. if (batch->logits) {
  105. if (ubatch.equal_seqs) {
  106. for (size_t i = 0; i < length; ++i) {
  107. size_t id = ids[seq.offset + i];
  108. int8_t is_output = batch->logits[id];
  109. ubatch.output[ubatch.n_tokens + i] = is_output;
  110. if (is_output) { out_ids.push_back(id); }
  111. }
  112. } else {
  113. // simple split
  114. ubatch.output = batch->logits + seq.offset;
  115. for (size_t i = 0; i < length; ++i) {
  116. if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
  117. }
  118. }
  119. } else {
  120. // only get last output
  121. for (size_t i = 0; i < length; ++i) {
  122. size_t id = ids[seq.offset + i];
  123. int8_t is_last = id == ids.size() - 1;
  124. ubatch.output[ubatch.n_tokens + i] = is_last;
  125. if (is_last) { out_ids.push_back(id); }
  126. }
  127. }
  128. if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
  129. ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
  130. }
  131. ubatch.n_tokens += length;
  132. ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
  133. seq.offset += length;
  134. seq.length -= length;
  135. n_tokens -= length;
  136. GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
  137. }
  138. llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
  139. n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
  140. llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
  141. ubatch.equal_seqs = false;
  142. if (!seq.empty()) {
  143. llama_sbatch_seq & s = seq[0];
  144. size_t length = s.length < n_ubatch ? s.length : n_ubatch;
  145. GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
  146. add_seq_to_ubatch(ubatch, s, length);
  147. }
  148. return ubatch;
  149. }
  150. llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
  151. n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
  152. llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
  153. if (!seq.empty()) {
  154. size_t length = 0;
  155. size_t n_tokens_in_ubatch = 0;
  156. GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
  157. // smallest first, because it's easier to split this way;
  158. // starting from the end to pop in constant time.
  159. for (size_t i = seq.size(); i-- > 0;) {
  160. llama_sbatch_seq & s = seq[i];
  161. GGML_ASSERT(s.length > 0);
  162. if (length == 0) {
  163. length = s.length < n_ubatch ? s.length : n_ubatch;
  164. }
  165. add_seq_to_ubatch(ubatch, s, length);
  166. n_tokens_in_ubatch += length;
  167. // shared prompts can't be mixed with any of their sequences,
  168. // so it's safer to compute them in their own ubatch
  169. if (s.n_seq_id > 1) { break; }
  170. // stop when there isn't enough space for another sequence
  171. if (length + n_tokens_in_ubatch > n_ubatch) { break; }
  172. }
  173. }
  174. return ubatch;
  175. }
  176. llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
  177. n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
  178. llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
  179. if (!seq.empty()) {
  180. llama_sbatch_seq & s = seq[seq.size() - 1];
  181. size_t length = s.length < n_ubatch ? s.length : n_ubatch;
  182. GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
  183. add_seq_to_ubatch(ubatch, s, length);
  184. }
  185. return ubatch;
  186. }
  187. llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
  188. GGML_ASSERT(batch.n_tokens >= 0);
  189. this->batch = &batch;
  190. this->n_embd = n_embd;
  191. n_tokens = batch.n_tokens;
  192. ids.resize(n_tokens);
  193. out_ids.clear();
  194. // TODO: reserve out_ids and seq
  195. for (size_t i = 0; i < n_tokens; ++i) {
  196. ids[i] = i;
  197. }
  198. if (simple_split) {
  199. seq.resize(1);
  200. llama_sbatch_seq & s = seq[0];
  201. s.n_seq_id = 0;
  202. s.seq_id = nullptr;
  203. s.offset = 0;
  204. s.length = n_tokens;
  205. return;
  206. }
  207. std::sort(ids.begin(), ids.end(),
  208. [&batch](size_t a, size_t b) {
  209. int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
  210. int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
  211. // sort by seq_id, then by pos
  212. if (n_seq_a == n_seq_b) {
  213. if (batch.seq_id) {
  214. for (int32_t i = 0; i < n_seq_a; ++i) {
  215. llama_seq_id seq_id_a = batch.seq_id[a][i];
  216. llama_seq_id seq_id_b = batch.seq_id[b][i];
  217. // smaller seq_ids go first
  218. if (seq_id_a != seq_id_b) {
  219. return seq_id_a < seq_id_b;
  220. }
  221. }
  222. }
  223. // when all else is equal, sort by pos
  224. if (batch.pos) {
  225. return batch.pos[a] < batch.pos[b];
  226. }
  227. // no pos, sort by id
  228. return a < b;
  229. }
  230. // shared prompts go first
  231. return n_seq_a > n_seq_b;
  232. }
  233. );
  234. // init seq
  235. llama_sbatch_seq * last_seq = nullptr;
  236. for (size_t i = 0; i < n_tokens; ++i) {
  237. const size_t bi = ids[i];
  238. const int32_t n_seqs = batch.n_seq_id[bi];
  239. llama_seq_id * seq_ids = batch.seq_id[bi];
  240. if (last_seq != nullptr) {
  241. bool same = n_seqs == last_seq->n_seq_id;
  242. for (int32_t j = 0; same && j < n_seqs; ++j) {
  243. if (seq_ids[j] != last_seq->seq_id[j]) {
  244. same = false;
  245. }
  246. }
  247. if (same) {
  248. last_seq->length += 1;
  249. continue;
  250. }
  251. }
  252. llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
  253. seq.push_back(new_seq);
  254. last_seq = &seq.back();
  255. }
  256. // keep shared prompts first at the end, then sort by length descending.
  257. std::sort(seq.begin(), seq.end(),
  258. [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
  259. if (a.n_seq_id == b.n_seq_id) {
  260. return a.length > b.length;
  261. }
  262. return a.n_seq_id < b.n_seq_id;
  263. }
  264. );
  265. }
  266. llama_batch_allocr::llama_batch_allocr() {
  267. const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
  268. debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
  269. }
  270. bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
  271. clear();
  272. batch = batch_inp;
  273. GGML_ASSERT(batch.n_tokens > 0);
  274. if (!batch.pos) {
  275. if (batch.seq_id) {
  276. LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
  277. return false;
  278. }
  279. }
  280. if (batch.token) {
  281. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  282. if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
  283. LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
  284. return false;
  285. }
  286. }
  287. }
  288. if (batch.seq_id) {
  289. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  290. for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
  291. if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
  292. LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
  293. return false;
  294. }
  295. }
  296. }
  297. }
  298. if (!batch.pos) {
  299. assert(p0 >= 0);
  300. pos.resize(batch.n_tokens);
  301. for (int32_t i = 0; i < batch.n_tokens; i++) {
  302. pos[i] = p0 + i;
  303. }
  304. batch.pos = pos.data();
  305. }
  306. if (!batch.n_seq_id) {
  307. n_seq_id.resize(batch.n_tokens);
  308. for (int32_t i = 0; i < batch.n_tokens; i++) {
  309. n_seq_id[i] = seq_id_0.size();
  310. }
  311. batch.n_seq_id = n_seq_id.data();
  312. }
  313. if (!batch.seq_id) {
  314. seq_id.resize(batch.n_tokens + 1);
  315. seq_id[batch.n_tokens] = NULL;
  316. for (int32_t i = 0; i < batch.n_tokens; i++) {
  317. seq_id[i] = seq_id_0.data();
  318. }
  319. batch.seq_id = seq_id.data();
  320. }
  321. if (!batch.logits) {
  322. // by default return the output only for the last token
  323. output.resize(batch.n_tokens);
  324. output[output.size() - 1] = true;
  325. batch.logits = output.data();
  326. }
  327. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  328. n_outputs += batch.logits[i] != 0;
  329. }
  330. if (debug > 0) {
  331. LLAMA_LOG_DEBUG("%s: input batch info (p0 = %d):\n", __func__, p0);
  332. LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
  333. LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
  334. LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
  335. LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
  336. LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) batch.n_seq_id);
  337. LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) batch.seq_id);
  338. LLAMA_LOG_DEBUG("%s: logits = %p\n", __func__, (void *) batch.logits);
  339. LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
  340. if (debug > 1) {
  341. int seq_id_max = 0;
  342. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  343. for (int s = 0; s < batch.n_seq_id[i]; ++s) {
  344. for (int s = 0; s < batch.n_seq_id[i]; ++s) {
  345. seq_id_max = std::max(seq_id_max, batch.seq_id[i][s]);
  346. }
  347. }
  348. }
  349. ++seq_id_max;
  350. LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
  351. for (int32_t i = 0; i < batch.n_tokens; ++i) {
  352. std::vector<int8_t> seq_id(seq_id_max);
  353. for (int s = 0; s < batch.n_seq_id[i]; ++s) {
  354. seq_id[batch.seq_id[i][s]] = 1;
  355. }
  356. std::stringstream ss;
  357. for (int s = 0; s < seq_id_max; ++s) {
  358. if (seq_id[s]) {
  359. ss << s%10;
  360. } else {
  361. ss << ".";
  362. }
  363. }
  364. LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
  365. __func__, i, batch.token[i], vocab.token_to_piece(batch.token[i]).c_str(),
  366. batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
  367. }
  368. LLAMA_LOG_DEBUG("%s: ]\n", __func__);
  369. }
  370. }
  371. return true;
  372. }
  373. const llama_batch & llama_batch_allocr::get_batch() const {
  374. return batch;
  375. }
  376. uint32_t llama_batch_allocr::get_n_outputs() const {
  377. return n_outputs;
  378. }
  379. void llama_batch_allocr::clear() {
  380. n_outputs = 0;
  381. batch = {};
  382. pos.clear();
  383. n_seq_id.clear();
  384. seq_id.clear();
  385. output.clear();
  386. }
  387. //
  388. // interface implementation
  389. //
  390. struct llama_batch llama_batch_get_one(
  391. llama_token * tokens,
  392. int32_t n_tokens) {
  393. return {
  394. /*n_tokens =*/ n_tokens,
  395. /*tokens =*/ tokens,
  396. /*embd =*/ nullptr,
  397. /*pos =*/ nullptr,
  398. /*n_seq_id =*/ nullptr,
  399. /*seq_id =*/ nullptr,
  400. /*logits =*/ nullptr,
  401. };
  402. }
  403. struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
  404. llama_batch batch = {
  405. /*n_tokens =*/ 0,
  406. /*tokens =*/ nullptr,
  407. /*embd =*/ nullptr,
  408. /*pos =*/ nullptr,
  409. /*n_seq_id =*/ nullptr,
  410. /*seq_id =*/ nullptr,
  411. /*logits =*/ nullptr,
  412. };
  413. if (embd) {
  414. batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
  415. } else {
  416. batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
  417. }
  418. batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
  419. batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
  420. batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
  421. for (int i = 0; i < n_tokens_alloc; ++i) {
  422. batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
  423. }
  424. batch.seq_id[n_tokens_alloc] = nullptr;
  425. batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
  426. return batch;
  427. }
  428. void llama_batch_free(struct llama_batch batch) {
  429. if (batch.token) free(batch.token);
  430. if (batch.embd) free(batch.embd);
  431. if (batch.pos) free(batch.pos);
  432. if (batch.n_seq_id) free(batch.n_seq_id);
  433. if (batch.seq_id) {
  434. for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
  435. free(batch.seq_id[i]);
  436. }
  437. free(batch.seq_id);
  438. }
  439. if (batch.logits) free(batch.logits);
  440. }