Просмотр исходного кода

server : refactor (#5882)

* server : refactoring (wip)

* server : remove llava/clip objects from build

* server : fix empty prompt handling + all slots idle logic

* server : normalize id vars

* server : code style

* server : simplify model chat template validation

* server : code style

* server : minor

* llama : llama_chat_apply_template support null buf

* server : do not process embedding requests when disabled

* server : reorganize structs and enums + naming fixes

* server : merge oai.hpp in utils.hpp

* server : refactor system prompt update at start

* server : disable cached prompts with self-extend

* server : do not process more than n_batch tokens per iter

* server: tests: embeddings use a real embeddings model (#5908)

* server, tests : bump batch to fit 1 embedding prompt

* server: tests: embeddings fix build type Debug is randomly failing (#5911)

* server: tests: embeddings, use different KV Cache size

* server: tests: embeddings, fixed prompt do not exceed n_batch, increase embedding timeout, reduce number of concurrent embeddings

* server: tests: embeddings, no need to wait for server idle as it can timout

* server: refactor: clean up http code (#5912)

* server : avoid n_available var

ggml-ci

* server: refactor: better http codes

* server : simplify json parsing + add comment about t_last

* server : rename server structs

* server : allow to override FQDN in tests

ggml-ci

* server : add comments

---------

Co-authored-by: Pierrick Hymbert <pierrick.hymbert@gmail.com>
Georgi Gerganov 1 год назад
Родитель
Сommit
2002bc96bf

+ 2 - 1
.github/workflows/server.yml

@@ -58,7 +58,8 @@ jobs:
             cmake \
             cmake \
             python3-pip \
             python3-pip \
             wget \
             wget \
-            psmisc
+            psmisc \
+            language-pack-en
 
 
       - name: Build
       - name: Build
         id: cmake_build
         id: cmake_build

+ 2 - 3
Makefile

@@ -724,10 +724,9 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
 	$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
 	$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
 
 
-server: examples/server/server.cpp examples/server/oai.hpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h examples/llava/llava.h examples/llava/llava.cpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
+server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
-	$(CXX) $(CXXFLAGS) -c examples/llava/clip.cpp -o $(call GET_OBJ_FILE, examples/llava/clip.cpp) -Wno-cast-qual
-	$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h %.hpp $< examples/llava/clip.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) -o $@ $(LDFLAGS) $(LWINSOCK2)
+	$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
 
 
 gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)
 gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)
 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

+ 1 - 1
examples/server-embd.py

@@ -13,7 +13,7 @@ async def main():
     model_url = "http://127.0.0.1:6900"
     model_url = "http://127.0.0.1:6900"
     responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
     responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
         url= f"{model_url}/embedding",
         url= f"{model_url}/embedding",
-        json= {"content": str(i)*1024}
+        json= {"content": str(0)*1024}
     ) for i in range(n)])
     ) for i in range(n)])
 
 
     for response in responses:
     for response in responses:

+ 2 - 2
examples/server/CMakeLists.txt

@@ -1,12 +1,12 @@
 set(TARGET server)
 set(TARGET server)
 option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
 option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
 include_directories(${CMAKE_CURRENT_SOURCE_DIR})
 include_directories(${CMAKE_CURRENT_SOURCE_DIR})
-add_executable(${TARGET} server.cpp oai.hpp utils.hpp json.hpp httplib.h)
+add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
 install(TARGETS ${TARGET} RUNTIME)
 install(TARGETS ${TARGET} RUNTIME)
 target_compile_definitions(${TARGET} PRIVATE
 target_compile_definitions(${TARGET} PRIVATE
     SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
     SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
 )
 )
-target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
+target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
 if (WIN32)
 if (WIN32)
     TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
     TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
 endif()
 endif()

+ 1 - 1
examples/server/README.md

