|
|
@@ -119,10 +119,10 @@ bool llama_kv_cache_init(
|
|
|
|
|
|
struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
|
struct llama_kv_cache & cache,
|
|
|
- const struct llama_ubatch & batch) {
|
|
|
- const uint32_t n_tokens = batch.n_tokens;
|
|
|
- const uint32_t n_seqs = batch.n_seqs;
|
|
|
- const uint32_t n_seq_tokens = batch.n_seq_tokens;
|
|
|
+ const struct llama_ubatch & ubatch) {
|
|
|
+ const uint32_t n_tokens = ubatch.n_tokens;
|
|
|
+ const uint32_t n_seqs = ubatch.n_seqs;
|
|
|
+ const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
|
|
|
|
|
if (cache.recurrent) {
|
|
|
// For recurrent state architectures (like Mamba or RWKV),
|
|
|
@@ -130,16 +130,16 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
|
// A slot should be always be contiguous.
|
|
|
|
|
|
// can only process batches with an equal number of new tokens in each sequence
|
|
|
- GGML_ASSERT(batch.equal_seqs);
|
|
|
+ GGML_ASSERT(ubatch.equal_seqs);
|
|
|
|
|
|
int32_t min = cache.size - 1;
|
|
|
int32_t max = 0;
|
|
|
|
|
|
// everything should fit if all seq_ids are smaller than the max
|
|
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
|
- const uint32_t n_seq_id = batch.n_seq_id[s];
|
|
|
+ const uint32_t n_seq_id = ubatch.n_seq_id[s];
|
|
|
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
|
|
- const llama_seq_id seq_id = batch.seq_id[s][j];
|
|
|
+ const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
|
|
|
|
|
if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
|
|
|
// too big seq_id
|
|
|
@@ -198,7 +198,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
|
|
|
|
// find usable cell range
|
|
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
|
- const llama_seq_id seq_id = batch.seq_id[s][0];
|
|
|
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
|
|
llama_kv_cell & seq_meta = cache.cells[seq_id];
|
|
|
bool has_cell = false;
|
|
|
if (seq_meta.tail >= 0) {
|
|
|
@@ -237,7 +237,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
|
// gather and re-order
|
|
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
|
int32_t dst_id = s + min;
|
|
|
- int32_t src_id = cache.cells[batch.seq_id[s][0]].tail;
|
|
|
+ int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail;
|
|
|
if (dst_id != src_id) {
|
|
|
llama_kv_cell & dst_cell = cache.cells[dst_id];
|
|
|
llama_kv_cell & src_cell = cache.cells[src_id];
|
|
|
@@ -258,7 +258,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
|
|
|
|
// update the pos of the used seqs
|
|
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
|
- const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
|
|
+ const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
|
|
int32_t cell_id = s + min;
|
|
|
llama_kv_cell & cell = cache.cells[cell_id];
|
|
|
|
|
|
@@ -266,12 +266,12 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
|
// What should happen when the pos backtracks or skips a value?
|
|
|
// Clearing the state mid-batch would require special-casing which isn't done.
|
|
|
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
|
|
|
- __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
|
|
|
+ __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
|
|
|
}
|
|
|
cell.pos = last_pos;
|
|
|
cell.seq_id.clear();
|
|
|
- for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) {
|
|
|
- const llama_seq_id seq_id = batch.seq_id[s][j];
|
|
|
+ for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
|
|
|
+ const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
|
|
cell.seq_id.insert(seq_id);
|
|
|
cache.cells[seq_id].tail = cell_id;
|
|
|
}
|
|
|
@@ -325,10 +325,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
|
for (uint32_t s = 0; s < n_seqs; s++) {
|
|
|
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
|
|
|
uint32_t k = s*n_seq_tokens + i;
|
|
|
- cache.cells[cache.head + k].pos = batch.pos[k];
|
|
|
+ cache.cells[cache.head + k].pos = ubatch.pos[k];
|
|
|
|
|
|
- for (int32_t j = 0; j < batch.n_seq_id[s]; j++) {
|
|
|
- cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]);
|
|
|
+ for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
|
|
|
+ cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]);
|
|
|
}
|
|
|
}
|
|
|
}
|