|
|
@@ -156,6 +156,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
|
|
|
|
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
|
|
|
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
|
|
|
+
|
|
|
+ const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
|
|
|
+ supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
|
|
|
+
|
|
|
+ if (!supports_set_rows) {
|
|
|
+ LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
void llama_kv_cache_unified::clear(bool data) {
|
|
|
@@ -353,13 +360,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
|
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
|
}
|
|
|
|
|
|
- auto heads = prepare(ubatches);
|
|
|
- if (heads.empty()) {
|
|
|
+ auto sinfos = prepare(ubatches);
|
|
|
+ if (sinfos.empty()) {
|
|
|
break;
|
|
|
}
|
|
|
|
|
|
return std::make_unique<llama_kv_cache_unified_context>(
|
|
|
- this, std::move(heads), std::move(ubatches));
|
|
|
+ this, std::move(sinfos), std::move(ubatches));
|
|
|
} while (false);
|
|
|
|
|
|
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
|
@@ -402,12 +409,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
|
|
|
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
|
|
|
}
|
|
|
|
|
|
-llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
|
|
- llama_kv_cache_unified::ubatch_heads res;
|
|
|
+llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
|
|
+ llama_kv_cache_unified::slot_info_vec_t res;
|
|
|
|
|
|
struct state {
|
|
|
uint32_t head_old; // old position of the head, before placing the ubatch
|
|
|
- uint32_t head_new; // new position of the head, after placing the ubatch
|
|
|
+
|
|
|
+ slot_info sinfo; // slot info for the ubatch
|
|
|
|
|
|
llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
|
|
|
};
|
|
|
@@ -418,26 +426,29 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::
|
|
|
bool success = true;
|
|
|
|
|
|
for (const auto & ubatch : ubatches) {
|
|
|
+ // non-continuous slots require support for ggml_set_rows()
|
|
|
+ const bool cont = supports_set_rows ? false : true;
|
|
|
+
|
|
|
// only find a suitable slot for the ubatch. don't modify the cells yet
|
|
|
- const int32_t head_new = find_slot(ubatch);
|
|
|
- if (head_new < 0) {
|
|
|
+ const auto sinfo_new = find_slot(ubatch, cont);
|
|
|
+ if (sinfo_new.empty()) {
|
|
|
success = false;
|
|
|
break;
|
|
|
}
|
|
|
|
|
|
// remeber the position that we found
|
|
|
- res.push_back(head_new);
|
|
|
+ res.push_back(sinfo_new);
|
|
|
|
|
|
// store the old state of the cells in the recovery stack
|
|
|
- states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)});
|
|
|
+ states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
|
|
|
|
|
|
// now emplace the ubatch
|
|
|
- apply_ubatch(head_new, ubatch);
|
|
|
+ apply_ubatch(sinfo_new, ubatch);
|
|
|
}
|
|
|
|
|
|
// iterate backwards and restore the cells to their original state
|
|
|
for (auto it = states.rbegin(); it != states.rend(); ++it) {
|
|
|
- cells.set(it->head_new, it->cells);
|
|
|
+ cells.set(it->sinfo.idxs, it->cells);
|
|
|
head = it->head_old;
|
|
|
}
|
|
|
|
|
|
@@ -539,7 +550,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
|
|
return updated;
|
|
|
}
|
|
|
|
|
|
-int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
|
+llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
|
|
const uint32_t n_tokens = ubatch.n_tokens;
|
|
|
|
|
|
uint32_t head_cur = this->head;
|
|
|
@@ -552,7 +563,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
|
|
|
|
if (n_tokens > cells.size()) {
|
|
|
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
|
|
- return -1;
|
|
|
+ return { };
|
|
|
}
|
|
|
|
|
|
if (debug > 0) {
|
|
|
@@ -615,15 +626,26 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
|
|
|
|
uint32_t n_tested = 0;
|
|
|
|
|
|
+ // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
|
|
|
+ // for non-continuous slots, we test the tokens one by one
|
|
|
+ const uint32_t n_test = cont ? n_tokens : 1;
|
|
|
+
|
|
|
+ slot_info res;
|
|
|
+
|
|
|
+ auto & idxs = res.idxs;
|
|
|
+
|
|
|
+ idxs.reserve(n_tokens);
|
|
|
+
|
|
|
while (true) {
|
|
|
- if (head_cur + n_tokens > cells.size()) {
|
|
|
+ if (head_cur + n_test > cells.size()) {
|
|
|
n_tested += cells.size() - head_cur;
|
|
|
head_cur = 0;
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
- bool found = true;
|
|
|
- for (uint32_t i = 0; i < n_tokens; i++) {
|
|
|
+ for (uint32_t i = 0; i < n_test; i++) {
|
|
|
+ const auto idx = head_cur;
|
|
|
+
|
|
|
//const llama_pos pos = ubatch.pos[i];
|
|
|
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
|
|
|
|
|
@@ -633,19 +655,19 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
|
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
|
|
|
// - mask SWA, using current max pos for that sequence in the cache
|
|
|
// always insert in the cell with minimum pos
|
|
|
- bool can_use = cells.is_empty(head_cur + i);
|
|
|
+ bool can_use = cells.is_empty(idx);
|
|
|
|
|
|
- if (!can_use && cells.seq_count(head_cur + i) == 1) {
|
|
|
- const llama_pos pos_cell = cells.pos_get(head_cur + i);
|
|
|
+ if (!can_use && cells.seq_count(idx) == 1) {
|
|
|
+ const llama_pos pos_cell = cells.pos_get(idx);
|
|
|
|
|
|
// (disabled) causal mask
|
|
|
// note: it's better to purge any "future" tokens beforehand
|
|
|
- //if (cells.seq_has(head_cur + i, seq_id)) {
|
|
|
+ //if (cells.seq_has(idx, seq_id)) {
|
|
|
// can_use = pos_cell >= pos;
|
|
|
//}
|
|
|
|
|
|
if (!can_use) {
|
|
|
- const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
|
|
|
+ const llama_seq_id seq_id_cell = cells.seq_get(idx);
|
|
|
|
|
|
// SWA mask
|
|
|
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
|
|
@@ -654,28 +676,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if (!can_use) {
|
|
|
- found = false;
|
|
|
- head_cur += i + 1;
|
|
|
- n_tested += i + 1;
|
|
|
+ head_cur++;
|
|
|
+ n_tested++;
|
|
|
+
|
|
|
+ if (can_use) {
|
|
|
+ idxs.push_back(idx);
|
|
|
+ } else {
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if (found) {
|
|
|
+ if (idxs.size() == n_tokens) {
|
|
|
break;
|
|
|
}
|
|
|
|
|
|
+ if (cont) {
|
|
|
+ idxs.clear();
|
|
|
+ }
|
|
|
+
|
|
|
if (n_tested >= cells.size()) {
|
|
|
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
|
|
- return -1;
|
|
|
+ return { };
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- return head_cur;
|
|
|
+ // we didn't find a suitable slot - return empty result
|
|
|
+ if (idxs.size() < n_tokens) {
|
|
|
+ res.clear();
|
|
|
+ }
|
|
|
+
|
|
|
+ return res;
|
|
|
}
|
|
|
|
|
|
-void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
|
|
|
+void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
|
|
|
// keep track of the max sequence position that we would overwrite with this ubatch
|
|
|
// for non-SWA cache, this would be always empty
|
|
|
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
|
|
@@ -683,22 +716,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
|
|
seq_pos_max_rm[s] = -1;
|
|
|
}
|
|
|
|
|
|
+ assert(ubatch.n_tokens == sinfo.idxs.size());
|
|
|
+
|
|
|
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
|
|
- if (!cells.is_empty(head_cur + i)) {
|
|
|
- assert(cells.seq_count(head_cur + i) == 1);
|
|
|
+ const auto idx = sinfo.idxs.at(i);
|
|
|
|
|
|
- const llama_seq_id seq_id = cells.seq_get(head_cur + i);
|
|
|
- const llama_pos pos = cells.pos_get(head_cur + i);
|
|
|
+ if (!cells.is_empty(idx)) {
|
|
|
+ assert(cells.seq_count(idx) == 1);
|
|
|
+
|
|
|
+ const llama_seq_id seq_id = cells.seq_get(idx);
|
|
|
+ const llama_pos pos = cells.pos_get(idx);
|
|
|
|
|
|
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
|
|
|
|
|
- cells.rm(head_cur + i);
|
|
|
+ cells.rm(idx);
|
|
|
}
|
|
|
|
|
|
- cells.pos_set(head_cur + i, ubatch.pos[i]);
|
|
|
+ cells.pos_set(idx, ubatch.pos[i]);
|
|
|
|
|
|
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
|
|
- cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
|
|
|
+ cells.seq_add(idx, ubatch.seq_id[i][s]);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -719,7 +756,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
|
|
}
|
|
|
|
|
|
// move the head at the end of the slot
|
|
|
- head = head_cur + ubatch.n_tokens;
|
|
|
+ head = sinfo.idxs.back() + 1;
|
|
|
}
|
|
|
|
|
|
bool llama_kv_cache_unified::get_can_shift() const {
|
|
|
@@ -772,47 +809,133 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
|
|
|
0);
|
|
|
}
|
|
|
|
|
|
-ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
|
|
|
+ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
|
|
|
const int32_t ikv = map_layer_ids.at(il);
|
|
|
|
|
|
auto * k = layers[ikv].k;
|
|
|
|
|
|
+ const int64_t n_embd_k_gqa = k->ne[0];
|
|
|
const int64_t n_tokens = k_cur->ne[2];
|
|
|
|
|
|
+ k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
|
|
|
+
|
|
|
+ if (k_idxs && supports_set_rows) {
|
|
|
+ return ggml_set_rows(ctx, k, k_cur, k_idxs);
|
|
|
+ }
|
|
|
+
|
|
|
+ // TODO: fallback to old ggml_cpy() method for backwards compatibility
|
|
|
+ // will be removed when ggml_set_rows() is adopted by all backends
|
|
|
+
|
|
|
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
|
|
- n_tokens*hparams.n_embd_k_gqa(il),
|
|
|
- ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
|
|
|
+ n_tokens*n_embd_k_gqa,
|
|
|
+ ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
|
|
|
|
|
|
return ggml_cpy(ctx, k_cur, k_view);
|
|
|
}
|
|
|
|
|
|
-ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
|
|
|
+ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
|
|
|
const int32_t ikv = map_layer_ids.at(il);
|
|
|
|
|
|
auto * v = layers[ikv].v;
|
|
|
|
|
|
+ const int64_t n_embd_v_gqa = v->ne[0];
|
|
|
const int64_t n_tokens = v_cur->ne[2];
|
|
|
|
|
|
- v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
|
|
|
+ v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
|
|
|
+
|
|
|
+ if (v_idxs && supports_set_rows) {
|
|
|
+ if (!v_trans) {
|
|
|
+ return ggml_set_rows(ctx, v, v_cur, v_idxs);
|
|
|
+ }
|
|
|
+
|
|
|
+ // the row becomes a single element
|
|
|
+ ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
|
|
|
+
|
|
|
+ // note: the V cache is transposed when not using flash attention
|
|
|
+ v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
|
|
|
+
|
|
|
+ // note: we can be more explicit here at the cost of extra cont
|
|
|
+ // however, above we take advantage that a row of single element is always continuous regardless of the row stride
|
|
|
+ //v_cur = ggml_transpose(ctx, v_cur);
|
|
|
+ //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
|
|
|
+
|
|
|
+ // we broadcast the KV indices n_embd_v_gqa times
|
|
|
+ // v [1, n_kv, n_embd_v_gqa]
|
|
|
+ // v_cur [1, n_tokens, n_embd_v_gqa]
|
|
|
+ // v_idxs [n_tokens, 1, 1]
|
|
|
+ return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
|
|
|
+ }
|
|
|
+
|
|
|
+ // TODO: fallback to old ggml_cpy() method for backwards compatibility
|
|
|
+ // will be removed when ggml_set_rows() is adopted by all backends
|
|
|
|
|
|
ggml_tensor * v_view = nullptr;
|
|
|
|
|
|
if (!v_trans) {
|
|
|
v_view = ggml_view_1d(ctx, v,
|
|
|
- n_tokens*hparams.n_embd_v_gqa(il),
|
|
|
- ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
|
|
|
+ n_tokens*n_embd_v_gqa,
|
|
|
+ ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
|
|
|
} else {
|
|
|
- // note: the V cache is transposed when not using flash attention
|
|
|
- v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
|
|
|
- (v->ne[1])*ggml_element_size(v),
|
|
|
- (head_cur)*ggml_element_size(v));
|
|
|
-
|
|
|
v_cur = ggml_transpose(ctx, v_cur);
|
|
|
+
|
|
|
+ v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
|
|
|
+ (v->ne[1] )*ggml_element_size(v),
|
|
|
+ (sinfo.head())*ggml_element_size(v));
|
|
|
}
|
|
|
|
|
|
return ggml_cpy(ctx, v_cur, v_view);
|
|
|
}
|
|
|
|
|
|
+ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
|
|
+ const uint32_t n_tokens = ubatch.n_tokens;
|
|
|
+
|
|
|
+ ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
|
|
|
+
|
|
|
+ ggml_set_input(k_idxs);
|
|
|
+
|
|
|
+ return k_idxs;
|
|
|
+}
|
|
|
+
|
|
|
+ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
|
|
+ const uint32_t n_tokens = ubatch.n_tokens;
|
|
|
+
|
|
|
+ ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
|
|
|
+
|
|
|
+ ggml_set_input(v_idxs);
|
|
|
+
|
|
|
+ return v_idxs;
|
|
|
+}
|
|
|
+
|
|
|
+void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
|
|
+ if (!supports_set_rows) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const uint32_t n_tokens = ubatch->n_tokens;
|
|
|
+
|
|
|
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
|
|
+ int64_t * data = (int64_t *) dst->data;
|
|
|
+
|
|
|
+ for (int64_t i = 0; i < n_tokens; ++i) {
|
|
|
+ data[i] = sinfo.idxs.at(i);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
|
|
+ if (!supports_set_rows) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const uint32_t n_tokens = ubatch->n_tokens;
|
|
|
+
|
|
|
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
|
|
+ int64_t * data = (int64_t *) dst->data;
|
|
|
+
|
|
|
+ for (int64_t i = 0; i < n_tokens; ++i) {
|
|
|
+ data[i] = sinfo.idxs.at(i);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
|
|
const uint32_t n_tokens = ubatch->n_tokens;
|
|
|
|
|
|
@@ -1552,13 +1675,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
|
ubatch.seq_id[i] = &dest_seq_id;
|
|
|
}
|
|
|
|
|
|
- const auto head_cur = find_slot(ubatch);
|
|
|
- if (head_cur < 0) {
|
|
|
+ const auto sinfo = find_slot(ubatch, true);
|
|
|
+ if (sinfo.empty()) {
|
|
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
|
|
return false;
|
|
|
}
|
|
|
|
|
|
- apply_ubatch(head_cur, ubatch);
|
|
|
+ apply_ubatch(sinfo, ubatch);
|
|
|
+
|
|
|
+ const auto head_cur = sinfo.head();
|
|
|
|
|
|
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
|
|
head = head_cur;
|
|
|
@@ -1744,7 +1869,11 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat
|
|
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
|
|
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
|
|
n_kv = kv->get_size();
|
|
|
- head = 0;
|
|
|
+
|
|
|
+ // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
|
|
|
+ sinfos.resize(1);
|
|
|
+ sinfos[0].idxs.resize(1);
|
|
|
+ sinfos[0].idxs[0] = 0;
|
|
|
}
|
|
|
|
|
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
|
|
@@ -1759,8 +1888,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
|
|
|
|
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
|
|
llama_kv_cache_unified * kv,
|
|
|
- llama_kv_cache_unified::ubatch_heads heads,
|
|
|
- std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
|
|
+ llama_kv_cache_unified::slot_info_vec_t sinfos,
|
|
|
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
|
|
|
}
|
|
|
|
|
|
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
|
|
@@ -1768,7 +1897,7 @@ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
|
|
bool llama_kv_cache_unified_context::next() {
|
|
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
|
|
|
|
- if (++i_next >= ubatches.size()) {
|
|
|
+ if (++i_cur >= ubatches.size()) {
|
|
|
return false;
|
|
|
}
|
|
|
|
|
|
@@ -1785,10 +1914,9 @@ bool llama_kv_cache_unified_context::apply() {
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
- kv->apply_ubatch(heads[i_next], ubatches[i_next]);
|
|
|
+ kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
|
|
|
|
|
|
n_kv = kv->get_n_kv();
|
|
|
- head = heads[i_next];
|
|
|
|
|
|
return true;
|
|
|
}
|
|
|
@@ -1800,7 +1928,7 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const {
|
|
|
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
|
|
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
|
|
|
|
- return ubatches[i_next];
|
|
|
+ return ubatches[i_cur];
|
|
|
}
|
|
|
|
|
|
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
|
|
@@ -1815,18 +1943,34 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
|
|
|
return kv->get_v(ctx, il, n_kv);
|
|
|
}
|
|
|
|
|
|
-ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
|
|
|
- return kv->cpy_k(ctx, k_cur, il, head);
|
|
|
+ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
|
|
|
+ return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
|
|
|
+}
|
|
|
+
|
|
|
+ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
|
|
|
+ return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
|
|
|
+}
|
|
|
+
|
|
|
+ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
|
|
+ return kv->build_input_k_idxs(ctx, ubatch);
|
|
|
}
|
|
|
|
|
|
-ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
|
|
|
- return kv->cpy_v(ctx, v_cur, il, head);
|
|
|
+ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
|
|
+ return kv->build_input_v_idxs(ctx, ubatch);
|
|
|
}
|
|
|
|
|
|
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
|
|
kv->set_input_k_shift(dst);
|
|
|
}
|
|
|
|
|
|
+void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
|
|
+ kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
|
|
|
+}
|
|
|
+
|
|
|
+void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
|
|
+ kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
|
|
|
+}
|
|
|
+
|
|
|
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
|
|
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
|
|
}
|