Просмотр исходного кода

server : (refactoring) do not rely on JSON internally (#10643)

* server : (refactoring) reduce usage of json internally

* move all response types to struct

* wip [no ci]

* many fixes

* add virtual function

* fix index

* minor style fix

* add std::move

* refactor handle_completions_generic

* add virtual functions

* remove server.hpp

* clarify server_sent_event RFC specs

* apply review comments

* fix model_alias and completion_probabilities

* small clean up

* remove virtual for to_json_oai_compat()

* naming oai_compat --> oaicompat

* fix unwanted recursive call

* update docs
Xuan Son Nguyen 1 год назад
Родитель
Сommit
6c5bc0625f

+ 1 - 1
common/common.h

@@ -215,7 +215,7 @@ struct common_params {
     struct common_params_speculative speculative;
     struct common_params_speculative speculative;
 
 
     std::string model                = ""; // model path                                                    // NOLINT
     std::string model                = ""; // model path                                                    // NOLINT
-    std::string model_alias          = "unknown"; // model alias                                            // NOLINT
+    std::string model_alias          = ""; // model alias                                                   // NOLINT
     std::string model_url            = ""; // model url to download                                         // NOLINT
     std::string model_url            = ""; // model url to download                                         // NOLINT
     std::string hf_token             = ""; // HF token                                                      // NOLINT
     std::string hf_token             = ""; // HF token                                                      // NOLINT
     std::string hf_repo              = ""; // HF repo                                                       // NOLINT
     std::string hf_repo              = ""; // HF repo                                                       // NOLINT

+ 5 - 3
examples/server/README.md

@@ -473,9 +473,11 @@ Notice that each `probs` is an array of length `n_probs`.
 - `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.).
 - `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.).
 - `model`: The path to the model loaded with `-m`
 - `model`: The path to the model loaded with `-m`
 - `prompt`: The provided `prompt`
 - `prompt`: The provided `prompt`
-- `stopped_eos`: Indicating whether the completion has stopped because it encountered the EOS token
-- `stopped_limit`: Indicating whether the completion stopped because `n_predict` tokens were generated before stop words or EOS was encountered
-- `stopped_word`: Indicating whether the completion stopped due to encountering a stopping word from `stop` JSON array provided
+- `stop_type`: Indicating whether the completion has stopped. Possible values are:
+  - `none`: Generating (not stopped)
+  - `eos`: Stopped because it encountered the EOS token
+  - `limit`: Stopped because `n_predict` tokens were generated before stop words or EOS was encountered
+  - `word`: Stopped due to encountering a stopping word from `stop` JSON array provided
 - `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word)
 - `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word)
 - `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second`
 - `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second`
 - `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`)
 - `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`)

Разница между файлами не показана из-за своего большого размера
+ 683 - 134
examples/server/server.cpp


+ 6 - 0
examples/server/tests/README.md

@@ -44,4 +44,10 @@ To run with stdout/stderr display in real time (verbose output, but useful for d
 DEBUG=1 ./tests.sh -s -v -x
 DEBUG=1 ./tests.sh -s -v -x
 ```
 ```
 
 
+Hint: You can compile and run test in single command, useful for local developement:
+
+```shell
+cmake --build build -j --target llama-server && ./examples/server/tests/tests.sh
+```
+
 To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html)
 To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html)

+ 4 - 0
examples/server/tests/tests.sh

@@ -1,5 +1,9 @@
 #!/bin/bash
 #!/bin/bash
 
 
+# make sure we are in the right directory
+SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
+cd $SCRIPT_DIR
+
 set -eu
 set -eu
 
 
 if [ $# -lt 1 ]
 if [ $# -lt 1 ]

+ 14 - 19
examples/server/tests/unit/test_chat_completion.py

@@ -12,13 +12,13 @@ def create_server():
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
-    "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
+    "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
     [
     [
-        ("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
-        ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
+        (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
+        ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
     ]
     ]
 )
 )
-def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
+def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
     global server
     global server
     server.start()
     server.start()
     res = server.make_request("POST", "/chat/completions", data={
     res = server.make_request("POST", "/chat/completions", data={
@@ -30,29 +30,27 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
         ],
         ],
     })
     })
     assert res.status_code == 200
     assert res.status_code == 200
+    assert res.body["model"] == model if model is not None else server.model_alias
     assert res.body["usage"]["prompt_tokens"] == n_prompt
     assert res.body["usage"]["prompt_tokens"] == n_prompt
     assert res.body["usage"]["completion_tokens"] == n_predicted
     assert res.body["usage"]["completion_tokens"] == n_predicted
     choice = res.body["choices"][0]
     choice = res.body["choices"][0]
     assert "assistant" == choice["message"]["role"]
     assert "assistant" == choice["message"]["role"]
     assert match_regex(re_content, choice["message"]["content"])
     assert match_regex(re_content, choice["message"]["content"])