@@ -436,7 +436,7 @@ Notice that each `probs` is an array of length `n_probs`.
         "next_token": {
         "next_token": {
             "has_next_token": true,
             "has_next_token": true,
             "n_remain": -1,
             "n_remain": -1,
-            "num_tokens_predicted": 0,
+            "n_decoded": 0,
             "stopped_eos": false,
             "stopped_eos": false,
             "stopped_limit": false,
             "stopped_limit": false,
             "stopped_word": false,
             "stopped_word": false,

+ 0 - 225
examples/server/oai.hpp

@@ -1,225 +0,0 @@
-#pragma once
-
-#include <string>
-#include <vector>
-#include <set>
-#include <mutex>
-#include <condition_variable>
-#include <unordered_map>
-
-#include "json.hpp"
-#include "utils.hpp"
-
-#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
-
-using json = nlohmann::json;
-
-inline static json oaicompat_completion_params_parse(
-    const struct llama_model * model,
-    const json &body, /* openai api json semantics */
-    const std::string &chat_template)
-{
-    json llama_params;
-
-    llama_params["__oaicompat"] = true;
-
-    // Map OpenAI parameters to llama.cpp parameters
-    //
-    // For parameters that are defined by the OpenAI documentation (e.g.
-    // temperature), we explicitly specify OpenAI's intended default; we
-    // need to do that because sometimes OpenAI disagrees with llama.cpp
-    //
-    // https://platform.openai.com/docs/api-reference/chat/create
-    llama_sampling_params default_sparams;
-    llama_params["model"]             = json_value(body, "model", std::string("unknown"));
-    llama_params["prompt"]            = format_chat(model, chat_template, body["messages"]);
-    llama_params["cache_prompt"]      = json_value(body, "cache_prompt", false);
-    llama_params["temperature"]       = json_value(body, "temperature", 0.0);
-    llama_params["top_k"]             = json_value(body, "top_k", default_sparams.top_k);
-    llama_params["top_p"]             = json_value(body, "top_p", 1.0);
-    llama_params["n_predict"]         = json_value(body, "max_tokens", -1);
-    llama_params["logit_bias"]        = json_value(body, "logit_bias",json::object());
-    llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0);
-    llama_params["presence_penalty"]  = json_value(body, "presence_penalty", 0.0);
-    llama_params["seed"]              = json_value(body, "seed", LLAMA_DEFAULT_SEED);
-    llama_params["stream"]            = json_value(body, "stream", false);
-    llama_params["mirostat"]          = json_value(body, "mirostat", default_sparams.mirostat);
-    llama_params["mirostat_tau"]      = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
-    llama_params["mirostat_eta"]      = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
-    llama_params["penalize_nl"]       = json_value(body, "penalize_nl", default_sparams.penalize_nl);
-    llama_params["typical_p"]         = json_value(body, "typical_p", default_sparams.typical_p);
-    llama_params["repeat_last_n"]     = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
-    llama_params["ignore_eos"]        = json_value(body, "ignore_eos", false);
-    llama_params["tfs_z"]             = json_value(body, "tfs_z", default_sparams.tfs_z);
-
-    if (body.count("grammar") != 0) {
-        llama_params["grammar"] = json_value(body, "grammar", json::object());
-    }
-
-    // Handle 'stop' field
-    if (body.contains("stop") && body["stop"].is_string()) {
-        llama_params["stop"] = json::array({body["stop"].get<std::string>()});
-    } else {
-        llama_params["stop"] = json_value(body, "stop", json::array());
-    }
-
-    // Ensure there is ChatML-specific end sequence among stop words
-    llama_params["stop"].push_back("<|im_end|>");
-
-    return llama_params;
-}
-
-inline static json format_final_response_oaicompat(const json &request, const task_result &response, bool streaming = false)
-{
-    json result = response.result_json;
-
-    bool stopped_word        = result.count("stopped_word") != 0;
-    bool stopped_eos         = json_value(result, "stopped_eos", false);
-    int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
-    int num_prompt_tokens    = json_value(result, "tokens_evaluated", 0);
-    std::string content      = json_value(result, "content", std::string(""));
-
-    std::string finish_reason = "length";
-    if (stopped_word || stopped_eos) {
-        finish_reason = "stop";
-    }
-
-    json choices =
-        streaming ? json::array({json{{"finish_reason", finish_reason},
-                                        {"index", 0},
-                                        {"delta", json::object()}}})
-                  : json::array({json{{"finish_reason", finish_reason},
-                                        {"index", 0},
-                                        {"message", json{{"content", content},
-                                                         {"role", "assistant"}}}}});
-
-    std::time_t t = std::time(0);
-
-    json res =
-        json{{"choices", choices},
-            {"created", t},
-            {"model",
-                json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
-            {"object", streaming ? "chat.completion.chunk" : "chat.completion"},
-            {"usage",
-                json{{"completion_tokens", num_tokens_predicted},
-                     {"prompt_tokens",     num_prompt_tokens},
-                     {"total_tokens",      num_tokens_predicted + num_prompt_tokens}}},
-            {"id", gen_chatcmplid()}};
-
-    if (server_verbose) {
-        res["__verbose"] = result;
-    }
-
-    if (result.contains("completion_probabilities")) {
-        res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
-    }
-
-    return res;
-}
-
-// return value is vector as there is one case where we might need to generate two responses
-inline static std::vector<json> format_partial_response_oaicompat(const task_result &response) {
-    json result = response.result_json;
-
-    if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
-        return std::vector<json>({response.result_json});
-    }
-
-    bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
-    std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
-
-    bool stopped_word   = json_value(result, "stopped_word", false);
-    bool stopped_eos    = json_value(result, "stopped_eos", false);
-    bool stopped_limit  = json_value(result, "stopped_limit", false);
-    std::string content = json_value(result, "content", std::string(""));
-
-    std::string finish_reason;
-    if (stopped_word || stopped_eos) {
-        finish_reason = "stop";
-    }
-    if (stopped_limit) {
-        finish_reason = "length";
-    }
-
-    std::time_t t = std::time(0);
-
-    json choices;
-
-    if (!finish_reason.empty()) {
-        choices = json::array({json{{"finish_reason", finish_reason},
-                                    {"index", 0},
-                                    {"delta", json::object()}}});
-    } else {
-        if (first) {
-            if (content.empty()) {
-                choices = json::array({json{{"finish_reason", nullptr},
-                                            {"index", 0},
-                                            {"delta", json{{"role", "assistant"}}}}});
-            } else {
-                // We have to send this as two updates to conform to openai behavior
-                json initial_ret = json{{"choices", json::array({json{
-                                        {"finish_reason", nullptr},
-                                        {"index", 0},
-                                        {"delta", json{
-                                            {"role", "assistant"}
-                                        }}}})},
-                            {"created", t},
-                            {"id", gen_chatcmplid()},
-                            {"model", modelname},
-                            {"object", "chat.completion.chunk"}};
-
-                json second_ret = json{
-                            {"choices", json::array({json{{"finish_reason", nullptr},
-                                                            {"index", 0},
-                                                            {"delta", json{
-                                                            {"content", content}}}
-                                                            }})},
-                            {"created", t},
-                            {"id", gen_chatcmplid()},
-                            {"model", modelname},
-                            {"object", "chat.completion.chunk"}};
-
-                return std::vector<json>({initial_ret, second_ret});
-            }
-        } else {
-            // Some idiosyncrasy in task processing logic makes several trailing calls
-            // with empty content, we ignore these at the calee site.
-            if (content.empty()) {
-                return std::vector<json>({json::object()});
-            }
-
-            choices = json::array({json{
-                {"finish_reason", nullptr},
-                {"index", 0},
-                {"delta",
-                json{
-                    {"content", content},
-                }},
-            }});
-        }
-    }
-
-    json ret = json{{"choices", choices},
-                    {"created", t},
-                    {"id", gen_chatcmplid()},
-                    {"model", modelname},
-                    {"object", "chat.completion.chunk"}};
-
-    return std::vector<json>({ret});
-}
-
-inline static json format_embeddings_response_oaicompat(const json &request, const json &embeddings)
-{
-    json res =
-        json{
-            {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
-            {"object", "list"},
-            {"usage",
-                json{{"prompt_tokens", 0},
-                     {"total_tokens", 0}}},
-            {"data", embeddings}
-        };
-    return res;
-}
-

Разница между файлами не показана из-за своего большого размера
+ 576 - 424
examples/server/server.cpp


+ 94 - 0
examples/server/tests/features/embeddings.feature

@@ -0,0 +1,94 @@
+@llama.cpp
+@embeddings
+Feature: llama.cpp server
+
+  Background: Server startup
+    Given a server listening on localhost:8080
+    And   a model file bert-bge-small/ggml-model-f16.gguf from HF repo ggml-org/models
+    And   a model alias bert-bge-small
+    And   42 as server seed
+    And   2 slots
+    And   1024 as batch size
+    And   2048 KV cache size
+    And   embeddings extraction
+    Then  the server is starting
+    Then  the server is healthy
+
+  Scenario: Embedding
+    When embeddings are computed for:
+    """
+    What is the capital of Bulgaria ?
+    """
+    Then embeddings are generated
+
+  Scenario: OAI Embeddings compatibility
+    Given a model bert-bge-small
+    When an OAI compatible embeddings computation request for:
+    """
+    What is the capital of Spain ?
+    """
+    Then embeddings are generated
+
+  Scenario: OAI Embeddings compatibility with multiple inputs
+    Given a model bert-bge-small
+    Given a prompt:
+      """
+      In which country Paris is located ?
+      """
+    And a prompt:
+      """
+      Is Madrid the capital of Spain ?
+      """
+    When an OAI compatible embeddings computation request for multiple inputs
+    Then embeddings are generated
+
+  Scenario: Multi users embeddings
+    Given a prompt:
+      """
+      Write a very long story about AI.
+      """
+    And a prompt:
+      """
+      Write another very long music lyrics.
+      """
+    And a prompt:
+      """
+      Write a very long poem.
+      """
+    And a prompt:
+      """
+      Write a very long joke.
+      """
+    Given concurrent embedding requests
+    Then the server is busy
+    Then the server is idle
+    Then all embeddings are generated
+
+  Scenario: Multi users OAI compatibility embeddings
+    Given a prompt:
+      """
+      In which country Paris is located ?
+      """
+    And a prompt:
+      """
+      Is Madrid the capital of Spain ?
+      """
+    And a prompt:
+      """
+      What is the biggest US city ?
+      """
+    And a prompt:
+      """
+      What is the capital of Bulgaria ?
+      """
+    And   a model bert-bge-small
+    Given concurrent OAI embedding requests
+    Then the server is busy
+    Then the server is idle
+    Then all embeddings are generated
+
+  Scenario: All embeddings should be the same
+    Given 10 fixed prompts
+    And   a model bert-bge-small
+    Given concurrent OAI embedding requests
+    Then all embeddings are the same

+ 0 - 46
examples/server/tests/features/parallel.feature

@@ -9,7 +9,6 @@ Feature: Parallel
     And   512 as batch size
     And   512 as batch size
     And   64 KV cache size
     And   64 KV cache size
     And   2 slots
     And   2 slots
-    And   embeddings extraction
     And   continuous batching
     And   continuous batching
     Then  the server is starting
     Then  the server is starting
     Then  the server is healthy
     Then  the server is healthy
@@ -99,48 +98,3 @@ Feature: Parallel
     Then the server is busy
     Then the server is busy
     Then the server is idle
     Then the server is idle
     Then all prompts are predicted
     Then all prompts are predicted
-
-  Scenario: Multi users embeddings
-    Given a prompt:
-      """
-      Write a very long story about AI.
-      """
-    And a prompt:
-      """
-      Write another very long music lyrics.
-      """
-    And a prompt:
-      """
-      Write a very long poem.
-      """
-    And a prompt:
-      """
-      Write a very long joke.
-      """
-    Given concurrent embedding requests
-    Then the server is busy
-    Then the server is idle
-    Then all embeddings are generated
-
-  Scenario: Multi users OAI compatibility embeddings
-    Given a prompt:
-      """
-      In which country Paris is located ?
-      """
-    And a prompt:
-      """
-      Is Madrid the capital of Spain ?
-      """
-    And a prompt:
-      """
-      What is the biggest US city ?
-      """
-    And a prompt:
-      """
-      What is the capital of Bulgaria ?
-      """
-    And   a model tinyllama-2
-    Given concurrent OAI embedding requests
-    Then the server is busy
-    Then the server is idle
-    Then all embeddings are generated

+ 0 - 28
examples/server/tests/features/server.feature

@@ -49,34 +49,6 @@ Feature: llama.cpp server
       | llama-2      | Book                        | What is the best book                | 8          | (Mom\|what)+           | 8           | disabled         |
       | llama-2      | Book                        | What is the best book                | 8          | (Mom\|what)+           | 8           | disabled         |
       | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64         | (thanks\|happy\|bird)+ | 32          | enabled          |
       | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64         | (thanks\|happy\|bird)+ | 32          | enabled          |
 
 
-  Scenario: Embedding
-    When embeddings are computed for:
-    """
-    What is the capital of Bulgaria ?
-    """
-    Then embeddings are generated
-
-  Scenario: OAI Embeddings compatibility
-    Given a model tinyllama-2
-    When an OAI compatible embeddings computation request for:
-    """
-    What is the capital of Spain ?
-    """
-    Then embeddings are generated
-
-  Scenario: OAI Embeddings compatibility with multiple inputs
-    Given a model tinyllama-2
-    Given a prompt:
-      """
-      In which country Paris is located ?
-      """
-    And a prompt:
-      """
-      Is Madrid the capital of Spain ?
-      """
-    When an OAI compatible embeddings computation request for multiple inputs
-    Then embeddings are generated
-
   Scenario: Tokenize / Detokenize
   Scenario: Tokenize / Detokenize
     When tokenizing:
     When tokenizing:
     """
     """

+ 67 - 19
examples/server/tests/features/steps/steps.py

@@ -10,6 +10,7 @@ from contextlib import closing
 from re import RegexFlag
 from re import RegexFlag
 
 
 import aiohttp
 import aiohttp
+import numpy as np
 import openai
 import openai
 from behave import step
 from behave import step
 from behave.api.async_step import async_run_until_complete
 from behave.api.async_step import async_run_until_complete
@@ -24,6 +25,9 @@ def step_server_config(context, server_fqdn, server_port):
     if 'PORT' in os.environ:
     if 'PORT' in os.environ:
         context.server_port = int(os.environ['PORT'])
         context.server_port = int(os.environ['PORT'])
         print(f"$PORT set, overriding server port with to {context.server_port}")
         print(f"$PORT set, overriding server port with to {context.server_port}")
+    if 'FQDN' in os.environ:
+        context.server_fqdn = os.environ['FQDN']
+        print(f"$FQDN set, overriding server fqdn with to {context.server_fqdn}")
 
 
     context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
     context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
 
 
@@ -34,6 +38,7 @@ def step_server_config(context, server_fqdn, server_port):
     context.n_ga_w = None
     context.n_ga_w = None
     context.n_gpu_layer = None
     context.n_gpu_layer = None
     context.n_predict = None
     context.n_predict = None
+    context.n_prompts = 0
     context.n_server_predict = None
     context.n_server_predict = None
     context.n_slots = None
     context.n_slots = None
     context.prompt_prefix = None
     context.prompt_prefix = None
@@ -202,6 +207,7 @@ def step_n_tokens_predicted(context, predicted_n):
 @step(u'a user prompt {user_prompt}')
 @step(u'a user prompt {user_prompt}')
 def step_user_prompt(context, user_prompt):
 def step_user_prompt(context, user_prompt):
     context.prompts.append(user_prompt)
     context.prompts.append(user_prompt)
+    context.n_prompts = len(context.prompts)
 
 
 
 
 @step(u'a system prompt {system_prompt}')
 @step(u'a system prompt {system_prompt}')
@@ -290,6 +296,12 @@ def step_prompt_passkey(context):
     context.prompt_passkey = context.text
     context.prompt_passkey = context.text
 
 
 
 
+@step(u'{n_prompts:d} fixed prompts')
+def step_fixed_prompts(context, n_prompts):
+    context.prompts.extend([str(0)*(context.n_batch if context.n_batch is not None else 512) for i in range(n_prompts)])
+    context.n_prompts = n_prompts
+
+
 @step(u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk')
 @step(u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk')
 def step_prompt_passkey(context, passkey, i_pos):
 def step_prompt_passkey(context, passkey, i_pos):
     prompt = ""
     prompt = ""
@@ -301,6 +313,7 @@ def step_prompt_passkey(context, passkey, i_pos):
         passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
         passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
         print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n")
         print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n")
     context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
     context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
+    context.n_prompts = len(context.prompts)
 
 
 
 
 @step(u'an OAI compatible chat completions request with {api_error} api error')
 @step(u'an OAI compatible chat completions request with {api_error} api error')
@@ -341,11 +354,13 @@ async def step_oai_chat_completions(context, api_error):
 @step(u'a prompt')
 @step(u'a prompt')
 def step_a_prompt(context):
 def step_a_prompt(context):
     context.prompts.append(context.text)
     context.prompts.append(context.text)
+    context.n_prompts = len(context.prompts)
 
 
 
 
 @step(u'a prompt {prompt}')
 @step(u'a prompt {prompt}')
 def step_a_prompt_prompt(context, prompt):
 def step_a_prompt_prompt(context, prompt):
     context.prompts.append(prompt)
     context.prompts.append(prompt)
+    context.n_prompts = len(context.prompts)
 
 
 
 
 @step(u'concurrent completion requests')
 @step(u'concurrent completion requests')
@@ -430,25 +445,47 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
 @step(u'embeddings are computed for')
 @step(u'embeddings are computed for')
 @async_run_until_complete
 @async_run_until_complete
 async def step_compute_embedding(context):
 async def step_compute_embedding(context):
+    context.n_prompts = 1
     context.embeddings = await request_embedding(context.text, base_url=context.base_url)
     context.embeddings = await request_embedding(context.text, base_url=context.base_url)
 
 
 
 
+@step(u'all embeddings are the same')
+@async_run_until_complete
+async def step_all_embeddings_are_the_same(context):
+    n_embedding_requests = await gather_tasks_results(context)
+    assert n_embedding_requests > 0
+    embeddings = []
+    for i in range(n_embedding_requests):
+        embedding = context.tasks_result.pop().pop()
+        embeddings.append(embedding)
+        assert_embeddings(embedding)
+    n = len(embeddings)
+    for i in range(n-1):
+        for j in range(i+1, n):
+            embedding1 = np.array(embeddings[i])
+            embedding2 = np.array(embeddings[j])
+            if context.debug:
+                print(f"embedding1: {embedding1[-8:]}\n")
+                print(f"embedding2: {embedding2[-8:]}\n")
+            similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
+            msg = f"Similarity between {i} and {j}: {similarity:.10f}"
+            if context.debug:
+                print(f"{msg}\n")
+            assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg
+
 @step(u'embeddings are generated')
 @step(u'embeddings are generated')
 def step_assert_embeddings(context):
 def step_assert_embeddings(context):
-    if len(context.prompts) == 0:
-        assert_embeddings(context.embeddings)
-    else:
-        assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n"
-                                                                 f"context.prompts={context.prompts}\n"
-                                                                 f"context.embeddings={context.embeddings}")
-        for embedding in context.embeddings:
-            context.prompts.pop()
-            assert_embeddings(embedding)
+    assert context.n_prompts == len(context.embeddings), (f"unexpected response:\n"
+                                                             f"context.n_prompts={context.n_prompts}\n"
+                                                             f"context.embeddings={context.embeddings}")
+    for embedding in context.embeddings:
+        assert_embeddings(embedding)
 
 
 
 
 @step(u'an OAI compatible embeddings computation request for')
 @step(u'an OAI compatible embeddings computation request for')
 @async_run_until_complete
 @async_run_until_complete
 async def step_oai_compute_embeddings(context):
 async def step_oai_compute_embeddings(context):
+    context.n_prompts = 1
     context.embeddings = await request_oai_embeddings(context.text,
     context.embeddings = await request_oai_embeddings(context.text,
                                                       base_url=context.base_url,
                                                       base_url=context.base_url,
                                                       user_api_key=context.user_api_key,
                                                       user_api_key=context.user_api_key,
@@ -462,6 +499,7 @@ async def step_oai_compute_embeddings_multiple_inputs(context):
                                                       base_url=context.base_url,
                                                       base_url=context.base_url,
                                                       user_api_key=context.user_api_key,
                                                       user_api_key=context.user_api_key,
                                                       model=context.model)
                                                       model=context.model)
+    context.prompts.clear()
 
 
 
 
 @step(u'concurrent embedding requests')
 @step(u'concurrent embedding requests')
@@ -488,9 +526,9 @@ async def step_concurrent_oai_embedding_requests(context):
 @async_run_until_complete()
 @async_run_until_complete()
 async def all_embeddings_are_generated(context):
 async def all_embeddings_are_generated(context):
     n_embedding_requests = await gather_tasks_results(context)
     n_embedding_requests = await gather_tasks_results(context)
-    assert n_embedding_requests > 0
+    assert n_embedding_requests == context.n_prompts
     for i in range(n_embedding_requests):
     for i in range(n_embedding_requests):
-        assert_embeddings(context.tasks_result.pop())
+        assert_embeddings(context.tasks_result.pop().pop())
 
 
 
 
 @step(u'tokenizing')
 @step(u'tokenizing')
@@ -588,11 +626,11 @@ def step_supported_models(context, i_model, param, preposition, param_value):
 
 
 
 
 async def concurrent_requests(context, f_completion, *args, **kwargs):
 async def concurrent_requests(context, f_completion, *args, **kwargs):
-    n_prompts = len(context.prompts)
+    context.n_prompts = len(context.prompts)
     if context.debug:
     if context.debug:
-        print(f"starting {n_prompts} concurrent completion requests...")
-    assert n_prompts > 0
-    for prompt_no in range(n_prompts):
+        print(f"starting {context.n_prompts} concurrent completion requests...")
+    assert context.n_prompts > 0
+    for prompt_no in range(context.n_prompts):
         shifted_args = [context.prompts.pop(), *args]
         shifted_args = [context.prompts.pop(), *args]
         context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
         context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
     await asyncio.sleep(0.1)
     await asyncio.sleep(0.1)
@@ -765,7 +803,7 @@ async def request_embedding(content, base_url=None):
                                 }) as response:
                                 }) as response:
             assert response.status == 200
             assert response.status == 200
             response_json = await response.json()
             response_json = await response.json()
