Просмотр исходного кода

server : fix crash when system prompt is bigger than batch size (#5714)

The system prompt is now decoded in batches.

* server : fix off-by-one n_past when start of prompt matches whole cache

The tokens right after the matching part would otherwise skip a pos value.
compilade 1 год назад
Родитель
Сommit
f7625019c5
1 измененных файлов с 25 добавлено и 3 удалено
  1. 25 3
      examples/server/server.cpp

+ 25 - 3
examples/server/server.cpp

@@ -902,10 +902,24 @@ struct llama_server_context
                 llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
                 llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
             }
             }
 
 
-            if (llama_decode(ctx, batch) != 0)
+            for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch)
             {
             {
-                LOG_TEE("%s: llama_decode() failed\n", __func__);
-                return;
+                const int32_t n_tokens = std::min(params.n_batch, (int32_t) (batch.n_tokens - i));
+                llama_batch batch_view = {
+                    n_tokens,
+                    batch.token    + i,
+                    nullptr,
+                    batch.pos      + i,
+                    batch.n_seq_id + i,
+                    batch.seq_id   + i,
+                    batch.logits   + i,
+                    0, 0, 0, // unused
+                };
+                if (llama_decode(ctx, batch_view) != 0)
+                {
+                    LOG_TEE("%s: llama_decode() failed\n", __func__);
+                    return;
+                }
             }
             }
 
 
             // assign the system KV cache to all parallel sequences
             // assign the system KV cache to all parallel sequences
@@ -1785,6 +1799,14 @@ struct llama_server_context
                         }
                         }
 
 
                         slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
                         slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
+
+                        // the last token of the cache is not in the KV cache until the next call to llama_decode
+                        // (it was sampled, pushed into the "cache_tokens", but not yet put in the context)
+                        if (slot.n_past > 0 && slot.n_past == (int32_t) slot.cache_tokens.size())
+                        {
+                            slot.n_past -= 1;
+                        }
+
                         slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
                         slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
 
 
                         if (slot.ga_n != 1)
                         if (slot.ga_n != 1)