-    if truncated:
-        assert choice["finish_reason"] == "length"
-    else:
-        assert choice["finish_reason"] == "stop"
+    assert choice["finish_reason"] == finish_reason
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
-    "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
+    "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
     [
     [
-        ("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
-        ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
+        ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
+        ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
     ]
     ]
 )
 )
-def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
+def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
     global server
     global server
+    server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL
     server.start()
     server.start()
     res = server.make_stream_request("POST", "/chat/completions", data={
     res = server.make_stream_request("POST", "/chat/completions", data={
-        "model": model,
         "max_tokens": max_tokens,
         "max_tokens": max_tokens,
         "messages": [
         "messages": [
             {"role": "system", "content": system_prompt},
             {"role": "system", "content": system_prompt},
@@ -63,16 +61,13 @@ def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, r
     content = ""
     content = ""
     for data in res:
     for data in res:
         choice = data["choices"][0]
         choice = data["choices"][0]
+        assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
         if choice["finish_reason"] in ["stop", "length"]:
         if choice["finish_reason"] in ["stop", "length"]:
             assert data["usage"]["prompt_tokens"] == n_prompt
             assert data["usage"]["prompt_tokens"] == n_prompt
             assert data["usage"]["completion_tokens"] == n_predicted
             assert data["usage"]["completion_tokens"] == n_predicted
             assert "content" not in choice["delta"]
             assert "content" not in choice["delta"]
             assert match_regex(re_content, content)
             assert match_regex(re_content, content)
-            # FIXME: not sure why this is incorrect in stream mode
-            # if truncated:
-            #   assert choice["finish_reason"] == "length"
-            # else:
-            #   assert choice["finish_reason"] == "stop"
+            assert choice["finish_reason"] == finish_reason
         else:
         else:
             assert choice["finish_reason"] is None
             assert choice["finish_reason"] is None
             content += choice["delta"]["content"]
             content += choice["delta"]["content"]
@@ -93,7 +88,7 @@ def test_chat_completion_with_openai_library():
         temperature=0.8,
         temperature=0.8,
     )
     )
     print(res)
     print(res)
-    assert res.choices[0].finish_reason == "stop"
+    assert res.choices[0].finish_reason == "length"
     assert res.choices[0].message.content is not None
     assert res.choices[0].message.content is not None
     assert match_regex("(Suddenly)+", res.choices[0].message.content)
     assert match_regex("(Suddenly)+", res.choices[0].message.content)
 
 

+ 39 - 0
examples/server/tests/unit/test_completion.py

@@ -51,6 +51,24 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
             content += data["content"]
             content += data["content"]
 
 
 
 
+def test_completion_stream_vs_non_stream():
+    global server
+    server.start()
+    res_stream = server.make_stream_request("POST", "/completion", data={
+        "n_predict": 8,
+        "prompt": "I believe the meaning of life is",
+        "stream": True,
+    })
+    res_non_stream = server.make_request("POST", "/completion", data={
+        "n_predict": 8,
+        "prompt": "I believe the meaning of life is",
+    })
+    content_stream = ""
+    for data in res_stream:
+        content_stream += data["content"]
+    assert content_stream == res_non_stream.body["content"]
+
+
 @pytest.mark.parametrize("n_slots", [1, 2])
 @pytest.mark.parametrize("n_slots", [1, 2])
 def test_consistent_result_same_seed(n_slots: int):
 def test_consistent_result_same_seed(n_slots: int):
     global server
     global server
@@ -221,3 +239,24 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
         assert len(res.body["content"]) > 10
         assert len(res.body["content"]) > 10
         # FIXME: the result is not deterministic when using other slot than slot 0
         # FIXME: the result is not deterministic when using other slot than slot 0
         # assert match_regex(re_content, res.body["content"])
         # assert match_regex(re_content, res.body["content"])
+
+
+def test_n_probs():
+    global server
+    server.start()
+    res = server.make_request("POST", "/completion", data={
+        "prompt": "I believe the meaning of life is",
+        "n_probs": 10,
+        "temperature": 0.0,
+        "n_predict": 5,
+    })
+    assert res.status_code == 200
+    assert "completion_probabilities" in res.body
+    assert len(res.body["completion_probabilities"]) == 5
+    for tok in res.body["completion_probabilities"]:
+        assert "probs" in tok
+        assert len(tok["probs"]) == 10
+        for prob in tok["probs"]:
+            assert "prob" in prob
+            assert "tok_str" in prob
+            assert 0.0 <= prob["prob"] <= 1.0

+ 2 - 247
examples/server/utils.hpp

@@ -20,6 +20,7 @@
 #include <sstream>
 #include <sstream>
 #include <string>
 #include <string>
 #include <vector>
 #include <vector>
+#include <memory>
 
 
 #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
 #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
 
 
@@ -40,17 +41,6 @@ using json = nlohmann::ordered_json;
 #define QUE_ERR(fmt, ...) LOG_ERR("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
 #define QUE_ERR(fmt, ...) LOG_ERR("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
 #define QUE_DBG(fmt, ...) LOG_DBG("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
 #define QUE_DBG(fmt, ...) LOG_DBG("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
 
 
-// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
-enum error_type {
-    ERROR_TYPE_INVALID_REQUEST,
-    ERROR_TYPE_AUTHENTICATION,
-    ERROR_TYPE_SERVER,
-    ERROR_TYPE_NOT_FOUND,
-    ERROR_TYPE_PERMISSION,
-    ERROR_TYPE_UNAVAILABLE, // custom error
-    ERROR_TYPE_NOT_SUPPORTED, // custom error
-};
-
 template <typename T>
 template <typename T>
 static T json_value(const json & body, const std::string & key, const T & default_value) {
 static T json_value(const json & body, const std::string & key, const T & default_value) {
     // Fallback null to default value
     // Fallback null to default value
@@ -485,48 +475,11 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx,
     return out;
     return out;
 }
 }
 
 
-struct completion_token_output {
-    llama_token tok;
-    std::string text_to_send;
-
-    struct token_prob {
-        llama_token tok;
-        float prob;
-    };
-
-    std::vector<token_prob> probs;
-};
-
-// convert a vector of completion_token_output to json
-static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> & probs) {
-    json out = json::array();
-
-    for (const auto & prob : probs) {
-        json probs_for_token = json::array();
-
-        for (const auto & p : prob.probs) {
-            const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
-            probs_for_token.push_back(json {
-                {"tok_str", tok_str},
-                {"prob",    p.prob},
-            });
-        }
-
-        const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
-        out.push_back(json {
-            {"content", tok_str},
-            {"probs",   probs_for_token},
-        });
-    }
-
-    return out;
-}
-
 static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) {
 static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) {
     const std::string str =
     const std::string str =
         std::string(event) + ": " +
         std::string(event) + ": " +
         data.dump(-1, ' ', false, json::error_handler_t::replace) +
         data.dump(-1, ' ', false, json::error_handler_t::replace) +
-        "\n\n"; // note: these newlines are important (not sure why though, if you know, add a comment to explain)
+        "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
 
 
     LOG_DBG("data stream, to_send: %s", str.c_str());
     LOG_DBG("data stream, to_send: %s", str.c_str());
 
 
@@ -604,164 +557,6 @@ static json oaicompat_completion_params_parse(
     return llama_params;
     return llama_params;
 }
 }
 
 
-static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
-    bool stopped_word        = result.count("stopped_word") != 0;
-    bool stopped_eos         = json_value(result, "stopped_eos", false);
-    int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
-    int num_prompt_tokens    = json_value(result, "tokens_evaluated", 0);
-    std::string content      = json_value(result, "content", std::string(""));
-
-    std::string finish_reason = "length";
-    if (stopped_word || stopped_eos) {
-        finish_reason = "stop";
-    }
-
-    json choices =
-        streaming ? json::array({json{{"finish_reason", finish_reason},
-                                        {"index", 0},
-                                        {"delta", json::object()}}})
-                  : json::array({json{{"finish_reason", finish_reason},
-                                        {"index", 0},
-                                        {"message", json{{"content", content},
-                                                         {"role", "assistant"}}}}});
-
-    std::time_t t = std::time(0);
-
-    json res = json {
-        {"choices", choices},
-        {"created", t},
-        {"model",
-            json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
-        {"object", streaming ? "chat.completion.chunk" : "chat.completion"},
-        {"usage", json {
-            {"completion_tokens", num_tokens_predicted},
-            {"prompt_tokens",     num_prompt_tokens},
-            {"total_tokens",      num_tokens_predicted + num_prompt_tokens}
-        }},
-        {"id", completion_id}
-    };
-
-    // extra fields for debugging purposes
-    if (verbose) {
-        res["__verbose"] = result;
-    }
-
-    if (result.contains("completion_probabilities")) {
-        res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
-    }
-
-    if (result.contains("timings")) {
-        res.push_back({"timings", json_value(result, "timings", json::object())});
-    }
-
-    return res;
-}
-
-// return value is vector as there is one case where we might need to generate two responses
-static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) {
-    if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
-        return std::vector<json>({result});
-    }
-
-    bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
-    std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
-
-    bool stopped_word   = json_value(result, "stopped_word",  false);
-    bool stopped_eos    = json_value(result, "stopped_eos",   false);
-    bool stopped_limit  = json_value(result, "stopped_limit", false);
-    std::string content = json_value(result, "content",       std::string(""));
-
-    std::string finish_reason;
-    if (stopped_word || stopped_eos) {
-        finish_reason = "stop";
-    }
-    if (stopped_limit) {
-        finish_reason = "length";
-    }
-
-    std::time_t t = std::time(0);
-
-    json choices;
-
-    if (!finish_reason.empty()) {
-        choices = json::array({json{{"finish_reason", finish_reason},
-                                    {"index", 0},
-                                    {"delta", json::object()}}});
-    } else {
-        if (first) {
-            if (content.empty()) {
-                choices = json::array({json{{"finish_reason", nullptr},
-                                            {"index", 0},
-                                            {"delta", json{{"role", "assistant"}}}}});
-            } else {
-                // We have to send this as two updates to conform to openai behavior
-                json initial_ret = json{{"choices", json::array({json{
-                                        {"finish_reason", nullptr},
-                                        {"index", 0},
-                                        {"delta", json{
-                                            {"role", "assistant"}
-                                        }}}})},
-                            {"created", t},
-                            {"id", completion_id},
-                            {"model", modelname},
-                            {"object", "chat.completion.chunk"}};
-
-                json second_ret = json{
-                            {"choices", json::array({json{{"finish_reason", nullptr},
-                                                            {"index", 0},
-                                                            {"delta", json{
-                                                            {"content", content}}}
-                                                            }})},
-                            {"created", t},
-                            {"id", completion_id},
-                            {"model", modelname},
-                            {"object", "chat.completion.chunk"}};
-
-                return std::vector<json>({initial_ret, second_ret});
-            }
-        } else {
-            // Some idiosyncrasy in task processing logic makes several trailing calls
-            // with empty content, we ignore these at the calee site.
-            if (content.empty()) {
-                return std::vector<json>({json::object()});
-            }
-
-            choices = json::array({json{
-                {"finish_reason", nullptr},
-                {"index", 0},
-                {"delta",
-                json{
-                    {"content", content},
-                }},
-            }});
-        }
-    }
-
-    json ret = json {
-        {"choices", choices},
-        {"created", t},
-        {"id",      completion_id},
-        {"model",   modelname},
-        {"object",  "chat.completion.chunk"}
-    };
-
-    if (result.contains("timings")) {
-        ret.push_back({"timings", json_value(result, "timings", json::object())});
-    }
-
-    if (!finish_reason.empty()) {
-        int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
-        int num_prompt_tokens    = json_value(result, "tokens_evaluated", 0);
-        ret.push_back({"usage", json {
-            {"completion_tokens", num_tokens_predicted},
-            {"prompt_tokens",     num_prompt_tokens},
-            {"total_tokens",      num_tokens_predicted + num_prompt_tokens}
-        }});
-    }
-
-    return std::vector<json>({ret});
-}
-
 static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
 static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
     json data = json::array();
     json data = json::array();
     int i = 0;
     int i = 0;