-            return response_json['embedding']
+            return [response_json['embedding']]
 
 
 
 
 async def request_oai_embeddings(input,
 async def request_oai_embeddings(input,
@@ -775,6 +813,7 @@ async def request_oai_embeddings(input,
     user_api_key = user_api_key if user_api_key is not None else 'nope'
     user_api_key = user_api_key if user_api_key is not None else 'nope'
     if async_client:
     if async_client:
         origin = 'llama.cpp'
         origin = 'llama.cpp'
+        headers=[]
         if user_api_key is not None:
         if user_api_key is not None:
             headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
             headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
         async with aiohttp.ClientSession() as session:
         async with aiohttp.ClientSession() as session:
@@ -783,14 +822,21 @@ async def request_oai_embeddings(input,
                                         "input": input,
                                         "input": input,
                                         "model": model,
                                         "model": model,
                                     },
                                     },
-                                    headers=headers) as response:
+                                    headers=headers,
+                                    timeout=3600) as response:
                 assert response.status == 200, f"received status code not expected: {response.status}"
                 assert response.status == 200, f"received status code not expected: {response.status}"
                 assert response.headers['Access-Control-Allow-Origin'] == origin
                 assert response.headers['Access-Control-Allow-Origin'] == origin
                 assert response.headers['Content-Type'] == "application/json; charset=utf-8"
                 assert response.headers['Content-Type'] == "application/json; charset=utf-8"
                 response_json = await response.json()
                 response_json = await response.json()
                 assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
                 assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
                 assert response_json['object'] == 'list'
                 assert response_json['object'] == 'list'
