|
|
@@ -1,6 +1,7 @@
|
|
|
#include "llama-kv-cache-unified.h"
|
|
|
|
|
|
#include "llama-impl.h"
|
|
|
+#include "llama-io.h"
|
|
|
#include "llama-model.h"
|
|
|
#include "llama-context.h"
|
|
|
|
|
|
@@ -320,16 +321,49 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
|
|
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
|
}
|
|
|
|
|
|
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS,
|
|
|
+ return std::make_unique<llama_kv_cache_unified_state>(
|
|
|
this, std::move(sbatch), std::move(heads), std::move(ubatches));
|
|
|
}
|
|
|
|
|
|
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
|
|
|
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
|
|
+ return std::make_unique<llama_kv_cache_unified_state>(this);
|
|
|
}
|
|
|
|
|
|
-std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
|
|
- std::vector<uint32_t> res;
|
|
|
+llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
|
|
|
+ bool do_shift = get_has_shift();
|
|
|
+
|
|
|
+ defrag_info dinfo;
|
|
|
+
|
|
|
+ // see if we need to defrag
|
|
|
+ {
|
|
|
+ bool do_defrag = optimize;
|
|
|
+
|
|
|
+ const auto thold = lctx->get_cparams().defrag_thold;
|
|
|
+
|
|
|
+ if (!do_defrag && thold > 0.0f) {
|
|
|
+ const auto n_kv = cells.used_max_p1();
|
|
|
+
|
|
|
+ // - do not defrag small contexts (i.e. < 2048 tokens)
|
|
|
+ // - count the padding towards the number of used tokens
|
|
|
+ const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
|
|
|
+
|
|
|
+ if (fragmentation > thold) {
|
|
|
+ LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
|
|
+
|
|
|
+ do_defrag = true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if (do_defrag) {
|
|
|
+ dinfo = defrag_prepare(lctx->graph_max_nodes());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
|
|
|
+}
|
|
|
+
|
|
|
+llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
|
|
+ llama_kv_cache_unified::ubatch_heads res;
|
|
|
|
|
|
struct state {
|
|
|
uint32_t head_old; // old position of the head, before placing the ubatch
|
|
|
@@ -374,12 +408,12 @@ std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ub
|
|
|
return res;
|
|
|
}
|
|
|
|
|
|
-bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|
|
+bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
|
|
|
bool updated = false;
|
|
|
|
|
|
- auto * sched = lctx.get_sched();
|
|
|
+ auto * sched = lctx->get_sched();
|
|
|
|
|
|
- if (cells.get_has_shift()) {
|
|
|
+ if (do_shift) {
|
|
|
if (!get_can_shift()) {
|
|
|
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
|
|
|
}
|
|
|
@@ -390,9 +424,9 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|
|
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
|
|
ggml_backend_sched_reset(sched);
|
|
|
|
|
|
- auto * gf = lctx.graph_init();
|
|
|
+ auto * gf = lctx->graph_init();
|
|
|
|
|
|
- auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
|
|
|
+ auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
|
|
|
if (!res) {
|
|
|
LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
|
|
|
return updated;
|
|
|
@@ -405,7 +439,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|
|
|
|
|
res->set_inputs(nullptr);
|
|
|
|
|
|
- if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
|
|
+ if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
|
|
LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
|
|
|
return updated;
|
|
|
}
|
|
|
@@ -416,54 +450,53 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|
|
cells.reset_shift();
|
|
|
}
|
|
|
|
|
|
- if (do_defrag) {
|
|
|
+ if (!dinfo.empty()) {
|
|
|
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
|
|
|
|
|
- if (defrag_prepare(lctx.graph_max_nodes())) {
|
|
|
- ggml_backend_sched_reset(sched);
|
|
|
-
|
|
|
- auto * gf = lctx.graph_init();
|
|
|
-
|
|
|
- auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
|
|
|
- if (!res) {
|
|
|
- LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
|
|
|
- return updated;
|
|
|
- }
|
|
|
+ // apply moves:
|
|
|
+ {
|
|
|
+ const auto n_kv = dinfo.ids.size();
|
|
|
|
|
|
- if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
|
|
- LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
|
|
|
- return updated;
|
|
|
- }
|
|
|
+ for (uint32_t i = 0; i < n_kv; ++i) {
|
|
|
+ assert(dinfo.ids[i] <= n_kv);
|
|
|
|
|
|
- res->set_inputs(nullptr);
|
|
|
+ if (dinfo.ids[i] == n_kv) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
|
|
|
- if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
|
|
- LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
|
|
|
- return updated;
|
|
|
+ cells.mv(i, dinfo.ids[i]);
|
|
|
}
|
|
|
|
|
|
- updated = true;
|
|
|
+ // reset the head so we can find the first free slot during the next ubatch
|
|
|
+ head = 0;
|
|
|
}
|
|
|
|
|
|
- do_defrag = false;
|
|
|
- }
|
|
|
+ ggml_backend_sched_reset(sched);
|
|
|
|
|
|
- return updated;
|
|
|
-}
|
|
|
+ auto * gf = lctx->graph_init();
|
|
|
|
|
|
-void llama_kv_cache_unified::defrag_sched(float thold) {
|
|
|
- const auto n_kv = cells.used_max_p1();
|
|
|
+ auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
|
|
|
+ if (!res) {
|
|
|
+ LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
|
|
|
+ return updated;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
|
|
+ LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
|
|
|
+ return updated;
|
|
|
+ }
|
|
|
|
|
|
- // - do not defrag small contexts (i.e. < 2048 tokens)
|
|
|
- // - count the padding towards the number of used tokens
|
|
|
- const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
|
|
|
+ res->set_inputs(nullptr);
|
|
|
|
|
|
- // queue defragmentation for next llama_kv_cache_update
|
|
|
- if (fragmentation > thold) {
|
|
|
- LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
|
|
+ if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
|
|
+ LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
|
|
|
+ return updated;
|
|
|
+ }
|
|
|
|
|
|
- do_defrag = true;
|
|
|
+ updated = true;
|
|
|
}
|
|
|
+
|
|
|
+ return updated;
|
|
|
}
|
|
|
|
|
|
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
|
@@ -612,6 +645,10 @@ uint32_t llama_kv_cache_unified::get_size() const {
|
|
|
return cells.size();
|
|
|
}
|
|
|
|
|
|
+bool llama_kv_cache_unified::get_has_shift() const {
|
|
|
+ return cells.get_has_shift();
|
|
|
+}
|
|
|
+
|
|
|
uint32_t llama_kv_cache_unified::get_n_kv() const {
|
|
|
return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
|
|
|
}
|
|
|
@@ -941,12 +978,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|
|
}
|
|
|
|
|
|
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|
|
- const llama_cparams & cparams,
|
|
|
- ggml_context * ctx,
|
|
|
- ggml_cgraph * gf) const {
|
|
|
+ const llama_cparams & cparams,
|
|
|
+ ggml_context * ctx,
|
|
|
+ ggml_cgraph * gf,
|
|
|
+ const defrag_info & dinfo) const {
|
|
|
auto res = std::make_unique<llm_graph_result>();
|
|
|
|
|
|
- const auto & ids = defrag_info.ids;
|
|
|
+ const auto & ids = dinfo.ids;
|
|
|
|
|
|
#if 0
|
|
|
// CPU defrag
|
|
|
@@ -1087,7 +1125,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|
|
return res;
|
|
|
}
|
|
|
|
|
|
-bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
|
+llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
|
|
|
const uint32_t n_layer = layers.size();
|
|
|
|
|
|
const uint32_t n_kv = cells.used_max_p1();
|
|
|
@@ -1108,14 +1146,9 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
|
const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
|
|
|
|
|
|
// determine which KV cells to move where
|
|
|
- //
|
|
|
- // cell i moves to ids[i]
|
|
|
- //
|
|
|
- // if ids[i] == i || ids[i] == n_kv, then cell i is not moved
|
|
|
- //
|
|
|
- auto & ids = defrag_info.ids;
|
|
|
+ defrag_info res;
|
|
|
+ auto & ids = res.ids;
|
|
|
|
|
|
- ids.clear();
|
|
|
ids.resize(n_kv, n_kv);
|
|
|
|
|
|
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
|
|
@@ -1179,11 +1212,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
|
// this cell goes to (i0 + nf)
|
|
|
ids[i1] = i0 + nf;
|
|
|
|
|
|
- // move the cell meta data
|
|
|
- cells.mv(i1, i0 + nf);
|
|
|
-
|
|
|
- head = n_used;
|
|
|
-
|
|
|
if (!cont) {
|
|
|
n_moves++;
|
|
|
cont = true;
|
|
|
@@ -1206,14 +1234,14 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
|
}
|
|
|
|
|
|
if (n_moves == 0) {
|
|
|
- return false;
|
|
|
+ return {};
|
|
|
}
|
|
|
|
|
|
LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
|
|
|
|
|
|
LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
|
|
|
|
|
|
- return true;
|
|
|
+ return res;
|
|
|
}
|
|
|
|
|
|
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
|
|
@@ -1636,24 +1664,27 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
|
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
|
|
|
|
|
|
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
|
|
- llama_memory_status status,
|
|
|
- llama_kv_cache_unified * kv) : status(status), kv(kv) {
|
|
|
- n_kv = kv->get_size();
|
|
|
- head = 0;
|
|
|
- }
|
|
|
+ llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
|
|
+ n_kv = kv->get_size();
|
|
|
+ head = 0;
|
|
|
+}
|
|
|
|
|
|
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
|
|
- llama_memory_status status,
|
|
|
- llama_kv_cache_unified * kv,
|
|
|
- llama_sbatch sbatch,
|
|
|
- std::vector<uint32_t> heads,
|
|
|
- std::vector<llama_ubatch> ubatches)
|
|
|
- : status(status),
|
|
|
- kv(kv),
|
|
|
- sbatch(std::move(sbatch)),
|
|
|
- heads(std::move(heads)),
|
|
|
- ubatches(std::move(ubatches)) {
|
|
|
+ llama_kv_cache_unified * kv,
|
|
|
+ llama_context * lctx,
|
|
|
+ bool do_shift,
|
|
|
+ defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
|
|
|
+ if (!do_shift && dinfo.empty()) {
|
|
|
+ status = LLAMA_MEMORY_STATUS_NO_UPDATE;
|
|
|
}
|
|
|
+}
|
|
|
+
|
|
|
+llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
|
|
+ llama_kv_cache_unified * kv,
|
|
|
+ llama_sbatch sbatch,
|
|
|
+ llama_kv_cache_unified::ubatch_heads heads,
|
|
|
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
|
|
+}
|
|
|
|
|
|
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
|
|
|
|
|
|
@@ -1670,6 +1701,13 @@ bool llama_kv_cache_unified_state::next() {
|
|
|
bool llama_kv_cache_unified_state::apply() {
|
|
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
|
|
|
|
+ // no ubatches -> this is a KV cache update
|
|
|
+ if (ubatches.empty()) {
|
|
|
+ kv->update(lctx, do_shift, dinfo);
|
|
|
+
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
kv->apply_ubatch(heads[i_next], ubatches[i_next]);
|
|
|
|
|
|
n_kv = kv->get_n_kv();
|