@@ -853,43 +648,3 @@ static json format_detokenized_response(const std::string & content) {
         {"content", content}
         {"content", content}
     };
     };
 }
 }
-
-static json format_error_response(const std::string & message, const enum error_type type) {
-    std::string type_str;
-    int code = 500;
-    switch (type) {
-        case ERROR_TYPE_INVALID_REQUEST:
-            type_str = "invalid_request_error";
-            code = 400;
-            break;
-        case ERROR_TYPE_AUTHENTICATION:
-            type_str = "authentication_error";
-            code = 401;
-            break;
-        case ERROR_TYPE_NOT_FOUND:
-            type_str = "not_found_error";
-            code = 404;
-            break;
-        case ERROR_TYPE_SERVER:
-            type_str = "server_error";
-            code = 500;
-            break;
-        case ERROR_TYPE_PERMISSION:
-            type_str = "permission_error";
-            code = 403;
-            break;
-        case ERROR_TYPE_NOT_SUPPORTED:
-            type_str = "not_supported_error";
-            code = 501;
-            break;
-        case ERROR_TYPE_UNAVAILABLE:
-            type_str = "unavailable_error";
-            code = 503;
-            break;
-    }
-    return json {
-        {"code", code},
-        {"message", message},
-        {"type", type_str},
-    };
-}

Некоторые файлы не были показаны из-за большого количества измененных файлов