Browse Source

server: fix memory reservations in populate_token_probs (#18787)

Lennart Austenfeld 1 week ago
parent
commit
18361c579c
1 changed files with 8 additions and 5 deletions
  1. 8 5
      tools/server/server-context.cpp

+ 8 - 5
tools/server/server-context.cpp

@@ -1326,11 +1326,12 @@ private:
     }
 
     void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
-        const size_t n_probs = slot.task->params.sampling.n_probs;
+        const size_t n_probs_request = slot.task->params.sampling.n_probs;
 
         if (post_sampling) {
             const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true);
             const size_t max_probs = cur_p->size;
+            const size_t n_probs = std::min(max_probs, n_probs_request);
 
             // set probability for sampled token
             for (size_t i = 0; i < max_probs; i++) {
@@ -1341,8 +1342,8 @@ private:
             }
 
             // set probability for top n_probs tokens
-            result.probs.reserve(max_probs);
-            for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
+            result.probs.reserve(n_probs);
+            for (size_t i = 0; i < n_probs; i++) {
                 result.probs.push_back({
                     cur_p->data[i].id,
                     common_token_to_piece(ctx, cur_p->data[i].id, special),
@@ -1352,9 +1353,11 @@ private:
         } else {
             // TODO: optimize this with min-p optimization
             std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
+            const size_t max_probs = cur.size();
+            const size_t n_probs = std::min(max_probs, n_probs_request);
 
             // set probability for sampled token
-            for (size_t i = 0; i < cur.size(); i++) {
+            for (size_t i = 0; i < max_probs; i++) {
                 // set probability for sampled token
                 if (cur[i].id == result.tok) {
                     result.prob = cur[i].p;
@@ -1364,7 +1367,7 @@ private:
 
             // set probability for top n_probs tokens
             result.probs.reserve(n_probs);
-            for (size_t i = 0; i < std::min(cur.size(), n_probs); i++) {
+            for (size_t i = 0; i < n_probs; i++) {
                 result.probs.push_back({
                     cur[i].id,
                     common_token_to_piece(ctx, cur[i].id, special),