Przeglądaj źródła

server : fix assistant prefilling when content is an array (#14360)

Sigbjørn Skjæret 6 miesięcy temu
rodzic
commit
ddef99522d

+ 23 - 0
tools/server/tests/unit/test_chat_completion.py

@@ -132,6 +132,28 @@ def test_chat_template():
     assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
     assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
 
 
 
 
+@pytest.mark.parametrize("prefill,re_prefill", [
+    ("Whill", "Whill"),
+    ([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
+])
+def test_chat_template_assistant_prefill(prefill, re_prefill):
+    global server
+    server.chat_template = "llama3"
+    server.debug = True  # to get the "__verbose" object in the response
+    server.start()
+    res = server.make_request("POST", "/chat/completions", data={
+        "max_tokens": 8,
+        "messages": [
+            {"role": "system", "content": "Book"},
+            {"role": "user", "content": "What is the best book"},
+            {"role": "assistant", "content": prefill},
+        ]
+    })
+    assert res.status_code == 200
+    assert "__verbose" in res.body
+    assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
+
+
 def test_apply_chat_template():
 def test_apply_chat_template():
     global server
     global server
     server.chat_template = "command-r"
     server.chat_template = "command-r"
@@ -228,6 +250,7 @@ def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re
     [{"role": "system", "content": 123}],
     [{"role": "system", "content": 123}],
     # [{"content": "hello"}], # TODO: should not be a valid case
     # [{"content": "hello"}], # TODO: should not be a valid case
     [{"role": "system", "content": "test"}, {}],
     [{"role": "system", "content": "test"}, {}],
+    [{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}],
 ])
 ])
 def test_invalid_chat_completion_req(messages):
 def test_invalid_chat_completion_req(messages):
     global server
     global server

+ 7 - 1
tools/server/utils.hpp

@@ -792,7 +792,13 @@ static json oaicompat_chat_params_parse(
 
 
     /* Append assistant prefilled message */
     /* Append assistant prefilled message */
     if (prefill_assistant_message) {
     if (prefill_assistant_message) {
-         chat_params.prompt += last_message.content;
+        if (!last_message.content_parts.empty()) {
+            for (auto & p : last_message.content_parts) {
+                chat_params.prompt += p.text;
+            }
+        } else {
+            chat_params.prompt += last_message.content;
+        }
     }
     }
 
 
     llama_params["chat_format"]      = static_cast<int>(chat_params.format);
     llama_params["chat_format"]      = static_cast<int>(chat_params.format);