Kaynağa Gözat

server: remove default "gpt-3.5-turbo" model name (#17668)

* server: remove default "gpt-3.5-turbo" model name

* do not reflect back model name from request

* fix test
Xuan-Son Nguyen 1 ay önce
ebeveyn
işleme
5d6bd842ea

+ 8 - 3
tools/server/server-common.cpp

@@ -1263,7 +1263,11 @@ json convert_anthropic_to_oai(const json & body) {
     return oai_body;
 }
 
-json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) {
+json format_embeddings_response_oaicompat(
+        const json & request,
+        const std::string & model_name,
+        const json & embeddings,
+        bool use_base64) {
     json data = json::array();
     int32_t n_tokens = 0;
     int i = 0;
@@ -1293,7 +1297,7 @@ json format_embeddings_response_oaicompat(const json & request, const json & emb
     }
 
     json res = json {
-        {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+        {"model", json_value(request, "model", model_name)},
         {"object", "list"},
         {"usage", json {
             {"prompt_tokens", n_tokens},
@@ -1307,6 +1311,7 @@ json format_embeddings_response_oaicompat(const json & request, const json & emb
 
 json format_response_rerank(
         const json & request,
+        const std::string & model_name,
         const json & ranks,
         bool is_tei_format,
         std::vector<std::string> & texts,
@@ -1338,7 +1343,7 @@ json format_response_rerank(
     if (is_tei_format) return results;
 
     json res = json{
-        {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+        {"model", json_value(request, "model", model_name)},
         {"object", "list"},
         {"usage", json{
             {"prompt_tokens", n_tokens},

+ 6 - 3
tools/server/server-common.h

@@ -13,8 +13,6 @@
 #include <vector>
 #include <cinttypes>
 
-#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo"
-
 const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
 
 using json = nlohmann::ordered_json;
@@ -298,11 +296,16 @@ json oaicompat_chat_params_parse(
 json convert_anthropic_to_oai(const json & body);
 
 // TODO: move it to server-task.cpp
-json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false);
+json format_embeddings_response_oaicompat(
+    const json & request,
+    const std::string & model_name,
+    const json & embeddings,
+    bool use_base64 = false);
 
 // TODO: move it to server-task.cpp
 json format_response_rerank(
         const json & request,
+        const std::string & model_name,
         const json & ranks,
         bool is_tei_format,
         std::vector<std::string> & texts,

+ 22 - 6
tools/server/server-context.cpp

@@ -17,6 +17,7 @@
 #include <cinttypes>
 #include <memory>
 #include <unordered_set>
+#include <filesystem>
 
 // fix problem with std::min and std::max
 #if defined(_WIN32)
@@ -518,6 +519,8 @@ struct server_context_impl {
     // Necessary similarity of prompt for slot selection
     float slot_prompt_similarity = 0.0f;
 
+    std::string model_name; // name of the loaded model, to be used by API
+
     common_chat_templates_ptr chat_templates;
     oaicompat_parser_options  oai_parser_opt;
 
@@ -758,6 +761,18 @@ struct server_context_impl {
         }
         SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
 
+        if (!params_base.model_alias.empty()) {
+            // user explicitly specified model name
+            model_name = params_base.model_alias;
+        } else if (!params_base.model.name.empty()) {
+            // use model name in registry format (for models in cache)
+            model_name = params_base.model.name;
+        } else {
+            // fallback: derive model name from file name
+            auto model_path = std::filesystem::path(params_base.model.path);
+            model_name = model_path.filename().string();
+        }
+
         // thinking is enabled if:
         // 1. It's not explicitly disabled (reasoning_budget == 0)
         // 2. The chat template supports it
@@ -2611,7 +2626,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
             // OAI-compat
             task.params.res_type          = res_type;
             task.params.oaicompat_cmpl_id = completion_id;
-            // oaicompat_model is already populated by params_from_json_cmpl
+            task.params.oaicompat_model   = ctx_server.model_name;
 
             tasks.push_back(std::move(task));
         }
@@ -2939,7 +2954,7 @@ void server_routes::init_routes() {
         json data = {
             { "default_generation_settings", default_generation_settings_for_props },
             { "total_slots",                 ctx_server.params_base.n_parallel },
-            { "model_alias",                 ctx_server.params_base.model_alias },
+            { "model_alias",                 ctx_server.model_name },
             { "model_path",                  ctx_server.params_base.model.path },
             { "modalities",                  json {
                 {"vision", ctx_server.oai_parser_opt.allow_image},
@@ -3181,8 +3196,8 @@ void server_routes::init_routes() {
         json models = {
             {"models", {
                 {
-                    {"name", params.model_alias.empty() ? params.model.path : params.model_alias},
-                    {"model", params.model_alias.empty() ? params.model.path : params.model_alias},
+                    {"name", ctx_server.model_name},
+                    {"model", ctx_server.model_name},
                     {"modified_at", ""},
                     {"size", ""},
                     {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash
@@ -3204,7 +3219,7 @@ void server_routes::init_routes() {
             {"object", "list"},
             {"data", {
                 {
-                    {"id",       params.model_alias.empty() ? params.model.path : params.model_alias},
+                    {"id",       ctx_server.model_name},
                     {"object",   "model"},
                     {"created",  std::time(0)},
                     {"owned_by", "llamacpp"},
@@ -3351,6 +3366,7 @@ void server_routes::init_routes() {
         // write JSON response
         json root = format_response_rerank(
             body,
+            ctx_server.model_name,
             responses,
             is_tei_format,
             documents,
@@ -3613,7 +3629,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(cons
 
     // write JSON response
     json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD
-        ? format_embeddings_response_oaicompat(body, responses, use_base64)
+        ? format_embeddings_response_oaicompat(body, ctx_server.model_name, responses, use_base64)
         : json(responses);
     res->ok(root);
     return res;

+ 0 - 3
tools/server/server-task.cpp

@@ -450,9 +450,6 @@ task_params server_task::params_from_json_cmpl(
         }
     }
 
-    std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
-    params.oaicompat_model = json_value(data, "model", model_name);
-
     return params;
 }
 

+ 4 - 3
tools/server/tests/unit/test_chat_completion.py

@@ -41,7 +41,8 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
     assert res.status_code == 200
     assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
     assert res.body["system_fingerprint"].startswith("b")
-    assert res.body["model"] == model if model is not None else server.model_alias
+    # we no longer reflect back the model name, see https://github.com/ggml-org/llama.cpp/pull/17668
+    # assert res.body["model"] == model if model is not None else server.model_alias
     assert res.body["usage"]["prompt_tokens"] == n_prompt
     assert res.body["usage"]["completion_tokens"] == n_predicted
     choice = res.body["choices"][0]
@@ -59,7 +60,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
 )
 def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
     global server
-    server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL
+    server.model_alias = "llama-test-model"
     server.start()
     res = server.make_stream_request("POST", "/chat/completions", data={
         "max_tokens": max_tokens,
@@ -81,7 +82,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
             else:
                 assert "role" not in choice["delta"]
             assert data["system_fingerprint"].startswith("b")
-            assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
+            assert data["model"] == "llama-test-model"
             if last_cmpl_id is None:
                 last_cmpl_id = data["id"]
             assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream