|
@@ -3502,11 +3502,24 @@ static bool llama_kv_cache_init(
|
|
|
return true;
|
|
return true;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+// a structure holds information about the slot found in llama_kv_cache_find_slot
|
|
|
|
|
+struct llama_kv_cache_slot_info {
|
|
|
|
|
+ std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
|
|
|
|
|
+ bool found = false; // the slot was found
|
|
|
|
|
+
|
|
|
|
|
+ explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
|
|
|
|
|
+ llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
|
|
|
|
|
+
|
|
|
|
|
+ operator bool() const { return found; }
|
|
|
|
|
+};
|
|
|
|
|
+static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
|
|
|
|
|
+
|
|
|
// find an empty slot of size "n_tokens" in the cache
|
|
// find an empty slot of size "n_tokens" in the cache
|
|
|
// updates the cache head
|
|
// updates the cache head
|
|
|
|
|
+// returns a structure holding information about the slot found
|
|
|
// Note: On success, it's important that cache.head points
|
|
// Note: On success, it's important that cache.head points
|
|
|
// to the first cell of the slot.
|
|
// to the first cell of the slot.
|
|
|
-static bool llama_kv_cache_find_slot(
|
|
|
|
|
|
|
+static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
|
struct llama_kv_cache & cache,
|
|
struct llama_kv_cache & cache,
|
|
|
const struct llama_ubatch & batch) {
|
|
const struct llama_ubatch & batch) {
|
|
|
const uint32_t n_tokens = batch.n_tokens;
|
|
const uint32_t n_tokens = batch.n_tokens;
|
|
@@ -3534,7 +3547,7 @@ static bool llama_kv_cache_find_slot(
|
|
|
// too big seq_id
|
|
// too big seq_id
|
|
|
// TODO: would it be possible to resize the cache instead?
|
|
// TODO: would it be possible to resize the cache instead?
|
|
|
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
|
|
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
|
|
|
- return false;
|
|
|
|
|
|
|
+ return llama_kv_cache_slot_info_failed;
|
|
|
}
|
|
}
|
|
|
if (j > 0) {
|
|
if (j > 0) {
|
|
|
llama_kv_cell & seq = cache.cells[seq_id];
|
|
llama_kv_cell & seq = cache.cells[seq_id];
|
|
@@ -3669,15 +3682,17 @@ static bool llama_kv_cache_find_slot(
|
|
|
// allow getting the range of used cells, from head to head + n
|
|
// allow getting the range of used cells, from head to head + n
|
|
|
cache.head = min;
|
|
cache.head = min;
|
|
|
cache.n = max - min + 1;
|
|
cache.n = max - min + 1;
|
|
|
|
|
+ cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
|
|
|
|
|
+ [](const llama_kv_cell& cell){ return !cell.is_empty(); });
|
|
|
|
|
|
|
|
// sanity check
|
|
// sanity check
|
|
|
- return cache.n >= n_seqs;
|
|
|
|
|
|
|
+ return llama_kv_cache_slot_info(cache.n >= n_seqs);
|
|
|
}
|
|
}
|
|
|
// otherwise, one cell per token.
|
|
// otherwise, one cell per token.
|
|
|
|
|
|
|
|
if (n_tokens > cache.size) {
|
|
if (n_tokens > cache.size) {
|
|
|
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
|
|
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
|
|
|
- return false;
|
|
|
|
|
|
|
+ return llama_kv_cache_slot_info_failed;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
uint32_t n_tested = 0;
|
|
uint32_t n_tested = 0;
|
|
@@ -3705,7 +3720,7 @@ static bool llama_kv_cache_find_slot(
|
|
|
|
|
|
|
|
if (n_tested >= cache.size) {
|
|
if (n_tested >= cache.size) {
|
|
|
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
|
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
|
|
- return false;
|
|
|
|
|
|
|
+ return llama_kv_cache_slot_info_failed;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -3722,7 +3737,7 @@ static bool llama_kv_cache_find_slot(
|
|
|
|
|
|
|
|
cache.used += n_tokens;
|
|
cache.used += n_tokens;
|
|
|
|
|
|
|
|
- return true;
|
|
|
|
|
|
|
+ return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// find how many cells are currently in use
|
|
// find how many cells are currently in use
|
|
@@ -3998,6 +4013,53 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
|
|
|
return cparams.flash_attn ? 256u : 32u;
|
|
return cparams.flash_attn ? 256u : 32u;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+// saves the kv_cache state for future recovery.
|
|
|
|
|
+// used to rollback llama_kv_cache_find_slot changes.
|
|
|
|
|
+struct llama_kv_slot_restorer {
|
|
|
|
|
+ struct llama_kv_cache_state {
|
|
|
|
|
+ uint32_t head = 0;
|
|
|
|
|
+ uint32_t n = 0;
|
|
|
|
|
+ } old_state;
|
|
|
|
|
+
|
|
|
|
|
+ // for non-recurrent models only
|
|
|
|
|
+ // list of slots to restore
|
|
|
|
|
+ std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
|
|
|
|
|
+
|
|
|
|
|
+ bool do_restore = false;
|
|
|
|
|
+
|
|
|
|
|
+ explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
|
|
|
|
|
+ old_state.head = cache.head;
|
|
|
|
|
+ old_state.n = cache.n;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // saves a slot information for future restoration
|
|
|
|
|
+ void save(const struct llama_kv_cache_slot_info & slot) {
|
|
|
|
|
+ if (slot) {
|
|
|
|
|
+ do_restore = true;
|
|
|
|
|
+ if (slot.boundaries.first != slot.boundaries.second) {
|
|
|
|
|
+ slot_boundaries.push_back(slot.boundaries);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // must be explicitly called to restore the kv_cache state
|
|
|
|
|
+ // and rollback changes from all llama_kv_cache_find_slot calls
|
|
|
|
|
+ void restore(struct llama_kv_cache & cache) {
|
|
|
|
|
+ if (do_restore) {
|
|
|
|
|
+ cache.head = old_state.head;
|
|
|
|
|
+ cache.n = old_state.n;
|
|
|
|
|
+
|
|
|
|
|
+ if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
|
|
|
|
|
+ llama_kv_cache_seq_rm(cache, -1, -1, -1);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ for (auto & slot : slot_boundaries) {
|
|
|
|
|
+ llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+};
|
|
|
|
|
+
|
|
|
//
|
|
//
|
|
|
// model loading and saving
|
|
// model loading and saving
|
|
|
//
|
|
//
|
|
@@ -17181,7 +17243,8 @@ static void llama_output_reorder(struct llama_context * ctx) {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-static void llama_graph_compute(
|
|
|
|
|
|
|
+// returns the result of ggml_backend_sched_graph_compute_async execution
|
|
|
|
|
+static enum ggml_status llama_graph_compute(
|
|
|
llama_context & lctx,
|
|
llama_context & lctx,
|
|
|
ggml_cgraph * gf,
|
|
ggml_cgraph * gf,
|
|
|
int n_threads,
|
|
int n_threads,
|
|
@@ -17196,15 +17259,20 @@ static void llama_graph_compute(
|
|
|
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
|
|
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- auto err = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
|
|
|
|
|
- if (err != GGML_STATUS_SUCCESS) {
|
|
|
|
|
- LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, err);
|
|
|
|
|
|
|
+ auto status = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
|
|
|
|
|
+ if (status != GGML_STATUS_SUCCESS) {
|
|
|
|
|
+ LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
|
|
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
|
|
|
|
|
+
|
|
|
|
|
+ return status;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// decode a batch of tokens by evaluating the transformer
|
|
// decode a batch of tokens by evaluating the transformer
|
|
|
|
|
+// in case of unsuccessful decoding (error or warning),
|
|
|
|
|
+// the kv_cache state will be returned to its original state
|
|
|
|
|
+// (for non-recurrent models) or cleaned (for recurrent models)
|
|
|
//
|
|
//
|
|
|
// - lctx: llama context
|
|
// - lctx: llama context
|
|
|
// - batch: batch to evaluate
|
|
// - batch: batch to evaluate
|
|
@@ -17254,6 +17322,7 @@ static int llama_decode_internal(
|
|
|
lctx.n_queued_tokens += n_tokens_all;
|
|
lctx.n_queued_tokens += n_tokens_all;
|
|
|
|
|
|
|
|
auto & kv_self = lctx.kv_self;
|
|
auto & kv_self = lctx.kv_self;
|
|
|
|
|
+ llama_kv_slot_restorer kv_slot_restorer(kv_self);
|
|
|
|
|
|
|
|
const int64_t n_embd = hparams.n_embd;
|
|
const int64_t n_embd = hparams.n_embd;
|
|
|
const int64_t n_vocab = hparams.n_vocab;
|
|
const int64_t n_vocab = hparams.n_vocab;
|
|
@@ -17338,9 +17407,11 @@ static int llama_decode_internal(
|
|
|
kv_self.head = 0;
|
|
kv_self.head = 0;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
|
|
|
|
|
|
|
+ const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
|
|
|
|
+ if (!slot) {
|
|
|
return 1;
|
|
return 1;
|
|
|
}
|
|
}
|
|
|
|
|
+ kv_slot_restorer.save(slot);
|
|
|
|
|
|
|
|
if (!kv_self.recurrent) {
|
|
if (!kv_self.recurrent) {
|
|
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
|
@@ -17387,7 +17458,19 @@ static int llama_decode_internal(
|
|
|
|
|
|
|
|
llama_set_inputs(lctx, ubatch);
|
|
llama_set_inputs(lctx, ubatch);
|
|
|
|
|
|
|
|
- llama_graph_compute(lctx, gf, n_threads, threadpool);
|
|
|
|
|
|
|
+ const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
|
|
|
|
|
+ if (compute_status != GGML_STATUS_SUCCESS) {
|
|
|
|
|
+ kv_slot_restorer.restore(kv_self);
|
|
|
|
|
+ switch (compute_status) {
|
|
|
|
|
+ case GGML_STATUS_ABORTED:
|
|
|
|
|
+ return 2;
|
|
|
|
|
+ case GGML_STATUS_ALLOC_FAILED:
|
|
|
|
|
+ return -2;
|
|
|
|
|
+ case GGML_STATUS_FAILED:
|
|
|
|
|
+ default:
|
|
|
|
|
+ return -3;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
// update the kv ring buffer
|
|
// update the kv ring buffer
|
|
|
{
|
|
{
|
|
@@ -17624,7 +17707,18 @@ static int llama_encode_internal(
|
|
|
|
|
|
|
|
llama_set_inputs(lctx, ubatch);
|
|
llama_set_inputs(lctx, ubatch);
|
|
|
|
|
|
|
|
- llama_graph_compute(lctx, gf, n_threads, threadpool);
|
|
|
|
|
|
|
+ const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
|
|
|
|
|
+ switch (compute_status) {
|
|
|
|
|
+ case GGML_STATUS_SUCCESS:
|
|
|
|
|
+ break;
|
|
|
|
|
+ case GGML_STATUS_ABORTED:
|
|
|
|
|
+ return 2;
|
|
|
|
|
+ case GGML_STATUS_ALLOC_FAILED:
|
|
|
|
|
+ return -2;
|
|
|
|
|
+ case GGML_STATUS_FAILED:
|
|
|
|
|
+ default:
|
|
|
|
|
+ return -3;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
// extract embeddings
|
|
// extract embeddings
|
|
|
if (embd) {
|
|
if (embd) {
|