-                return response_json['data']
+                if isinstance(input, collections.abc.Sequence):
+                    embeddings = []
+                    for an_oai_embeddings in response_json['data']:
+                        embeddings.append(an_oai_embeddings['embedding'])
+                else:
+                    embeddings = [response_json['data']['embedding']]
+                return embeddings
     else:
     else:
         openai.api_key = user_api_key
         openai.api_key = user_api_key
         openai.api_base = f'{base_url}/v1'
         openai.api_base = f'{base_url}/v1'
@@ -804,7 +850,7 @@ async def request_oai_embeddings(input,
             for an_oai_embeddings in oai_embeddings.data:
             for an_oai_embeddings in oai_embeddings.data:
                 embeddings.append(an_oai_embeddings.embedding)
                 embeddings.append(an_oai_embeddings.embedding)
         else:
         else:
-            embeddings = oai_embeddings.data.embedding
+            embeddings = [oai_embeddings.data.embedding]
         return embeddings
         return embeddings
 
 
 
 
@@ -899,6 +945,8 @@ def assert_embeddings(embeddings):
     assert len(embeddings) > 0
     assert len(embeddings) > 0
     embeddings_computed = False
     embeddings_computed = False
     for emb in embeddings:
     for emb in embeddings:
+        if not isinstance(emb, float):
+            assert False, f"Bad embeddings: {embeddings}"
         if emb != 0:
         if emb != 0:
             embeddings_computed = True
             embeddings_computed = True
     assert embeddings_computed, f"Embeddings: {embeddings}"
     assert embeddings_computed, f"Embeddings: {embeddings}"

+ 1 - 0
examples/server/tests/requirements.txt

@@ -1,5 +1,6 @@
 aiohttp~=3.9.3
 aiohttp~=3.9.3
 behave~=1.2.6
 behave~=1.2.6
 huggingface_hub~=0.20.3
 huggingface_hub~=0.20.3
+numpy~=1.24.4
 openai~=0.25.0
 openai~=0.25.0
 prometheus-client~=0.20.0
 prometheus-client~=0.20.0

+ 307 - 396
examples/server/utils.hpp

@@ -1,15 +1,16 @@
 #pragma once
 #pragma once
 
 
-#include <string>
-#include <vector>
-#include <set>
-#include <mutex>
-#include <condition_variable>
-#include <unordered_map>
+#include "llama.h"
+#include "common.h"
 
 
 #include "json.hpp"
 #include "json.hpp"
 
 
-#include "../llava/clip.h"
+#include <string>
+#include <vector>
+#include <sstream>
+#include <random>
+
+#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
 
 
 using json = nlohmann::json;
 using json = nlohmann::json;
 
 
@@ -37,83 +38,35 @@ extern bool server_log_json;
 #define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__)
 #define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__)
 #define LOG_INFO(   MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
 #define LOG_INFO(   MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
 
 
-enum server_state {
-    SERVER_STATE_LOADING_MODEL,  // Server is starting up, model not fully loaded yet
-    SERVER_STATE_READY,          // Server is ready and model is loaded
-    SERVER_STATE_ERROR           // An error occurred, load_model failed
-};
-
-enum task_type {
-    TASK_TYPE_COMPLETION,
-    TASK_TYPE_CANCEL,
-    TASK_TYPE_NEXT_RESPONSE,
-    TASK_TYPE_METRICS
-};
-
-struct task_server {
-    int id = -1; // to be filled by llama_server_queue
-    int target_id;
-    task_type type;
-    json data;
-    bool infill_mode = false;
-    bool embedding_mode = false;
-    int multitask_id = -1;
-};
-
-struct task_result {
-    int id;
-    int multitask_id = -1;
-    bool stop;
-    bool error;
-    json result_json;
-};
-
-struct task_multi {
-    int id;
-    std::set<int> subtasks_remaining{};
-    std::vector<task_result> results{};
-};
-
-// completion token output with probabilities
-struct completion_token_output {
-    struct token_prob
-    {
-        llama_token tok;
-        float prob;
-    };
-
-    std::vector<token_prob> probs;
-    llama_token tok;
-    std::string text_to_send;
-};
-
-struct token_translator {
-    llama_context * ctx;
-    std::string operator()(llama_token tok)                    const { return llama_token_to_piece(ctx, tok); }
-    std::string operator()(const completion_token_output &cto) const { return (*this)(cto.tok); }
-};
+template <typename T>
+static T json_value(const json &body, const std::string &key, const T &default_value) {
+    // Fallback null to default value
+    return body.contains(key) && !body.at(key).is_null()
+        ? body.value(key, default_value)
+        : default_value;
+}
 
 
 static inline void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra) {
 static inline void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra) {
     std::stringstream ss_tid;
     std::stringstream ss_tid;
     ss_tid << std::this_thread::get_id();
     ss_tid << std::this_thread::get_id();
     json log = nlohmann::ordered_json{
     json log = nlohmann::ordered_json{
-        {"tid", ss_tid.str()},
+        {"tid",       ss_tid.str()},
         {"timestamp", time(nullptr)},
         {"timestamp", time(nullptr)},
     };
     };
 
 
     if (server_log_json) {
     if (server_log_json) {
-        log.merge_patch(
-                {
-                        {"level",     level},
-                        {"function",  function},
-                        {"line",      line},
-                        {"msg",       message},
-                });
+        log.merge_patch( {
+            {"level",    level},
+            {"function", function},
+            {"line",     line},
+            {"msg",      message},
+        });
+
         if (!extra.empty()) {
         if (!extra.empty()) {
             log.merge_patch(extra);
             log.merge_patch(extra);
         }
         }
 
 
-        std::cout << log.dump(-1, ' ', false, json::error_handler_t::replace) << "\n" << std::flush;
+        printf("%s\n", log.dump(-1, ' ', false, json::error_handler_t::replace).c_str());
     } else {
     } else {
         char buf[1024];
         char buf[1024];
         snprintf(buf, 1024, "%4s [%24s] %s", level, function, message);
         snprintf(buf, 1024, "%4s [%24s] %s", level, function, message);
@@ -136,22 +89,13 @@ static inline void server_log(const char *level, const char *function, int line,
 }
 }
 
 
 //
 //
-// server utils
+// chat template utils
 //
 //
 
 
-template <typename T>
-static T json_value(const json &body, const std::string &key, const T &default_value) {
-    // Fallback null to default value
-    return body.contains(key) && !body.at(key).is_null()
-        ? body.value(key, default_value)
-        : default_value;
-}
-
 // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
 // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
 inline bool verify_custom_template(const std::string & tmpl) {
 inline bool verify_custom_template(const std::string & tmpl) {
     llama_chat_message chat[] = {{"user", "test"}};
     llama_chat_message chat[] = {{"user", "test"}};
-    std::vector<char> buf(1);
-    int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, buf.data(), buf.size());
+    int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
     return res >= 0;
     return res >= 0;
 }
 }
 
 
@@ -163,7 +107,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
     std::vector<llama_chat_message> chat(messages.size());
     std::vector<llama_chat_message> chat(messages.size());
 
 
     for (size_t i = 0; i < messages.size(); ++i) {
     for (size_t i = 0; i < messages.size(); ++i) {
-        auto &curr_msg = messages[i];
+        const auto & curr_msg = messages[i];
         str[i*2 + 0]    = json_value(curr_msg, "role",    std::string(""));
         str[i*2 + 0]    = json_value(curr_msg, "role",    std::string(""));
         str[i*2 + 1]    = json_value(curr_msg, "content", std::string(""));
         str[i*2 + 1]    = json_value(curr_msg, "content", std::string(""));
         alloc_size     += str[i*2 + 1].length();
         alloc_size     += str[i*2 + 1].length();
@@ -183,261 +127,13 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
         res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
         res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
     }
     }
 
 
-    std::string formatted_chat(buf.data(), res);
+    const std::string formatted_chat(buf.data(), res);
+
     LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
     LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
 
 
     return formatted_chat;
     return formatted_chat;
 }
 }
 
 
