1
0
Эх сурвалжийг харах

server: improve speed of speculative decoding (#17808)

* server: improve speed of speculative decoding

* fix small draft case

* add link to the PR

* server : fix generation time measurement

* server : fix draft acceptance logs (add SRV_CNT, SLT_CNT macros)

* server : add comment

* add PR to docs

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Xuan-Son Nguyen 1 сар өмнө
parent
commit
f896d2c34f

+ 1 - 0
tools/server/README-dev.md

@@ -81,6 +81,7 @@ For detailed instructions, see the [test documentation](./tests/README.md).
 - Separation of HTTP logic into dedicated files: https://github.com/ggml-org/llama.cpp/pull/17216
 - Large-scale code base split into smaller files: https://github.com/ggml-org/llama.cpp/pull/17362
 - Introduction of router mode: https://github.com/ggml-org/llama.cpp/pull/17470
+- Speculative decoding: https://github.com/ggml-org/llama.cpp/pull/17808 and rework in https://github.com/ggml-org/llama.cpp/pull/17808
 
 
 

+ 2 - 0
tools/server/server-common.h

@@ -18,11 +18,13 @@ const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "
 using json = nlohmann::ordered_json;
 
 #define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
+#define SLT_CNT(slot, fmt, ...) LOG_CNT(""                                 fmt,                                                                __VA_ARGS__)
 #define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
 #define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
 #define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
 
 #define SRV_INF(fmt, ...) LOG_INF("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define SRV_CNT(fmt, ...) LOG_CNT(""              fmt,               __VA_ARGS__)
 #define SRV_WRN(fmt, ...) LOG_WRN("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
 #define SRV_ERR(fmt, ...) LOG_ERR("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
 #define SRV_DBG(fmt, ...) LOG_DBG("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)

+ 105 - 76
tools/server/server-context.cpp

@@ -102,6 +102,11 @@ struct server_slot {
     std::string  generated_text;
     llama_tokens generated_tokens;
 
+    // idx of draft tokens in the main batch
+    // non-empty if we went to evaluate draft tokens
+    // ref: https://github.com/ggml-org/llama.cpp/pull/17808
+    std::vector<int32_t> i_batch_dft;
+
     std::vector<completion_token_output> generated_token_probs;
 
     bool has_next_token = true;
@@ -150,7 +155,8 @@ struct server_slot {
 
     struct common_sampler * smpl = nullptr;
 
-    llama_token sampled;
+    llama_token sampled; // in speculative mode, this is the last accepted token
+    llama_tokens drafted;
 
     // stats
     size_t n_sent_text = 0; // number of sent text character
@@ -180,6 +186,8 @@ struct server_slot {
         stopping_word  = "";
         n_sent_text    = 0;
 
+        drafted.clear();
+        i_batch_dft.clear();
         generated_tokens.clear();
         generated_token_probs.clear();
         json_schema = json();
@@ -255,6 +263,31 @@ struct server_slot {
         generated_token_probs.push_back(token);
     }
 
+    int get_n_draft_max() const {
+        if (!can_speculate()) {
+            return 0;
+        }
+
+        // determine the max draft that fits the current slot state
+        int n_draft_max = task->params.speculative.n_max;
+
+        // note: slot.prompt is not yet expanded with the `id` token sampled above
+        //       also, need to leave space for 1 extra token to allow context shifts
+        n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2);
+
+        if (n_remaining > 0) {
+            n_draft_max = std::min(n_draft_max, n_remaining - 1);
+        }
+
+        SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
+
+        if (n_draft_max < task->params.speculative.n_min) {
+            SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
+            n_draft_max = 0;
+        }
+        return n_draft_max;
+    }
+
     // note: a slot can also be either a parent or a child
     bool is_parent() const {
         return is_processing() && task->n_children > 0;
@@ -353,8 +386,7 @@ struct server_slot {
 
         if (n_draft_total > 0) {
             const float draft_ratio = (float) n_draft_accepted / n_draft_total;
-            SLT_INF(*this,
-                    "\n"
+            SLT_CNT(*this,
                     "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
                     draft_ratio, n_draft_accepted, n_draft_total
             );
@@ -1774,14 +1806,57 @@ struct server_context_impl {
                 continue;
             }
 
-            slot.i_batch = batch.n_tokens;
+            // generate draft tokens in speculative decoding mode
+            // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
+            //       perform the speculative drafting for all sequences at the same time in a single batch
+            int n_draft_max = slot.get_n_draft_max();
+            if (n_draft_max > 0) {
+                if (mctx) {
+                    // we should never reach this, as speculative is automatically disabled if mmproj is loaded
+                    GGML_ABORT("not supported by multimodal");
+                }
 
-            common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
+                struct common_speculative_params params_spec;
+                params_spec.n_draft = n_draft_max;
+                params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
+                params_spec.p_min   = slot.task->params.speculative.p_min;
+                const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
+                llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
+
+                // add the sampled token to the batch
+                slot.i_batch_dft.push_back(batch.n_tokens);
+                common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
+                slot.prompt.tokens.push_back(slot.sampled);
+
+                if (slot.task->params.speculative.n_min > (int) draft.size()) {
+                    SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
+                    // fallback to normal decoding
+                    slot.i_batch = slot.i_batch_dft[0];
+                    slot.drafted.clear();
+                    slot.i_batch_dft.clear();
+                } else {
+                    // keep track of total number of drafted tokens tested
+                    slot.n_draft_total += draft.size();
+
+                    // add all drafted tokens to the batch
+                    for (size_t i = 0; i < draft.size(); i++) {
+                        slot.i_batch_dft.push_back(batch.n_tokens);
+                        common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true);
+                        slot.prompt.tokens.push_back(draft[i]);
+                    }
+                    slot.drafted = std::move(draft);
+                }
+            } else {
+                // no speculative decoding
+                slot.i_batch = batch.n_tokens;
 
-            slot.prompt.tokens.push_back(slot.sampled);
+                common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
 
-            SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
-                    slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
+                slot.prompt.tokens.push_back(slot.sampled);
+
+                SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
+                        slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
+            }
         }
 
         // process in chunks of params.n_batch
@@ -2345,6 +2420,10 @@ struct server_context_impl {
             // on successful decode, restore the original batch size
             n_batch = llama_n_batch(ctx);
 
+            // technically, measuring the time here excludes the sampling time for the last batch
+            // but on the other hand, we don't want to do too many system calls to measure the time, so it's ok
+            const int64_t t_current = ggml_time_us();
+
             for (auto & slot : slots) {
                 // may need to copy state to other slots
                 if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
@@ -2399,6 +2478,10 @@ struct server_context_impl {
                     continue; // continue loop of slots
                 }
 
+                if (slot.i_batch_dft.size() > 0) {
+                    continue; // sample using speculative decoding
+                }
+
                 const int tok_idx = slot.i_batch - i;
 
                 llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
@@ -2409,8 +2492,6 @@ struct server_context_impl {
 
                 slot.n_decoded += 1;
 
-                const int64_t t_current = ggml_time_us();
-
                 if (slot.n_decoded == 1) {
                     slot.t_start_generation = t_current;
                     slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
@@ -2439,84 +2520,32 @@ struct server_context_impl {
                 }
             }
 
-            // do speculative decoding
-            // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
-            //       perform the speculative drafting for all sequences at the same time in a single batch
+            // speculative decoding - main model sample and accept
             for (auto & slot : slots) {
-                if (!slot.is_processing() || !slot.can_speculate()) {
+                if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) {
                     continue;
                 }
 
-                if (slot.state != SLOT_STATE_GENERATING) {
-                    continue;
-                }
-
-                if (mctx) {
-                    // we should never reach this, as speculative is automatically disabled if mmproj is loaded
-                    GGML_ABORT("not supported by multimodal");
-                }
-
-                // determine the max draft that fits the current slot state
-                int n_draft_max = slot.task->params.speculative.n_max;
-
-                // note: slot.prompt is not yet expanded with the `id` token sampled above
-                //       also, need to leave space for 1 extra token to allow context shifts
-                n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2);
-
-                if (slot.n_remaining > 0) {
-                    n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
-                }
-
-                SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
-
-                if (n_draft_max < slot.task->params.speculative.n_min) {
-                    SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min);
-
-                    continue;
-                }
-
-                llama_token id = slot.sampled;
-
-                struct common_speculative_params params_spec;
-                params_spec.n_draft = n_draft_max;
-                params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
-                params_spec.p_min   = slot.task->params.speculative.p_min;
-
-                const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
-                llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
-
-                // ignore small drafts
-                if (slot.task->params.speculative.n_min > (int) draft.size()) {
-                    SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
-
-                    continue;
-                }
-
-                // keep track of total number of drafted tokens tested
-                slot.n_draft_total += draft.size();
-
-                // construct the speculation batch
-                common_batch_clear(slot.batch_spec);
-                common_batch_add  (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true);
-
-                for (size_t i = 0; i < draft.size(); ++i) {
-                    common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true);
-                }
-
-                SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
-
-                llama_decode(ctx, slot.batch_spec);
+                size_t n_draft = slot.drafted.size();
 
                 // the accepted tokens from the speculation
-                const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
+                const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, slot.i_batch_dft, slot.drafted);
+                slot.i_batch_dft.clear();
+                slot.drafted.clear();
 
                 slot.n_decoded += ids.size();
 
+                slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
+
                 // update how many tokens out of those tested were accepted
                 slot.n_draft_accepted += ids.size() - 1;
 
-                slot.prompt.tokens.push_back(id);
+                // rollback to the state before sampling the draft tokens
+                slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
+
+                // add accepted tokens to the prompt
                 slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
+                slot.sampled = ids.back(); // last accepted token
 
                 llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
 
@@ -2539,7 +2568,7 @@ struct server_context_impl {
                     }
                 }
 
-                SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens());
+                SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) slot.drafted.size(), slot.prompt.n_tokens());
             }
         }