|
|
@@ -764,7 +764,7 @@ struct completion_token_output {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-struct swa_checkpoint {
|
|
|
+struct ctx_checkpoint {
|
|
|
llama_pos pos_min;
|
|
|
llama_pos pos_max;
|
|
|
|
|
|
@@ -1460,7 +1460,7 @@ struct server_slot {
|
|
|
|
|
|
std::vector<completion_token_output> generated_token_probs;
|
|
|
|
|
|
- std::vector<swa_checkpoint> swa_checkpoints;
|
|
|
+ std::vector<ctx_checkpoint> ctx_checkpoints;
|
|
|
|
|
|
bool has_next_token = true;
|
|
|
bool has_new_line = false;
|
|
|
@@ -3541,7 +3541,11 @@ struct server_context {
|
|
|
slot.n_past = 0;
|
|
|
}
|
|
|
|
|
|
- const auto n_swa = llama_model_n_swa(model);
|
|
|
+ // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
|
|
|
+ const auto n_swa = std::max(1, llama_model_n_swa(model));
|
|
|
+
|
|
|
+ // the largest pos_min required for a checkpoint to be useful
|
|
|
+ const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
|
|
|
|
|
|
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
|
|
|
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
|
|
@@ -3550,66 +3554,62 @@ struct server_context {
|
|
|
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
|
|
}
|
|
|
|
|
|
- const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
|
|
|
-
|
|
|
if (pos_min > pos_min_thold) {
|
|
|
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
|
|
|
|
|
|
- // search for a SWA checkpoint
|
|
|
+ // search for a context checkpoint
|
|
|
const auto it = std::find_if(
|
|
|
- slot.swa_checkpoints.rbegin(),
|
|
|
- slot.swa_checkpoints.rend(),
|
|
|
+ slot.ctx_checkpoints.rbegin(),
|
|
|
+ slot.ctx_checkpoints.rend(),
|
|
|
[&](const auto & cur) {
|
|
|
- return cur.pos_min <= pos_min_thold;
|
|
|
+ // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
|
|
|
+ return cur.pos_min < pos_min_thold;
|
|
|
}
|
|
|
);
|
|
|
|
|
|
- bool do_reset = it == slot.swa_checkpoints.rend();
|
|
|
+ bool do_reset = it == slot.ctx_checkpoints.rend();
|
|
|
+ //printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false");
|
|
|
|
|
|
if (!do_reset) {
|
|
|
- // restore the checkpoint
|
|
|
- const size_t swa_size = it->data.size();
|
|
|
- const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
|
|
|
+ // restore the context checkpoint
|
|
|
+ const size_t ctx_checkpoint_size = it->data.size();
|
|
|
+ const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
|
|
- if (n != swa_size) {
|
|
|
- SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
|
|
|
+ if (n != ctx_checkpoint_size) {
|
|
|
+ SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
|
|
|
do_reset = true;
|
|
|
+ //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
|
|
} else {
|
|
|
- slot.n_past = std::min(slot.n_past, it->pos_max);
|
|
|
-
|
|
|
- SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
|
|
|
+ slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
|
|
|
+ SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
if (do_reset) {
|
|
|
- SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
|
|
|
+ SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
|
|
|
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
|
|
-
|
|
|
slot.n_past = 0;
|
|
|
- slot.swa_checkpoints.clear();
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if (n_swa > 0) {
|
|
|
- const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
|
|
|
-
|
|
|
+ {
|
|
|
// erase any checkpoints with pos_min > pos_min_thold
|
|
|
- for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) {
|
|
|
- const auto & cur = slot.swa_checkpoints[i];
|
|
|
+ for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) {
|
|
|
+ const auto & cur = slot.ctx_checkpoints[i];
|
|
|
if (cur.pos_min > pos_min_thold) {
|
|
|
- slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i);
|
|
|
-
|
|
|
- SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
|
|
+ SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
|
|
|
+ slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ // [TAG_PROMPT_LOGITS]
|
|
|
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
|
|
|
- SLT_WRN(slot, "need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
|
|
|
-
|
|
|
+ SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens);
|
|
|
slot.n_past--;
|
|
|
+ SLT_WRN(slot, "n_past was set to %d\n", slot.n_past);
|
|
|
}
|
|
|
|
|
|
slot.n_prompt_tokens_cache = slot.n_past;
|
|
|
@@ -3623,9 +3623,9 @@ struct server_context {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // keep only the common part
|
|
|
+ // truncate any tokens that are beyond n_past for this slot
|
|
|
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) {
|
|
|
- // could not partially delete (likely using a non-Transformer model)
|
|
|
+ SLT_WRN(slot, "failed to truncate tokens beyond n_past = %d\n", slot.n_past);
|
|
|
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
|
|
|
|
|
// there is no common part left
|
|
|
@@ -3633,7 +3633,7 @@ struct server_context {
|
|
|
slot.n_prompt_tokens_cache = 0;
|
|
|
}
|
|
|
|
|
|
- SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
|
|
+ SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past);
|
|
|
|
|
|
// remove the non-common part from the cache
|
|
|
slot.cache_tokens.keep_first(slot.n_past);
|
|
|
@@ -3854,37 +3854,38 @@ struct server_context {
|
|
|
// prompt evaluated for next-token prediction
|
|
|
slot.state = SLOT_STATE_GENERATING;
|
|
|
|
|
|
- // make a checkpoint with the SWA memory
|
|
|
- // checkpoints are needed only if we are not using "--swa-full"
|
|
|
- if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) {
|
|
|
- if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) {
|
|
|
- {
|
|
|
- const auto & cur = slot.swa_checkpoints.back();
|
|
|
-
|
|
|
- SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n",
|
|
|
- cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
|
|
- }
|
|
|
-
|
|
|
- slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
|
|
|
+ // make a checkpoint of the parts of the memory that cannot be rolled back.
|
|
|
+ // checkpoints are created only if:
|
|
|
+ // - the model uses SWA and we are not using `swa_full`
|
|
|
+ // - the model architecture is marked as recurrent or hybrid
|
|
|
+ //
|
|
|
+ // TODO: try to make this conditional on the context or the memory module, instead of the model type
|
|
|
+ const bool do_checkpoint =
|
|
|
+ (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
|
|
|
+ (llama_model_n_swa(model) > 0 && !params_base.swa_full);
|
|
|
+
|
|
|
+ if (do_checkpoint && params_base.n_ctx_checkpoints > 0) {
|
|
|
+ while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
|
|
+ // make room for the new checkpoint, if needed
|
|
|
+ const auto & cur = slot.ctx_checkpoints.front();
|
|
|
+ SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
|
|
+ cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
|
|
+
|
|
|
+ slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
|
|
|
}
|
|
|
|
|
|
- const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
|
|
|
+ const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
|
|
- auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{
|
|
|
+ auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
|
|
|
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
|
|
|
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
|
|
|
- /*.data = */ std::vector<uint8_t>(swa_size),
|
|
|
+ /*.data = */ std::vector<uint8_t>(checkpoint_size),
|
|
|
});
|
|
|
|
|
|
- llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
|
|
|
-
|
|
|
- float size_total = 0.0f;
|
|
|
- for (const auto & checkpoint : slot.swa_checkpoints) {
|
|
|
- size_total += (float) checkpoint.data.size() / 1024 / 1024;
|
|
|
- }
|
|
|
+ llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
|
|
- SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n",
|
|
|
- cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total);
|
|
|
+ SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
|
|
+ (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
|
|
}
|
|
|
} else if (slot.state != SLOT_STATE_GENERATING) {
|
|
|
continue; // continue loop of slots
|