-//
-// work queue utils
-//
-
-struct llama_server_queue {
-    int id = 0;
-    std::mutex mutex_tasks;
-    bool running;
-    // queues
-    std::vector<task_server> queue_tasks;
-    std::vector<task_server> queue_tasks_deferred;
-    std::vector<task_multi> queue_multitasks;
-    std::condition_variable condition_tasks;
-    // callback functions
-    std::function<void(task_server&)> callback_new_task;
-    std::function<void(task_multi&)> callback_finish_multitask;
-    std::function<void(void)> callback_run_slots;
-
-    // Add a new task to the end of the queue
-    int post(task_server task) {
-        std::unique_lock<std::mutex> lock(mutex_tasks);
-        if (task.id == -1) {
-            task.id = id++;
-            LOG_VERBOSE("new task id", {{"new_id", task.id}});
-        }
-        queue_tasks.push_back(std::move(task));
-        condition_tasks.notify_one();
-        return task.id;
-    }
-
-    // Add a new task, but defer until one slot is available
-    void defer(task_server task) {
-        std::unique_lock<std::mutex> lock(mutex_tasks);
-        queue_tasks_deferred.push_back(std::move(task));
-    }
-
-    // Get the next id for creating anew task
-    int get_new_id() {
-        std::unique_lock<std::mutex> lock(mutex_tasks);
-        int new_id = id++;
-        LOG_VERBOSE("new task id", {{"new_id", new_id}});
-        return new_id;
-    }
-
-    // Register function to process a new task
-    void on_new_task(std::function<void(task_server&)> callback) {
-        callback_new_task = callback;
-    }
-
-    // Register function to process a multitask when it is finished
-    void on_finish_multitask(std::function<void(task_multi&)> callback) {
-        callback_finish_multitask = callback;
-    }
-
-    // Register the function to be called when all slots data is ready to be processed
-    void on_run_slots(std::function<void(void)> callback) {
-        callback_run_slots = callback;
-    }
-
-    // Call when the state of one slot is changed
-    void notify_slot_changed() {
-        // move deferred tasks back to main loop
-        std::unique_lock<std::mutex> lock(mutex_tasks);
-        for (auto & task : queue_tasks_deferred) {
-            queue_tasks.push_back(std::move(task));
-        }
-        queue_tasks_deferred.clear();
-    }
-
-    // end the start_loop routine
-    void terminate() {
-        {
-            std::unique_lock<std::mutex> lock(mutex_tasks);
-            running = false;
-        }
-        condition_tasks.notify_all();
-    }
-
-    /**
-     * Main loop consists of these steps:
-     * - Wait until a new task arrives
-     * - Process the task (i.e. maybe copy data into slot)
-     * - Check if multitask is finished
-     * - Run all slots
-     */
-    void start_loop() {
-        running = true;
-        while (true) {
-            LOG_VERBOSE("new task may arrive", {});
-            {
-                while (true)
-                {
-                    std::unique_lock<std::mutex> lock(mutex_tasks);
-                    if (queue_tasks.empty()) {
-                        lock.unlock();
-                        break;
-                    }
-                    task_server task = queue_tasks.front();
-                    queue_tasks.erase(queue_tasks.begin());
-                    lock.unlock();
-                    LOG_VERBOSE("callback_new_task", {{"task_id", task.id}});
-                    callback_new_task(task);
-                }
-                LOG_VERBOSE("update_multitasks", {});
-                // check if we have any finished multitasks
-                auto queue_iterator = queue_multitasks.begin();
-                while (queue_iterator != queue_multitasks.end())
-                {
-                    if (queue_iterator->subtasks_remaining.empty())
-                    {
-                        // all subtasks done == multitask is done
-                        task_multi current_multitask = *queue_iterator;
-                        callback_finish_multitask(current_multitask);
-                        // remove this multitask
-                        queue_iterator = queue_multitasks.erase(queue_iterator);
-                    }
-                    else
-                    {
-                        ++queue_iterator;
-                    }
-                }
-                // all tasks in the current loop is processed, slots data is now ready
-                LOG_VERBOSE("callback_run_slots", {});
-                callback_run_slots();
-            }
-            LOG_VERBOSE("wait for new task", {});
-            // wait for new task
-            {
-                std::unique_lock<std::mutex> lock(mutex_tasks);
-                if (queue_tasks.empty()) {
-                    if (!running) {
-                        LOG_VERBOSE("ending start_loop", {});
-                        return;
-                    }
-                    condition_tasks.wait(lock, [&]{
-                        return (!queue_tasks.empty() || !running);
-                    });
-                }
-            }
-        }
-    }
-
-    //
-    // functions to manage multitasks
-    //
-
-    // add a multitask by specifying the id of all subtask (subtask is a task_server)
-    void add_multitask(int multitask_id, std::vector<int>& sub_ids)
-    {
-        std::lock_guard<std::mutex> lock(mutex_tasks);
-        task_multi multi;
-        multi.id = multitask_id;
-        std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
-        queue_multitasks.push_back(multi);
-    }
-
-    // updatethe remaining subtasks, while appending results to multitask
-    void update_multitask(int multitask_id, int subtask_id, task_result& result)
-    {
-        std::lock_guard<std::mutex> lock(mutex_tasks);
-        for (auto& multitask : queue_multitasks)
-        {
-            if (multitask.id == multitask_id)
-            {
-                multitask.subtasks_remaining.erase(subtask_id);
-                multitask.results.push_back(result);
-            }
-        }
-    }
-};
-
-struct llama_server_response {
-    typedef std::function<void(int, int, task_result&)> callback_multitask_t;
-    callback_multitask_t callback_update_multitask;
-    // for keeping track of all tasks waiting for the result
-    std::set<int> waiting_task_ids;
-    // the main result queue
-    std::vector<task_result> queue_results;
-    std::mutex mutex_results;
-    std::condition_variable condition_results;
-
-    // add the task_id to the list of tasks waiting for response
-    void add_waiting_task_id(int task_id) {
-        LOG_VERBOSE("waiting for task id", {{"task_id", task_id}});
-        std::unique_lock<std::mutex> lock(mutex_results);
-        waiting_task_ids.insert(task_id);
-    }
-
-    // when the request is finished, we can remove task associated with it
-    void remove_waiting_task_id(int task_id) {
-        LOG_VERBOSE("remove waiting for task id", {{"task_id", task_id}});
-        std::unique_lock<std::mutex> lock(mutex_results);
-        waiting_task_ids.erase(task_id);
-    }
-
-    // This function blocks the thread until there is a response for this task_id
-    task_result recv(int task_id) {
-        while (true)
-        {
-            std::unique_lock<std::mutex> lock(mutex_results);
-            condition_results.wait(lock, [&]{
-                return !queue_results.empty();
-            });
-
-            for (int i = 0; i < (int) queue_results.size(); i++)
-            {
-                if (queue_results[i].id == task_id)
-                {
-                    assert(queue_results[i].multitask_id == -1);
-                    task_result res = queue_results[i];
-                    queue_results.erase(queue_results.begin() + i);
-                    return res;
-                }
-            }
-        }
-
-        // should never reach here
-    }
-
-    // Register the function to update multitask
-    void on_multitask_update(callback_multitask_t callback) {
-        callback_update_multitask = callback;
-    }
-
-    // Send a new result to a waiting task_id
-    void send(task_result result) {
-        std::unique_lock<std::mutex> lock(mutex_results);
-        LOG_VERBOSE("send new result", {{"task_id", result.id}});
-        for (auto& task_id : waiting_task_ids) {
-            // LOG_TEE("waiting task id %i \n", task_id);
-            // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
-            if (result.multitask_id == task_id)
-            {
-                LOG_VERBOSE("callback_update_multitask", {{"task_id", task_id}});
-                callback_update_multitask(task_id, result.id, result);
-                continue;
-            }
-
-            if (result.id == task_id)
-            {
-                LOG_VERBOSE("queue_results.push_back", {{"task_id", task_id}});
-                queue_results.push_back(result);
-                condition_results.notify_all();
-                return;
-            }
-        }
-    }
-};
-
 //
 //
 // base64 utils (TODO: move to common in the future)
 // base64 utils (TODO: move to common in the future)
 //
 //
