Răsfoiți Sursa

server : various fixes for the prompt field in /completion (#5300)

server : fix deadlock when prompt array contains strings and numbers

server : removed an unnecessary generation when generating multi-prompts

server : removed an unnecessary assert
Niall Coates 1 an în urmă
părinte
comite
4ffc7a17d4
1 a modificat fișierele cu 27 adăugiri și 7 ștergeri
  1. 27 7
      examples/server/server.cpp

+ 27 - 7
examples/server/server.cpp

@@ -1163,13 +1163,30 @@ struct llama_server_context
         task.multitask_id = multitask_id;
 
         // when a completion task's prompt array is not a singleton, we split it into multiple requests
-        if (task.data.count("prompt") && task.data.at("prompt").size() > 1)
-        {
-            split_multiprompt_task(task_id, task);
-        }
-
         // otherwise, it's a single-prompt task, we actually queue it
-        queue_tasks.post(task);
+        // if there's numbers in the prompt array it will be treated as an array of tokens
+        if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
+            bool numbers = false;
+            for (const auto& e : task.data.at("prompt")) {
+                if (e.is_number()) {
+                    numbers = true;
+                    break;
+                }
+            }
+
+            // NOTE: split_multiprompt_task() does not handle a mix of strings and numbers,
+            // it will completely stall the server. I don't know where the bug for this is.
+            //
+            // if there are numbers, it needs to be treated like a single prompt,
+            // queue_tasks handles a mix of strings and numbers just fine.
+            if (numbers) {
+                queue_tasks.post(task);
+            } else {
+                split_multiprompt_task(task_id, task);
+            }
+        } else {
+            queue_tasks.post(task);
+        }
     }
 
     // for multiple images processing
@@ -1251,7 +1268,10 @@ struct llama_server_context
     void split_multiprompt_task(int multitask_id, task_server& multiprompt_task)
     {
         int prompt_count = multiprompt_task.data.at("prompt").size();
-        assert(prompt_count > 1);
+        if (prompt_count <= 1) {
+            send_error(multiprompt_task, "error while handling multiple prompts");
+            return;
+        }
 
         // generate all the ID for subtask
         std::vector<int> subtask_ids(prompt_count);