llama-batch.h 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. #pragma once
  2. #include "llama.h"
  3. #include "llama-cparams.h"
  4. #include <array>
  5. #include <vector>
  6. #include <set>
  7. #include <bitset>
  8. #include <memory>
  9. #include <unordered_map>
  10. // keep this struct lightweight
  11. struct llama_ubatch {
  12. bool equal_seqs() const {
  13. return b_equal_seqs != 0;
  14. }
  15. // typical for M-RoPE cases:
  16. // 0 - sequantial position of the tokens/embeddings in the sequence
  17. // 1 - y position in the image
  18. // 2 - x position in the image
  19. // 3 - other
  20. bool is_pos_2d() const {
  21. // TODO @ngxson : we may need to check for model arch when more models use >1 positions
  22. return n_pos >= 3;
  23. }
  24. uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
  25. // otherwise address sanitizer complains
  26. // TODO: whole_seqs for embeddings?
  27. uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
  28. uint32_t n_seq_tokens; // tokens per sequence set
  29. uint32_t n_seqs; // sequence sets in the ubatch
  30. uint32_t n_seqs_unq; // unique sequence ids in the ubatch
  31. uint32_t n_pos; // number of position inputs for each token/embedding
  32. // seq_id_unq: unique sequence ids in the ubatch
  33. // seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
  34. // used for extracting sequence pooled embeddings
  35. // // size | idx | val
  36. llama_token * token; // [n_tokens] | i | id, token
  37. float * embd; // [n_embd, n_tokens] | i | embd
  38. llama_pos * pos; // [n_tokens*n_pos] | i | pos
  39. int32_t * n_seq_id; // [n_tokens] | i | -
  40. llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
  41. llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
  42. int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
  43. int8_t * output; // [n_tokens] | i | -
  44. struct data_t {
  45. std::vector<llama_token> token;
  46. std::vector<float> embd;
  47. std::vector<llama_pos> pos;
  48. std::vector<int32_t> n_seq_id;
  49. std::vector<llama_seq_id *> seq_id;
  50. std::vector<llama_seq_id> seq_id_unq;
  51. std::vector<int32_t> seq_idx;
  52. std::vector<int8_t> output;
  53. };
  54. // the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
  55. std::shared_ptr<data_t> data;
  56. };
  57. // a helper for sanitizing, fulfilling and splitting a batch
  58. class llama_batch_allocr {
  59. public:
  60. llama_batch_allocr(uint32_t n_pos_per_embd);
  61. // sanitize and auto-gen missing data in the input batch
  62. // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
  63. bool init(
  64. const llama_batch & batch_inp,
  65. const llama_vocab & vocab,
  66. const llama_memory_i * memory,
  67. uint32_t n_embd,
  68. uint32_t n_seq_max,
  69. bool output_all);
  70. const llama_batch & get_batch() const;
  71. uint32_t get_n_tokens() const;
  72. uint32_t get_n_outputs() const;
  73. uint32_t get_n_used() const;
  74. // the array of output indices in the order they were encountered during the ubatch splitting
  75. std::vector<int32_t> & get_out_ids();
  76. // min/max positions of each sequence in the current ubatch
  77. llama_pos seq_pos_min(llama_seq_id seq_id) const;
  78. llama_pos seq_pos_max(llama_seq_id seq_id) const;
  79. // call once before splitting the batch to reset the internal state
  80. void split_reset();
  81. // simple split, unknown number of sequence sets of unequal lengths
  82. llama_ubatch split_simple(uint32_t n_ubatch);
  83. // make ubatches of equal-length sequences sets
  84. // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
  85. llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
  86. // sequence-set-wise split - each ubatch contains a single sequence-set
  87. llama_ubatch split_seq(uint32_t n_ubatch);
  88. // a helper method for creating a well-defined ubatch of tokens
  89. // TODO: support embeddings if needed in the future
  90. llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
  91. private:
  92. void clear();
  93. // create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
  94. // return llama_ubatch.n_tokens == 0 if the entire batch was consumed
  95. llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
  96. // for debugging, start with LLAMA_BATCH_DEBUG=2
  97. void ubatch_print(const llama_ubatch & ubatch, int debug);
  98. llama_batch batch;
  99. // only for debugging purposes
  100. const llama_vocab * vocab;
  101. // TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
  102. // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
  103. const uint32_t n_pos_per_embd;
  104. uint32_t n_embd;
  105. uint32_t n_seq_max;
  106. uint32_t n_outputs;
  107. std::array<llama_seq_id, 1> seq_id_0 = {{ 0 }}; // default sequence id
  108. std::vector<llama_pos> pos;
  109. std::vector<int32_t> n_seq_id;
  110. std::vector<llama_seq_id *> seq_id;
  111. std::vector<llama_seq_id> seq_id_unq;
  112. std::vector<int32_t> seq_idx;
  113. std::vector<int8_t> output;
  114. using pos_set_t = std::set<llama_pos>;
  115. using seq_cpl_t = std::vector<bool>;
  116. // helper flag to quickly determine if there are any coupled sequences in the batch
  117. bool has_cpl = false;
  118. std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
  119. std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
  120. using idx_vec_t = std::vector<int32_t>;
  121. using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
  122. std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
  123. std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
  124. // batch indices of the output
  125. std::vector<int32_t> out_ids;
  126. uint32_t n_used;
  127. // used[i] indicates if token i has already been used in a previous ubatch
  128. std::vector<bool> used;
  129. int debug;
  130. };