@@ -447,13 +143,11 @@ static const std::string base64_chars =
              "abcdefghijklmnopqrstuvwxyz"
              "abcdefghijklmnopqrstuvwxyz"
              "0123456789+/";
              "0123456789+/";
 
 
-static inline bool is_base64(uint8_t c)
-{
+static inline bool is_base64(uint8_t c) {
     return (isalnum(c) || (c == '+') || (c == '/'));
     return (isalnum(c) || (c == '+') || (c == '/'));
 }
 }
 
 
-static inline std::vector<uint8_t> base64_decode(const std::string & encoded_string)
-{
+static inline std::vector<uint8_t> base64_decode(const std::string & encoded_string) {
     int i = 0;
     int i = 0;
     int j = 0;
     int j = 0;
     int in_ = 0;
     int in_ = 0;
@@ -465,13 +159,10 @@ static inline std::vector<uint8_t> base64_decode(const std::string & encoded_str
 
 
     std::vector<uint8_t> ret;
     std::vector<uint8_t> ret;
 
 
-    while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_]))
-    {
+    while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
         char_array_4[i++] = encoded_string[in_]; in_++;
         char_array_4[i++] = encoded_string[in_]; in_++;
-        if (i == 4)
-        {
-            for (i = 0; i <4; i++)
-            {
+        if (i == 4) {
+            for (i = 0; i < 4; i++) {
                 char_array_4[i] = base64_chars.find(char_array_4[i]);
                 char_array_4[i] = base64_chars.find(char_array_4[i]);
             }
             }
 
 
@@ -479,23 +170,20 @@ static inline std::vector<uint8_t> base64_decode(const std::string & encoded_str
             char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
             char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
             char_array_3[2] = ((char_array_4[2] & 0x3) << 6) +   char_array_4[3];
             char_array_3[2] = ((char_array_4[2] & 0x3) << 6) +   char_array_4[3];
 
 
-            for (i = 0; (i < 3); i++)
-            {
+            for (i = 0; (i < 3); i++) {
                 ret.push_back(char_array_3[i]);
                 ret.push_back(char_array_3[i]);
             }
             }
+
             i = 0;
             i = 0;
         }
         }
     }
     }
 
 
