|
|
@@ -102,6 +102,11 @@ struct server_slot {
|
|
|
std::string generated_text;
|
|
|
llama_tokens generated_tokens;
|
|
|
|
|
|
+ // idx of draft tokens in the main batch
|
|
|
+ // non-empty if we went to evaluate draft tokens
|
|
|
+ // ref: https://github.com/ggml-org/llama.cpp/pull/17808
|
|
|
+ std::vector<int32_t> i_batch_dft;
|
|
|
+
|
|
|
std::vector<completion_token_output> generated_token_probs;
|
|
|
|
|
|
bool has_next_token = true;
|
|
|
@@ -150,7 +155,8 @@ struct server_slot {
|
|
|
|
|
|
struct common_sampler * smpl = nullptr;
|
|
|
|
|
|
- llama_token sampled;
|
|
|
+ llama_token sampled; // in speculative mode, this is the last accepted token
|
|
|
+ llama_tokens drafted;
|
|
|
|
|
|
// stats
|
|
|
size_t n_sent_text = 0; // number of sent text character
|
|
|
@@ -180,6 +186,8 @@ struct server_slot {
|
|
|
stopping_word = "";
|
|
|
n_sent_text = 0;
|
|
|
|
|
|
+ drafted.clear();
|
|
|
+ i_batch_dft.clear();
|
|
|
generated_tokens.clear();
|
|
|
generated_token_probs.clear();
|
|
|
json_schema = json();
|
|
|
@@ -255,6 +263,31 @@ struct server_slot {
|
|
|
generated_token_probs.push_back(token);
|
|
|
}
|
|
|
|
|
|
+ int get_n_draft_max() const {
|
|
|
+ if (!can_speculate()) {
|
|
|
+ return 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ // determine the max draft that fits the current slot state
|
|
|
+ int n_draft_max = task->params.speculative.n_max;
|
|
|
+
|
|
|
+ // note: slot.prompt is not yet expanded with the `id` token sampled above
|
|
|
+ // also, need to leave space for 1 extra token to allow context shifts
|
|
|
+ n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2);
|
|
|
+
|
|
|
+ if (n_remaining > 0) {
|
|
|
+ n_draft_max = std::min(n_draft_max, n_remaining - 1);
|
|
|
+ }
|
|
|
+
|
|
|
+ SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
|
|
|
+
|
|
|
+ if (n_draft_max < task->params.speculative.n_min) {
|
|
|
+ SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
|
|
|
+ n_draft_max = 0;
|
|
|
+ }
|
|
|
+ return n_draft_max;
|
|
|
+ }
|
|
|
+
|
|
|
// note: a slot can also be either a parent or a child
|
|
|
bool is_parent() const {
|
|
|
return is_processing() && task->n_children > 0;
|
|
|
@@ -353,8 +386,7 @@ struct server_slot {
|
|
|
|
|
|
if (n_draft_total > 0) {
|
|
|
const float draft_ratio = (float) n_draft_accepted / n_draft_total;
|
|
|
- SLT_INF(*this,
|
|
|
- "\n"
|
|
|
+ SLT_CNT(*this,
|
|
|
"draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
|
|
|
draft_ratio, n_draft_accepted, n_draft_total
|
|
|
);
|
|
|
@@ -1774,14 +1806,57 @@ struct server_context_impl {
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
- slot.i_batch = batch.n_tokens;
|
|
|
+ // generate draft tokens in speculative decoding mode
|
|
|
+ // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
|
|
+ // perform the speculative drafting for all sequences at the same time in a single batch
|
|
|
+ int n_draft_max = slot.get_n_draft_max();
|
|
|
+ if (n_draft_max > 0) {
|
|
|
+ if (mctx) {
|
|
|
+ // we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
|
|
+ GGML_ABORT("not supported by multimodal");
|
|
|
+ }
|
|
|
|
|
|
- common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
|
|
|
+ struct common_speculative_params params_spec;
|
|
|
+ params_spec.n_draft = n_draft_max;
|
|
|
+ params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
|
|
|
+ params_spec.p_min = slot.task->params.speculative.p_min;
|
|
|
+ const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
|
|
|
+ llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
|
|
|
+
|
|
|
+ // add the sampled token to the batch
|
|
|
+ slot.i_batch_dft.push_back(batch.n_tokens);
|
|
|
+ common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
|
|
|
+ slot.prompt.tokens.push_back(slot.sampled);
|
|
|
+
|
|
|
+ if (slot.task->params.speculative.n_min > (int) draft.size()) {
|
|
|
+ SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
|
|
|
+ // fallback to normal decoding
|
|
|
+ slot.i_batch = slot.i_batch_dft[0];
|
|
|
+ slot.drafted.clear();
|
|
|
+ slot.i_batch_dft.clear();
|
|
|
+ } else {
|
|
|
+ // keep track of total number of drafted tokens tested
|
|
|
+ slot.n_draft_total += draft.size();
|
|
|
+
|
|
|
+ // add all drafted tokens to the batch
|
|
|
+ for (size_t i = 0; i < draft.size(); i++) {
|
|
|
+ slot.i_batch_dft.push_back(batch.n_tokens);
|
|
|
+ common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true);
|
|
|
+ slot.prompt.tokens.push_back(draft[i]);
|
|
|
+ }
|
|
|
+ slot.drafted = std::move(draft);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // no speculative decoding
|
|
|
+ slot.i_batch = batch.n_tokens;
|
|
|
|
|
|
- slot.prompt.tokens.push_back(slot.sampled);
|
|
|
+ common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
|
|
|
|
|
|
- SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
|
|
|
- slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
|
|
|
+ slot.prompt.tokens.push_back(slot.sampled);
|
|
|
+
|
|
|
+ SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
|
|
|
+ slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// process in chunks of params.n_batch
|
|
|
@@ -2345,6 +2420,10 @@ struct server_context_impl {
|
|
|
// on successful decode, restore the original batch size
|
|
|
n_batch = llama_n_batch(ctx);
|
|
|
|
|
|
+ // technically, measuring the time here excludes the sampling time for the last batch
|
|
|
+ // but on the other hand, we don't want to do too many system calls to measure the time, so it's ok
|
|
|
+ const int64_t t_current = ggml_time_us();
|
|
|
+
|
|
|
for (auto & slot : slots) {
|
|
|
// may need to copy state to other slots
|
|
|
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
|
|
|
@@ -2399,6 +2478,10 @@ struct server_context_impl {
|
|
|
continue; // continue loop of slots
|
|
|
}
|
|
|
|
|
|
+ if (slot.i_batch_dft.size() > 0) {
|
|
|
+ continue; // sample using speculative decoding
|
|
|
+ }
|
|
|
+
|
|
|
const int tok_idx = slot.i_batch - i;
|
|
|
|
|
|
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
|
|
@@ -2409,8 +2492,6 @@ struct server_context_impl {
|
|
|
|
|
|
slot.n_decoded += 1;
|
|
|
|
|
|
- const int64_t t_current = ggml_time_us();
|
|
|
-
|
|
|
if (slot.n_decoded == 1) {
|
|
|
slot.t_start_generation = t_current;
|
|
|
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
|
|
@@ -2439,84 +2520,32 @@ struct server_context_impl {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // do speculative decoding
|
|
|
- // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
|
|
- // perform the speculative drafting for all sequences at the same time in a single batch
|
|
|
+ // speculative decoding - main model sample and accept
|
|
|
for (auto & slot : slots) {
|
|
|
- if (!slot.is_processing() || !slot.can_speculate()) {
|
|
|
+ if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) {
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
- if (slot.state != SLOT_STATE_GENERATING) {
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- if (mctx) {
|
|
|
- // we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
|
|
- GGML_ABORT("not supported by multimodal");
|
|
|
- }
|
|
|
-
|
|
|
- // determine the max draft that fits the current slot state
|
|
|
- int n_draft_max = slot.task->params.speculative.n_max;
|
|
|
-
|
|
|
- // note: slot.prompt is not yet expanded with the `id` token sampled above
|
|
|
- // also, need to leave space for 1 extra token to allow context shifts
|
|
|
- n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2);
|
|
|
-
|
|
|
- if (slot.n_remaining > 0) {
|
|
|
- n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
|
|
|
- }
|
|
|
-
|
|
|
- SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
|
|
|
-
|
|
|
- if (n_draft_max < slot.task->params.speculative.n_min) {
|
|
|
- SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min);
|
|
|
-
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- llama_token id = slot.sampled;
|
|
|
-
|
|
|
- struct common_speculative_params params_spec;
|
|
|
- params_spec.n_draft = n_draft_max;
|
|
|
- params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
|
|
|
- params_spec.p_min = slot.task->params.speculative.p_min;
|
|
|
-
|
|
|
- const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
|
|
|
- llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
|
|
|
-
|
|
|
- // ignore small drafts
|
|
|
- if (slot.task->params.speculative.n_min > (int) draft.size()) {
|
|
|
- SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
|
|
|
-
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- // keep track of total number of drafted tokens tested
|
|
|
- slot.n_draft_total += draft.size();
|
|
|
-
|
|
|
- // construct the speculation batch
|
|
|
- common_batch_clear(slot.batch_spec);
|
|
|
- common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true);
|
|
|
-
|
|
|
- for (size_t i = 0; i < draft.size(); ++i) {
|
|
|
- common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true);
|
|
|
- }
|
|
|
-
|
|
|
- SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
|
|
|
-
|
|
|
- llama_decode(ctx, slot.batch_spec);
|
|
|
+ size_t n_draft = slot.drafted.size();
|
|
|
|
|
|
// the accepted tokens from the speculation
|
|
|
- const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
|
|
+ const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, slot.i_batch_dft, slot.drafted);
|
|
|
+ slot.i_batch_dft.clear();
|
|
|
+ slot.drafted.clear();
|
|
|
|
|
|
slot.n_decoded += ids.size();
|
|
|
|
|
|
+ slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
|
|
+
|
|
|
// update how many tokens out of those tested were accepted
|
|
|
slot.n_draft_accepted += ids.size() - 1;
|
|
|
|
|
|
- slot.prompt.tokens.push_back(id);
|
|
|
+ // rollback to the state before sampling the draft tokens
|
|
|
+ slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
|
|
|
+
|
|
|
+ // add accepted tokens to the prompt
|
|
|
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
|
|
|
+ slot.sampled = ids.back(); // last accepted token
|
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
|
|
|
|
|
|
@@ -2539,7 +2568,7 @@ struct server_context_impl {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens());
|
|
|
+ SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) slot.drafted.size(), slot.prompt.n_tokens());
|
|
|
}
|
|
|
}
|
|
|
|