|
|
@@ -424,28 +424,33 @@ const llama_kv_cache * llama_context::get_kv_self() const {
|
|
|
return kv_self;
|
|
|
}
|
|
|
|
|
|
-void llama_context::kv_self_update() {
|
|
|
+bool llama_context::kv_self_update() {
|
|
|
if (!memory) {
|
|
|
- return;
|
|
|
+ return false;
|
|
|
}
|
|
|
|
|
|
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
|
|
|
|
- if (kv_self->update(*this)) {
|
|
|
- // if the KV cache did any computation, we have to reserve a new worst-case graph
|
|
|
- const auto kv_state = kv_self->init_full();
|
|
|
- if (!kv_state) {
|
|
|
- throw std::runtime_error("failed to initialize KV cache");
|
|
|
- }
|
|
|
+ if (!kv_self->update(*this)) {
|
|
|
+ // no updates have been performed
|
|
|
+ return false;
|
|
|
+ }
|
|
|
|
|
|
- const uint32_t n_seqs = cparams.n_seq_max;
|
|
|
- const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
|
+ // if the KV cache did any computation, we have to reserve a new worst-case graph
|
|
|
+ const auto kv_state = kv_self->init_full();
|
|
|
+ if (!kv_state) {
|
|
|
+ throw std::runtime_error("failed to initialize KV cache");
|
|
|
+ }
|
|
|
|
|
|
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
|
|
- if (!gf) {
|
|
|
- LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
|
|
|
- }
|
|
|
+ const uint32_t n_seqs = cparams.n_seq_max;
|
|
|
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
|
+
|
|
|
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
|
|
+ if (!gf) {
|
|
|
+ LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
|
|
|
}
|
|
|
+
|
|
|
+ return true;
|
|
|
}
|
|
|
|
|
|
enum llama_pooling_type llama_context::pooling_type() const {
|
|
|
@@ -933,24 +938,44 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
// handle any pending defrags/shifts
|
|
|
kv_self_update();
|
|
|
|
|
|
- auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
|
|
- if (!kv_state) {
|
|
|
- return -2;
|
|
|
- }
|
|
|
+ llama_memory_state_ptr kv_state;
|
|
|
|
|
|
- switch (kv_state->get_status()) {
|
|
|
- case LLAMA_MEMORY_STATUS_SUCCESS:
|
|
|
- {
|
|
|
- } break;
|
|
|
- case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
|
|
- {
|
|
|
- // not a fatal error, we can re-try with a different batch
|
|
|
- return 1;
|
|
|
- }
|
|
|
- case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
|
|
- {
|
|
|
- return -2;
|
|
|
- }
|
|
|
+ bool did_defrag = false;
|
|
|
+
|
|
|
+ while (true) {
|
|
|
+ kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
|
|
+ if (!kv_state) {
|
|
|
+ return -2;
|
|
|
+ }
|
|
|
+
|
|
|
+ switch (kv_state->get_status()) {
|
|
|
+ case LLAMA_MEMORY_STATUS_SUCCESS:
|
|
|
+ {
|
|
|
+ } break;
|
|
|
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
|
|
+ {
|
|
|
+ if (!did_defrag) {
|
|
|
+ did_defrag = true;
|
|
|
+
|
|
|
+ kv_self->defrag_sched(-1.0f);
|
|
|
+ if (kv_self_update()) {
|
|
|
+ LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
|
|
|
+
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
|
|
+
|
|
|
+ return 1;
|
|
|
+ }
|
|
|
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
|
|
+ {
|
|
|
+ return -2;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ break;
|
|
|
}
|
|
|
|
|
|
// reserve output buffer
|
|
|
@@ -2646,22 +2671,8 @@ int32_t llama_encode(
|
|
|
int32_t llama_decode(
|
|
|
llama_context * ctx,
|
|
|
llama_batch batch) {
|
|
|
- int ret = ctx->decode(batch);
|
|
|
-
|
|
|
- // defrag and try again
|
|
|
- // TODO: distinguish return code when we are sure that even after defrag there is no space available
|
|
|
- if (ret == 1) {
|
|
|
- llama_kv_self_defrag(ctx);
|
|
|
- ret = ctx->decode(batch);
|
|
|
-
|
|
|
- if (ret == 1) {
|
|
|
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
|
|
-
|
|
|
- return ret;
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- if (ret != 0) {
|
|
|
+ const int ret = ctx->decode(batch);
|
|
|
+ if (ret != 0 && ret != 1) {
|
|
|
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
|
|
}
|
|
|
|