-    if (i)
-    {
-        for (j = i; j <4; j++)
-        {
+    if (i) {
+        for (j = i; j < 4; j++) {
             char_array_4[j] = 0;
             char_array_4[j] = 0;
         }
         }
 
 
-        for (j = 0; j <4; j++)
-        {
+        for (j = 0; j < 4; j++) {
             char_array_4[j] = base64_chars.find(char_array_4[j]);
             char_array_4[j] = base64_chars.find(char_array_4[j]);
         }
         }
 
 
@@ -503,8 +191,7 @@ static inline std::vector<uint8_t> base64_decode(const std::string & encoded_str
         char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
         char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
         char_array_3[2] = ((char_array_4[2] & 0x3) << 6) +   char_array_4[3];
         char_array_3[2] = ((char_array_4[2] & 0x3) << 6) +   char_array_4[3];
 
 
-        for (j = 0; (j < i - 1); j++)
-        {
+        for (j = 0; j < i - 1; j++) {
             ret.push_back(char_array_3[j]);
             ret.push_back(char_array_3[j]);
         }
         }
     }
     }
@@ -516,8 +203,7 @@ static inline std::vector<uint8_t> base64_decode(const std::string & encoded_str
 // random string / id
 // random string / id
 //
 //
 
 
-static std::string random_string()
-{
+static std::string random_string() {
     static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
     static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
 
 
     std::random_device rd;
     std::random_device rd;
@@ -532,10 +218,10 @@ static std::string random_string()
     return result;
     return result;
 }
 }
 
 
-static std::string gen_chatcmplid()
-{
+static std::string gen_chatcmplid() {
     std::stringstream chatcmplid;
     std::stringstream chatcmplid;
     chatcmplid << "chatcmpl-" << random_string();
     chatcmplid << "chatcmpl-" << random_string();
+
     return chatcmplid.str();
     return chatcmplid.str();
 }
 }
 
 
@@ -543,91 +229,316 @@ static std::string gen_chatcmplid()
 // other common utils
 // other common utils
 //
 //
 
 
-static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
-{
+static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
     size_t i;
     size_t i;
-    for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++)
-    {
-    }
+    for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
+
     return i;
     return i;
 }
 }
 
 
-static bool ends_with(const std::string &str, const std::string &suffix)
-{
-    return str.size() >= suffix.size() &&
-           0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
+static bool ends_with(const std::string & str, const std::string & suffix) {
+    return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
 }
 }
 
 
-static size_t find_partial_stop_string(const std::string &stop,
-                                       const std::string &text)
-{
-    if (!text.empty() && !stop.empty())
-    {
+static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
+    if (!text.empty() && !stop.empty()) {
         const char text_last_char = text.back();
         const char text_last_char = text.back();
-        for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--)
-        {
-            if (stop[char_index] == text_last_char)
-            {
+        for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
+            if (stop[char_index] == text_last_char) {
                 const std::string current_partial = stop.substr(0, char_index + 1);
                 const std::string current_partial = stop.substr(0, char_index + 1);
-                if (ends_with(text, current_partial))
-                {
+                if (ends_with(text, current_partial)) {
                     return text.size() - char_index - 1;
                     return text.size() - char_index - 1;
                 }
                 }
             }
             }
         }
         }
     }
     }
+
     return std::string::npos;
     return std::string::npos;
 }
 }
 
 
 // TODO: reuse llama_detokenize
 // TODO: reuse llama_detokenize
 template <class Iter>
 template <class Iter>
-static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
-{
+static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
     std::string ret;
     std::string ret;
-    for (; begin != end; ++begin)
-    {
+    for (; begin != end; ++begin) {
         ret += llama_token_to_piece(ctx, *begin);
         ret += llama_token_to_piece(ctx, *begin);
     }
     }
+
     return ret;
     return ret;
 }
 }
 
 
 // format incomplete utf-8 multibyte character for output
 // format incomplete utf-8 multibyte character for output
-static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token)
-{
+static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) {
     std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token);
     std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token);
+
     // if the size is 1 and first bit is 1, meaning it's a partial character
     // if the size is 1 and first bit is 1, meaning it's a partial character
     //   (size > 1 meaning it's already a known token)
     //   (size > 1 meaning it's already a known token)
-    if (out.size() == 1 && (out[0] & 0x80) == 0x80)
-    {
+    if (out.size() == 1 && (out[0] & 0x80) == 0x80) {
         std::stringstream ss;
         std::stringstream ss;
         ss << std::hex << (out[0] & 0xff);
         ss << std::hex << (out[0] & 0xff);
         std::string res(ss.str());
         std::string res(ss.str());
         out = "byte: \\x" + res;
         out = "byte: \\x" + res;
     }
     }
+
     return out;
     return out;
 }
 }
 
 
+struct completion_token_output {
+    llama_token tok;
+    std::string text_to_send;
+
+    struct token_prob {
+        llama_token tok;
+        float prob;
+    };
+
+    std::vector<token_prob> probs;
+};
+
 // convert a vector of completion_token_output to json
 // convert a vector of completion_token_output to json
-static json probs_vector_to_json(const llama_context *ctx, const std::vector<completion_token_output> &probs)
-{
+static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> & probs) {
     json out = json::array();
     json out = json::array();
-    for (const auto &prob : probs)
-    {
+
+    for (const auto & prob : probs) {
         json probs_for_token = json::array();
         json probs_for_token = json::array();
-        for (const auto &p : prob.probs)
-        {
-            std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
-            probs_for_token.push_back(json
-            {
+
+        for (const auto & p : prob.probs) {
+            const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
+            probs_for_token.push_back(json {
                 {"tok_str", tok_str},
                 {"tok_str", tok_str},
                 {"prob",    p.prob},
                 {"prob",    p.prob},
             });
             });
         }
         }
-        std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
-        out.push_back(json{
+
+        const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
+        out.push_back(json {
             {"content", tok_str},
             {"content", tok_str},
             {"probs",   probs_for_token},
             {"probs",   probs_for_token},
         });
         });
     }
     }
+
     return out;
     return out;
 }
 }
