llama-batch.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. #include "llama-batch.h"
  2. #include <cassert>
  3. #include <cstring>
  4. #include <algorithm>
  5. llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
  6. // clear empty sequences
  7. // the previous ubatch is assumed to be gone,
  8. // so nothing should refer to values in these sequences anymore.
  9. for (size_t i = seq.size(); i-- > 0;) {
  10. if (seq[i].length == 0) {
  11. seq.pop_back();
  12. } else {
  13. break;
  14. }
  15. }
  16. udatas.push_back({});
  17. auto & udata = udatas.back();
  18. udata.token.resize(!has_embd ? n_ubatch : 0);
  19. udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
  20. udata.pos.resize(n_ubatch);
  21. udata.n_seq_id.resize(n_ubatch);
  22. udata.seq_id.resize(n_ubatch);
  23. udata.output.resize(n_ubatch);
  24. llama_ubatch ubatch = {
  25. /*equal_seqs =*/ true,
  26. /*n_tokens =*/ 0,
  27. /*n_seq_tokens =*/ 0,
  28. /*n_seqs =*/ 0,
  29. /*token =*/ !has_embd ? udata.token.data() : nullptr,
  30. /*embd =*/ has_embd ? udata.embd.data() : nullptr,
  31. /*pos =*/ udata.pos.data(),
  32. /*n_seq_id =*/ udata.n_seq_id.data(),
  33. /*seq_id =*/ udata.seq_id.data(),
  34. /*output =*/ udata.output.data(),
  35. };
  36. return ubatch;
  37. }
  38. void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
  39. GGML_ASSERT(batch != nullptr);
  40. GGML_ASSERT(length <= seq.length);
  41. // Can only add sequences of equal lengths to a batch,
  42. // otherwise it isn't clear to which sequence a token belongs
  43. GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
  44. GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
  45. // NOTE: loops are separated for cache-friendliness
  46. if (batch->token) {
  47. if (ubatch.equal_seqs) {
  48. for (size_t i = 0; i < length; ++i) {
  49. ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
  50. }
  51. } else {
  52. // simple split
  53. ubatch.token = batch->token + seq.offset;
  54. }
  55. } else {
  56. ubatch.token = nullptr;
  57. }
  58. if (batch->embd) {
  59. if (ubatch.equal_seqs) {
  60. for (size_t i = 0; i < length; ++i) {
  61. memcpy(
  62. ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
  63. batch->embd + (n_embd * ids[seq.offset + i]),
  64. n_embd * sizeof(float)
  65. );
  66. }
  67. } else {
  68. // simple split
  69. ubatch.embd = batch->embd + (n_embd * seq.offset);
  70. }
  71. } else {
  72. ubatch.embd = nullptr;
  73. }
  74. if (ubatch.equal_seqs) {
  75. for (size_t i = 0; i < length; ++i) {
  76. ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
  77. }
  78. } else {
  79. // simple split
  80. ubatch.pos = batch->pos + seq.offset;
  81. }
  82. if (ubatch.equal_seqs) {
  83. ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
  84. if (seq.seq_id) {
  85. ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
  86. }
  87. } else {
  88. // simple split
  89. if (batch->n_seq_id) {
  90. ubatch.n_seq_id = batch->n_seq_id + seq.offset;
  91. } else {
  92. for (size_t i = 0; i < length; ++i) {
  93. ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
  94. }
  95. }
  96. if (batch->seq_id) {
  97. ubatch.seq_id = batch->seq_id + seq.offset;
  98. }
  99. }
  100. if (logits_all) {
  101. for (size_t i = 0; i < length; ++i) {
  102. ubatch.output[ubatch.n_tokens + i] = 1;
  103. out_ids.push_back(ids[seq.offset + i]);
  104. }
  105. } else if (batch->logits) {
  106. if (ubatch.equal_seqs) {
  107. for (size_t i = 0; i < length; ++i) {
  108. size_t id = ids[seq.offset + i];
  109. int8_t is_output = batch->logits[id];
  110. ubatch.output[ubatch.n_tokens + i] = is_output;
  111. if (is_output) { out_ids.push_back(id); }
  112. }
  113. } else {
  114. // simple split
  115. ubatch.output = batch->logits + seq.offset;
  116. for (size_t i = 0; i < length; ++i) {
  117. if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
  118. }
  119. }
  120. } else {
  121. // only get last output
  122. for (size_t i = 0; i < length; ++i) {
  123. size_t id = ids[seq.offset + i];
  124. int8_t is_last = id == ids.size() - 1;
  125. ubatch.output[ubatch.n_tokens + i] = is_last;
  126. if (is_last) { out_ids.push_back(id); }
  127. }
  128. }
  129. if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
  130. ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
  131. }
  132. ubatch.n_tokens += length;
  133. ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
  134. seq.offset += length;
  135. seq.length -= length;
  136. n_tokens -= length;
  137. GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
  138. }
  139. llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
  140. n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
  141. llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
  142. ubatch.equal_seqs = false;
  143. if (!seq.empty()) {
  144. llama_sbatch_seq & s = seq[0];
  145. size_t length = s.length < n_ubatch ? s.length : n_ubatch;
  146. GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
  147. add_seq_to_ubatch(ubatch, s, length);
  148. }
  149. return ubatch;
  150. }
  151. llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
  152. n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
  153. llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
  154. if (!seq.empty()) {
  155. size_t length = 0;
  156. size_t n_tokens_in_ubatch = 0;
  157. GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
  158. // smallest first, because it's easier to split this way;
  159. // starting from the end to pop in constant time.
  160. for (size_t i = seq.size(); i-- > 0;) {
  161. llama_sbatch_seq & s = seq[i];
  162. GGML_ASSERT(s.length > 0);
  163. if (length == 0) {
  164. length = s.length < n_ubatch ? s.length : n_ubatch;
  165. }
  166. add_seq_to_ubatch(ubatch, s, length);
  167. n_tokens_in_ubatch += length;
  168. // shared prompts can't be mixed with any of their sequences,
  169. // so it's safer to compute them in their own ubatch
  170. if (s.n_seq_id > 1) { break; }
  171. // stop when there isn't enough space for another sequence
  172. if (length + n_tokens_in_ubatch > n_ubatch) { break; }
  173. }
  174. }
  175. return ubatch;
  176. }
  177. llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
  178. n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
  179. llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
  180. if (!seq.empty()) {
  181. llama_sbatch_seq & s = seq[seq.size() - 1];
  182. size_t length = s.length < n_ubatch ? s.length : n_ubatch;
  183. GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
  184. add_seq_to_ubatch(ubatch, s, length);
  185. }
  186. return ubatch;
  187. }
  188. llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
  189. GGML_ASSERT(batch.n_tokens >= 0);
  190. this->batch = &batch;
  191. this->n_embd = n_embd;
  192. this->logits_all = logits_all;
  193. n_tokens = batch.n_tokens;
  194. ids.resize(n_tokens);
  195. out_ids.clear();
  196. // TODO: reserve out_ids and seq
  197. for (size_t i = 0; i < n_tokens; ++i) {
  198. ids[i] = i;
  199. }
  200. if (simple_split) {
  201. seq.resize(1);
  202. llama_sbatch_seq & s = seq[0];
  203. s.n_seq_id = 0;
  204. s.seq_id = nullptr;
  205. s.offset = 0;
  206. s.length = n_tokens;
  207. return;
  208. }
  209. std::sort(ids.begin(), ids.end(),
  210. [&batch](size_t a, size_t b) {
  211. int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
  212. int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
  213. // sort by seq_id, then by pos
  214. if (n_seq_a == n_seq_b) {
  215. if (batch.seq_id) {
  216. for (int32_t i = 0; i < n_seq_a; ++i) {
  217. llama_seq_id seq_id_a = batch.seq_id[a][i];
  218. llama_seq_id seq_id_b = batch.seq_id[b][i];
  219. // smaller seq_ids go first
  220. if (seq_id_a != seq_id_b) {
  221. return seq_id_a < seq_id_b;
  222. }
  223. }
  224. }
  225. // when all else is equal, sort by pos
  226. if (batch.pos) {
  227. return batch.pos[a] < batch.pos[b];
  228. }
  229. // no pos, sort by id
  230. return a < b;
  231. }
  232. // shared prompts go first
  233. return n_seq_a > n_seq_b;
  234. }
  235. );
  236. // init seq
  237. llama_sbatch_seq * last_seq = nullptr;
  238. for (size_t i = 0; i < n_tokens; ++i) {
  239. const size_t bi = ids[i];
  240. const int32_t n_seqs = batch.n_seq_id[bi];
  241. llama_seq_id * seq_ids = batch.seq_id[bi];
  242. if (last_seq != nullptr) {
  243. bool same = n_seqs == last_seq->n_seq_id;
  244. for (int32_t j = 0; same && j < n_seqs; ++j) {
  245. if (seq_ids[j] != last_seq->seq_id[j]) {
  246. same = false;
  247. }
  248. }
  249. if (same) {
  250. last_seq->length += 1;
  251. continue;
  252. }
  253. }
  254. llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
  255. seq.push_back(new_seq);
  256. last_seq = &seq.back();
  257. }
  258. // keep shared prompts first at the end, then sort by length descending.
  259. std::sort(seq.begin(), seq.end(),
  260. [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
  261. if (a.n_seq_id == b.n_seq_id) {
  262. return a.length > b.length;
  263. }
  264. return a.n_seq_id < b.n_seq_id;
  265. }
  266. );
  267. }
  268. llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
  269. batch = in_batch;
  270. GGML_ASSERT(batch.n_tokens > 0);
  271. if (!batch.pos) {
  272. assert(p0 >= 0);
  273. pos.resize(batch.n_tokens);
  274. for (int32_t i = 0; i < batch.n_tokens; i++) {
  275. pos[i] = p0 + i;
  276. }
  277. batch.pos = pos.data();
  278. }
  279. if (!batch.n_seq_id) {
  280. n_seq_id.resize(batch.n_tokens);
  281. for (int32_t i = 0; i < batch.n_tokens; i++) {
  282. n_seq_id[i] = seq_id_0.size();
  283. }
  284. batch.n_seq_id = n_seq_id.data();
  285. }
  286. if (!batch.seq_id) {
  287. seq_id.resize(batch.n_tokens + 1);
  288. seq_id[batch.n_tokens] = NULL;
  289. for (int32_t i = 0; i < batch.n_tokens; i++) {
  290. seq_id[i] = seq_id_0.data();
  291. }
  292. batch.seq_id = seq_id.data();
  293. }
  294. if (!batch.logits) {
  295. logits.resize(batch.n_tokens);
  296. logits[logits.size() - 1] = true;
  297. batch.logits = logits.data();
  298. }
  299. }
  300. //
  301. // interface implementation
  302. //
  303. struct llama_batch llama_batch_get_one(
  304. llama_token * tokens,
  305. int32_t n_tokens) {
  306. return {
  307. /*n_tokens =*/ n_tokens,
  308. /*tokens =*/ tokens,
  309. /*embd =*/ nullptr,
  310. /*pos =*/ nullptr,
  311. /*n_seq_id =*/ nullptr,
  312. /*seq_id =*/ nullptr,
  313. /*logits =*/ nullptr,
  314. };
  315. }
  316. struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
  317. llama_batch batch = {
  318. /*n_tokens =*/ 0,
  319. /*tokens =*/ nullptr,
  320. /*embd =*/ nullptr,
  321. /*pos =*/ nullptr,
  322. /*n_seq_id =*/ nullptr,
  323. /*seq_id =*/ nullptr,
  324. /*logits =*/ nullptr,
  325. };
  326. if (embd) {
  327. batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
  328. } else {
  329. batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
  330. }
  331. batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
  332. batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
  333. batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
  334. for (int i = 0; i < n_tokens_alloc; ++i) {
  335. batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
  336. }
  337. batch.seq_id[n_tokens_alloc] = nullptr;
  338. batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
  339. return batch;
  340. }
  341. void llama_batch_free(struct llama_batch batch) {
  342. if (batch.token) free(batch.token);
  343. if (batch.embd) free(batch.embd);
  344. if (batch.pos) free(batch.pos);
  345. if (batch.n_seq_id) free(batch.n_seq_id);
  346. if (batch.seq_id) {
  347. for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
  348. free(batch.seq_id[i]);
  349. }
  350. free(batch.seq_id);
  351. }
  352. if (batch.logits) free(batch.logits);
  353. }