+
+//
+// OAI utils
+//
+
+static json oaicompat_completion_params_parse(
+    const struct llama_model * model,
+    const json & body, /* openai api json semantics */
+    const std::string & chat_template) {
+    json llama_params;
+
+    llama_params["__oaicompat"] = true;
+
+    // Map OpenAI parameters to llama.cpp parameters
+    //
+    // For parameters that are defined by the OpenAI documentation (e.g.
+    // temperature), we explicitly specify OpenAI's intended default; we
+    // need to do that because sometimes OpenAI disagrees with llama.cpp
+    //
+    // https://platform.openai.com/docs/api-reference/chat/create
+    llama_sampling_params default_sparams;
+    llama_params["model"]             = json_value(body,   "model",             std::string("unknown"));
+    llama_params["prompt"]            = format_chat(model, chat_template,       body["messages"]);
+    llama_params["cache_prompt"]      = json_value(body,   "cache_prompt",      false);
+    llama_params["temperature"]       = json_value(body,   "temperature",       0.0);
+    llama_params["top_k"]             = json_value(body,   "top_k",             default_sparams.top_k);
+    llama_params["top_p"]             = json_value(body,   "top_p",             1.0);
+    llama_params["n_predict"]         = json_value(body,   "max_tokens",        -1);
+    llama_params["logit_bias"]        = json_value(body,   "logit_bias",        json::object());
+    llama_params["frequency_penalty"] = json_value(body,   "frequency_penalty", 0.0);
+    llama_params["presence_penalty"]  = json_value(body,   "presence_penalty",  0.0);
+    llama_params["seed"]              = json_value(body,   "seed",              LLAMA_DEFAULT_SEED);
+    llama_params["stream"]            = json_value(body,   "stream",            false);
+    llama_params["mirostat"]          = json_value(body,   "mirostat",          default_sparams.mirostat);
+    llama_params["mirostat_tau"]      = json_value(body,   "mirostat_tau",      default_sparams.mirostat_tau);
+    llama_params["mirostat_eta"]      = json_value(body,   "mirostat_eta",      default_sparams.mirostat_eta);
+    llama_params["penalize_nl"]       = json_value(body,   "penalize_nl",       default_sparams.penalize_nl);
+    llama_params["typical_p"]         = json_value(body,   "typical_p",         default_sparams.typical_p);
+    llama_params["repeat_last_n"]     = json_value(body,   "repeat_last_n",     default_sparams.penalty_last_n);
+    llama_params["ignore_eos"]        = json_value(body,   "ignore_eos",        false);
+    llama_params["tfs_z"]             = json_value(body,   "tfs_z",             default_sparams.tfs_z);
+
+    if (body.count("grammar") != 0) {
+        llama_params["grammar"] = json_value(body, "grammar", json::object());
+    }
+
+    // Handle 'stop' field
+    if (body.contains("stop") && body["stop"].is_string()) {
+        llama_params["stop"] = json::array({body["stop"].get<std::string>()});
+    } else {
+        llama_params["stop"] = json_value(body, "stop", json::array());
+    }
+
+    // Ensure there is ChatML-specific end sequence among stop words
+    llama_params["stop"].push_back("<|im_end|>");
+
+    return llama_params;
+}
+
+static json format_final_response_oaicompat(const json & request, json result, bool streaming = false) {
+    bool stopped_word        = result.count("stopped_word") != 0;
+    bool stopped_eos         = json_value(result, "stopped_eos", false);
+    int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
+    int num_prompt_tokens    = json_value(result, "tokens_evaluated", 0);
+    std::string content      = json_value(result, "content", std::string(""));
+
+    std::string finish_reason = "length";
+    if (stopped_word || stopped_eos) {
+        finish_reason = "stop";
+    }
+
+    json choices =
+        streaming ? json::array({json{{"finish_reason", finish_reason},
+                                        {"index", 0},
+                                        {"delta", json::object()}}})
+                  : json::array({json{{"finish_reason", finish_reason},
+                                        {"index", 0},
+                                        {"message", json{{"content", content},
+                                                         {"role", "assistant"}}}}});
+
+    std::time_t t = std::time(0);
+
+    json res = json {
+        {"choices", choices},
+        {"created", t},
+        {"model",
+            json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+        {"object", streaming ? "chat.completion.chunk" : "chat.completion"},
+        {"usage", json {
+            {"completion_tokens", num_tokens_predicted},
+            {"prompt_tokens",     num_prompt_tokens},
+            {"total_tokens",      num_tokens_predicted + num_prompt_tokens}
+        }},
+        {"id", gen_chatcmplid()}
+    };
+
+    if (server_verbose) {
+        res["__verbose"] = result;
+    }
+
+    if (result.contains("completion_probabilities")) {
+        res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
+    }
+
+    return res;
+}
+
+// return value is vector as there is one case where we might need to generate two responses
+static std::vector<json> format_partial_response_oaicompat(json result) {
+    if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
+        return std::vector<json>({result});
+    }
+
+    bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
+    std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
+
+    bool stopped_word   = json_value(result, "stopped_word",  false);
+    bool stopped_eos    = json_value(result, "stopped_eos",   false);
+    bool stopped_limit  = json_value(result, "stopped_limit", false);
+    std::string content = json_value(result, "content",       std::string(""));
+
+    std::string finish_reason;
+    if (stopped_word || stopped_eos) {
+        finish_reason = "stop";
+    }
+    if (stopped_limit) {
+        finish_reason = "length";
+    }
+
+    std::time_t t = std::time(0);
+
+    json choices;
+
+    if (!finish_reason.empty()) {
+        choices = json::array({json{{"finish_reason", finish_reason},
+                                    {"index", 0},
+                                    {"delta", json::object()}}});
+    } else {
+        if (first) {
+            if (content.empty()) {
+                choices = json::array({json{{"finish_reason", nullptr},
+                                            {"index", 0},
+                                            {"delta", json{{"role", "assistant"}}}}});
+            } else {
+                // We have to send this as two updates to conform to openai behavior
+                json initial_ret = json{{"choices", json::array({json{
+                                        {"finish_reason", nullptr},
+                                        {"index", 0},
+                                        {"delta", json{
+                                            {"role", "assistant"}
+                                        }}}})},
+                            {"created", t},
+                            {"id", gen_chatcmplid()},
+                            {"model", modelname},
+                            {"object", "chat.completion.chunk"}};
+
+                json second_ret = json{
+                            {"choices", json::array({json{{"finish_reason", nullptr},
+                                                            {"index", 0},
+                                                            {"delta", json{
+                                                            {"content", content}}}
+                                                            }})},
+                            {"created", t},
+                            {"id", gen_chatcmplid()},
+                            {"model", modelname},
+                            {"object", "chat.completion.chunk"}};
+
+                return std::vector<json>({initial_ret, second_ret});
+            }
+        } else {
+            // Some idiosyncrasy in task processing logic makes several trailing calls
+            // with empty content, we ignore these at the calee site.
+            if (content.empty()) {
+                return std::vector<json>({json::object()});
+            }
+
+            choices = json::array({json{
+                {"finish_reason", nullptr},
+                {"index", 0},
+                {"delta",
+                json{
+                    {"content", content},
+                }},
+            }});
+        }
+    }
+
+    json ret = json {
+        {"choices", choices},
+        {"created", t},
+        {"id",      gen_chatcmplid()},
+        {"model",   modelname},
+        {"object",  "chat.completion.chunk"}
+    };
+
+    return std::vector<json>({ret});
+}
+
+static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
+    json res = json {
+        {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+        {"object", "list"},
+        {"usage", json {
+            {"prompt_tokens", 0},
+            {"total_tokens", 0}
+        }},
+        {"data", embeddings}
+    };
+
+    return res;
+}
+
+static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
+    return json {
+        {"tokens", tokens}
+    };
+}
+
+static json format_detokenized_response(const std::string & content) {
+    return json {
+        {"content", content}
+    };
+}

+ 5 - 1
llama.cpp

@@ -13541,18 +13541,22 @@ LLAMA_API int32_t llama_chat_apply_template(
             curr_tmpl = std::string(model_template.data(), model_template.size());
             curr_tmpl = std::string(model_template.data(), model_template.size());
         }
         }
     }
     }
+
     // format the chat to string
     // format the chat to string
     std::vector<const llama_chat_message *> chat_vec;
     std::vector<const llama_chat_message *> chat_vec;
     chat_vec.resize(n_msg);
     chat_vec.resize(n_msg);
     for (size_t i = 0; i < n_msg; i++) {
     for (size_t i = 0; i < n_msg; i++) {
         chat_vec[i] = &chat[i];
         chat_vec[i] = &chat[i];
     }
     }
+
     std::string formatted_chat;
     std::string formatted_chat;
     int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
     int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
     if (res < 0) {
     if (res < 0) {
         return res;
         return res;
     }
     }
-    strncpy(buf, formatted_chat.c_str(), length);
+    if (buf && length > 0) {
+        strncpy(buf, formatted_chat.c_str(), length);
+    }
     return res;
     return res;
 }
 }
 
 

Некоторые файлы не были показаны из-за большого количества измененных файлов