Преглед изворни кода

server : refactor (#5882)

* server : refactoring (wip)

* server : remove llava/clip objects from build

* server : fix empty prompt handling + all slots idle logic

* server : normalize id vars

* server : code style

* server : simplify model chat template validation

* server : code style

* server : minor

* llama : llama_chat_apply_template support null buf

* server : do not process embedding requests when disabled

* server : reorganize structs and enums + naming fixes

* server : merge oai.hpp in utils.hpp

* server : refactor system prompt update at start

* server : disable cached prompts with self-extend

* server : do not process more than n_batch tokens per iter

* server: tests: embeddings use a real embeddings model (#5908)

* server, tests : bump batch to fit 1 embedding prompt

* server: tests: embeddings fix build type Debug is randomly failing (#5911)

* server: tests: embeddings, use different KV Cache size

* server: tests: embeddings, fixed prompt do not exceed n_batch, increase embedding timeout, reduce number of concurrent embeddings

* server: tests: embeddings, no need to wait for server idle as it can timout

* server: refactor: clean up http code (#5912)

* server : avoid n_available var

ggml-ci

* server: refactor: better http codes

* server : simplify json parsing + add comment about t_last

* server : rename server structs

* server : allow to override FQDN in tests

ggml-ci

* server : add comments

---------

Co-authored-by: Pierrick Hymbert <pierrick.hymbert@gmail.com>
Georgi Gerganov пре 1 година
родитељ
комит
2002bc96bf

+ 2 - 1
.github/workflows/server.yml

@@ -58,7 +58,8 @@ jobs:
             cmake \
             python3-pip \
             wget \
-            psmisc
+            psmisc \
+            language-pack-en
 
       - name: Build
         id: cmake_build

+ 2 - 3
Makefile

@@ -724,10 +724,9 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
 	$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
 
-server: examples/server/server.cpp examples/server/oai.hpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h examples/llava/llava.h examples/llava/llava.cpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
+server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
-	$(CXX) $(CXXFLAGS) -c examples/llava/clip.cpp -o $(call GET_OBJ_FILE, examples/llava/clip.cpp) -Wno-cast-qual
-	$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h %.hpp $< examples/llava/clip.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) -o $@ $(LDFLAGS) $(LWINSOCK2)
+	$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
 
 gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)
 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

+ 1 - 1
examples/server-embd.py

@@ -13,7 +13,7 @@ async def main():
     model_url = "http://127.0.0.1:6900"
     responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
         url= f"{model_url}/embedding",
-        json= {"content": str(i)*1024}
+        json= {"content": str(0)*1024}
     ) for i in range(n)])
 
     for response in responses:

+ 2 - 2
examples/server/CMakeLists.txt

@@ -1,12 +1,12 @@
 set(TARGET server)
 option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
 include_directories(${CMAKE_CURRENT_SOURCE_DIR})
-add_executable(${TARGET} server.cpp oai.hpp utils.hpp json.hpp httplib.h)
+add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
 install(TARGETS ${TARGET} RUNTIME)
 target_compile_definitions(${TARGET} PRIVATE
     SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
 )
-target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
+target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
 if (WIN32)
     TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
 endif()

+ 1 - 1
examples/server/README.md

@@ -436,7 +436,7 @@ Notice that each `probs` is an array of length `n_probs`.
         "next_token": {
             "has_next_token": true,
             "n_remain": -1,
-            "num_tokens_predicted": 0,
+            "n_decoded": 0,
             "stopped_eos": false,
             "stopped_limit": false,
             "stopped_word": false,

+ 0 - 225
examples/server/oai.hpp

@@ -1,225 +0,0 @@
-#pragma once
-
-#include <string>
-#include <vector>
-#include <set>
-#include <mutex>
-#include <condition_variable>
-#include <unordered_map>
-
-#include "json.hpp"
-#include "utils.hpp"
-
-#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
-
-using json = nlohmann::json;
-
-inline static json oaicompat_completion_params_parse(
-    const struct llama_model * model,
-    const json &body, /* openai api json semantics */
-    const std::string &chat_template)
-{
-    json llama_params;
-
-    llama_params["__oaicompat"] = true;
-
-    // Map OpenAI parameters to llama.cpp parameters
-    //
-    // For parameters that are defined by the OpenAI documentation (e.g.
-    // temperature), we explicitly specify OpenAI's intended default; we
-    // need to do that because sometimes OpenAI disagrees with llama.cpp
-    //
-    // https://platform.openai.com/docs/api-reference/chat/create
-    llama_sampling_params default_sparams;
-    llama_params["model"]             = json_value(body, "model", std::string("unknown"));
-    llama_params["prompt"]            = format_chat(model, chat_template, body["messages"]);
-    llama_params["cache_prompt"]      = json_value(body, "cache_prompt", false);
-    llama_params["temperature"]       = json_value(body, "temperature", 0.0);
-    llama_params["top_k"]             = json_value(body, "top_k", default_sparams.top_k);
-    llama_params["top_p"]             = json_value(body, "top_p", 1.0);
-    llama_params["n_predict"]         = json_value(body, "max_tokens", -1);
-    llama_params["logit_bias"]        = json_value(body, "logit_bias",json::object());
-    llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0);
-    llama_params["presence_penalty"]  = json_value(body, "presence_penalty", 0.0);
-    llama_params["seed"]              = json_value(body, "seed", LLAMA_DEFAULT_SEED);
-    llama_params["stream"]            = json_value(body, "stream", false);
-    llama_params["mirostat"]          = json_value(body, "mirostat", default_sparams.mirostat);
-    llama_params["mirostat_tau"]      = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
-    llama_params["mirostat_eta"]      = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
-    llama_params["penalize_nl"]       = json_value(body, "penalize_nl", default_sparams.penalize_nl);
-    llama_params["typical_p"]         = json_value(body, "typical_p", default_sparams.typical_p);
-    llama_params["repeat_last_n"]     = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
-    llama_params["ignore_eos"]        = json_value(body, "ignore_eos", false);
-    llama_params["tfs_z"]             = json_value(body, "tfs_z", default_sparams.tfs_z);
-
-    if (body.count("grammar") != 0) {
-        llama_params["grammar"] = json_value(body, "grammar", json::object());
-    }
-
-    // Handle 'stop' field
-    if (body.contains("stop") && body["stop"].is_string()) {
-        llama_params["stop"] = json::array({body["stop"].get<std::string>()});
-    } else {
-        llama_params["stop"] = json_value(body, "stop", json::array());
-    }
-
-    // Ensure there is ChatML-specific end sequence among stop words
-    llama_params["stop"].push_back("<|im_end|>");
-
-    return llama_params;
-}
-
-inline static json format_final_response_oaicompat(const json &request, const task_result &response, bool streaming = false)
-{
-    json result = response.result_json;
-
-    bool stopped_word        = result.count("stopped_word") != 0;
-    bool stopped_eos         = json_value(result, "stopped_eos", false);
-    int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
-    int num_prompt_tokens    = json_value(result, "tokens_evaluated", 0);
-    std::string content      = json_value(result, "content", std::string(""));
-
-    std::string finish_reason = "length";
-    if (stopped_word || stopped_eos) {
-        finish_reason = "stop";
-    }
-
-    json choices =
-        streaming ? json::array({json{{"finish_reason", finish_reason},
-                                        {"index", 0},
-                                        {"delta", json::object()}}})
-                  : json::array({json{{"finish_reason", finish_reason},
-                                        {"index", 0},
-                                        {"message", json{{"content", content},
-                                                         {"role", "assistant"}}}}});
-
-    std::time_t t = std::time(0);
-
-    json res =
-        json{{"choices", choices},
-            {"created", t},
-            {"model",
-                json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
-            {"object", streaming ? "chat.completion.chunk" : "chat.completion"},
-            {"usage",
-                json{{"completion_tokens", num_tokens_predicted},
-                     {"prompt_tokens",     num_prompt_tokens},
-                     {"total_tokens",      num_tokens_predicted + num_prompt_tokens}}},
-            {"id", gen_chatcmplid()}};
-
-    if (server_verbose) {
-        res["__verbose"] = result;
-    }
-
-    if (result.contains("completion_probabilities")) {
-        res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
-    }
-
-    return res;
-}
-
-// return value is vector as there is one case where we might need to generate two responses
-inline static std::vector<json> format_partial_response_oaicompat(const task_result &response) {
-    json result = response.result_json;
-
-    if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
-        return std::vector<json>({response.result_json});
-    }
-
-    bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
-    std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
-
-    bool stopped_word   = json_value(result, "stopped_word", false);
-    bool stopped_eos    = json_value(result, "stopped_eos", false);
-    bool stopped_limit  = json_value(result, "stopped_limit", false);
-    std::string content = json_value(result, "content", std::string(""));
-
-    std::string finish_reason;
-    if (stopped_word || stopped_eos) {
-        finish_reason = "stop";
-    }
-    if (stopped_limit) {
-        finish_reason = "length";
-    }
-
-    std::time_t t = std::time(0);
-
-    json choices;
-
-    if (!finish_reason.empty()) {
-        choices = json::array({json{{"finish_reason", finish_reason},
-                                    {"index", 0},
-                                    {"delta", json::object()}}});
-    } else {
-        if (first) {
-            if (content.empty()) {
-                choices = json::array({json{{"finish_reason", nullptr},
-                                            {"index", 0},
-                                            {"delta", json{{"role", "assistant"}}}}});
-            } else {
-                // We have to send this as two updates to conform to openai behavior
-                json initial_ret = json{{"choices", json::array({json{
-                                        {"finish_reason", nullptr},
-                                        {"index", 0},
-                                        {"delta", json{
-                                            {"role", "assistant"}
-                                        }}}})},
-                            {"created", t},
-                            {"id", gen_chatcmplid()},
-                            {"model", modelname},
-                            {"object", "chat.completion.chunk"}};
-
-                json second_ret = json{
-                            {"choices", json::array({json{{"finish_reason", nullptr},
-                                                            {"index", 0},
-                                                            {"delta", json{
-                                                            {"content", content}}}
-                                                            }})},
-                            {"created", t},
-                            {"id", gen_chatcmplid()},
-                            {"model", modelname},
-                            {"object", "chat.completion.chunk"}};
-
-                return std::vector<json>({initial_ret, second_ret});
-            }
-        } else {
-            // Some idiosyncrasy in task processing logic makes several trailing calls
-            // with empty content, we ignore these at the calee site.
-            if (content.empty()) {
-                return std::vector<json>({json::object()});
-            }
-
-            choices = json::array({json{
-                {"finish_reason", nullptr},
-                {"index", 0},
-                {"delta",
-                json{
-                    {"content", content},
-                }},
-            }});
-        }
-    }
-
-    json ret = json{{"choices", choices},
-                    {"created", t},
-                    {"id", gen_chatcmplid()},
-                    {"model", modelname},
-                    {"object", "chat.completion.chunk"}};
-
-    return std::vector<json>({ret});
-}
-
-inline static json format_embeddings_response_oaicompat(const json &request, const json &embeddings)
-{
-    json res =
-        json{
-            {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
-            {"object", "list"},
-            {"usage",
-                json{{"prompt_tokens", 0},
-                     {"total_tokens", 0}}},
-            {"data", embeddings}
-        };
-    return res;
-}
-

Разлика између датотеке није приказан због своје велике величине
+ 576 - 424
examples/server/server.cpp


+ 94 - 0
examples/server/tests/features/embeddings.feature

@@ -0,0 +1,94 @@
+@llama.cpp
+@embeddings
+Feature: llama.cpp server
+
+  Background: Server startup
+    Given a server listening on localhost:8080
+    And   a model file bert-bge-small/ggml-model-f16.gguf from HF repo ggml-org/models
+    And   a model alias bert-bge-small
+    And   42 as server seed
+    And   2 slots
+    And   1024 as batch size
+    And   2048 KV cache size
+    And   embeddings extraction
+    Then  the server is starting
+    Then  the server is healthy
+
+  Scenario: Embedding
+    When embeddings are computed for:
+    """
+    What is the capital of Bulgaria ?
+    """
+    Then embeddings are generated
+
+  Scenario: OAI Embeddings compatibility
+    Given a model bert-bge-small
+    When an OAI compatible embeddings computation request for:
+    """
+    What is the capital of Spain ?
+    """
+    Then embeddings are generated
+
+  Scenario: OAI Embeddings compatibility with multiple inputs
+    Given a model bert-bge-small
+    Given a prompt:
+      """
+      In which country Paris is located ?
+      """
+    And a prompt:
+      """
+      Is Madrid the capital of Spain ?
+      """
+    When an OAI compatible embeddings computation request for multiple inputs
+    Then embeddings are generated
+
+  Scenario: Multi users embeddings
+    Given a prompt:
+      """
+      Write a very long story about AI.
+      """
+    And a prompt:
+      """
+      Write another very long music lyrics.
+      """
+    And a prompt:
+      """
+      Write a very long poem.
+      """
+    And a prompt:
+      """
+      Write a very long joke.
+      """
+    Given concurrent embedding requests
+    Then the server is busy
+    Then the server is idle
+    Then all embeddings are generated
+
+  Scenario: Multi users OAI compatibility embeddings
+    Given a prompt:
+      """
+      In which country Paris is located ?
+      """
+    And a prompt:
+      """
+      Is Madrid the capital of Spain ?
+      """
+    And a prompt:
+      """
+      What is the biggest US city ?
+      """
+    And a prompt:
+      """
+      What is the capital of Bulgaria ?
+      """
+    And   a model bert-bge-small
+    Given concurrent OAI embedding requests
+    Then the server is busy
+    Then the server is idle
+    Then all embeddings are generated
+
+  Scenario: All embeddings should be the same
+    Given 10 fixed prompts
+    And   a model bert-bge-small
+    Given concurrent OAI embedding requests
+    Then all embeddings are the same

+ 0 - 46
examples/server/tests/features/parallel.feature

@@ -9,7 +9,6 @@ Feature: Parallel
     And   512 as batch size
     And   64 KV cache size
     And   2 slots
-    And   embeddings extraction
     And   continuous batching
     Then  the server is starting
     Then  the server is healthy
@@ -99,48 +98,3 @@ Feature: Parallel
     Then the server is busy
     Then the server is idle
     Then all prompts are predicted
-
-  Scenario: Multi users embeddings
-    Given a prompt:
-      """
-      Write a very long story about AI.
-      """
-    And a prompt:
-      """
-      Write another very long music lyrics.
-      """
-    And a prompt:
-      """
-      Write a very long poem.
-      """
-    And a prompt:
-      """
-      Write a very long joke.
-      """
-    Given concurrent embedding requests
-    Then the server is busy
-    Then the server is idle
-    Then all embeddings are generated
-
-  Scenario: Multi users OAI compatibility embeddings
-    Given a prompt:
-      """
-      In which country Paris is located ?
-      """
-    And a prompt:
-      """
-      Is Madrid the capital of Spain ?
-      """
-    And a prompt:
-      """
-      What is the biggest US city ?
-      """
-    And a prompt:
-      """
-      What is the capital of Bulgaria ?
-      """
-    And   a model tinyllama-2
-    Given concurrent OAI embedding requests
-    Then the server is busy
-    Then the server is idle
-    Then all embeddings are generated

+ 0 - 28
examples/server/tests/features/server.feature

@@ -49,34 +49,6 @@ Feature: llama.cpp server
       | llama-2      | Book                        | What is the best book                | 8          | (Mom\|what)+           | 8           | disabled         |
       | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64         | (thanks\|happy\|bird)+ | 32          | enabled          |
 
-  Scenario: Embedding
-    When embeddings are computed for:
-    """
-    What is the capital of Bulgaria ?
-    """
-    Then embeddings are generated
-
-  Scenario: OAI Embeddings compatibility
-    Given a model tinyllama-2
-    When an OAI compatible embeddings computation request for:
-    """
-    What is the capital of Spain ?
-    """
-    Then embeddings are generated
-
-  Scenario: OAI Embeddings compatibility with multiple inputs
-    Given a model tinyllama-2
-    Given a prompt:
-      """
-      In which country Paris is located ?
-      """
-    And a prompt:
-      """
-      Is Madrid the capital of Spain ?
-      """
-    When an OAI compatible embeddings computation request for multiple inputs
-    Then embeddings are generated
-
   Scenario: Tokenize / Detokenize
     When tokenizing:
     """

+ 67 - 19
examples/server/tests/features/steps/steps.py

@@ -10,6 +10,7 @@ from contextlib import closing
 from re import RegexFlag
 
 import aiohttp
+import numpy as np
 import openai
 from behave import step
 from behave.api.async_step import async_run_until_complete
@@ -24,6 +25,9 @@ def step_server_config(context, server_fqdn, server_port):
     if 'PORT' in os.environ:
         context.server_port = int(os.environ['PORT'])
         print(f"$PORT set, overriding server port with to {context.server_port}")
+    if 'FQDN' in os.environ:
+        context.server_fqdn = os.environ['FQDN']
+        print(f"$FQDN set, overriding server fqdn with to {context.server_fqdn}")
 
     context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
 
@@ -34,6 +38,7 @@ def step_server_config(context, server_fqdn, server_port):
     context.n_ga_w = None
     context.n_gpu_layer = None
     context.n_predict = None
+    context.n_prompts = 0
     context.n_server_predict = None
     context.n_slots = None
     context.prompt_prefix = None
@@ -202,6 +207,7 @@ def step_n_tokens_predicted(context, predicted_n):
 @step(u'a user prompt {user_prompt}')
 def step_user_prompt(context, user_prompt):
     context.prompts.append(user_prompt)
+    context.n_prompts = len(context.prompts)
 
 
 @step(u'a system prompt {system_prompt}')
@@ -290,6 +296,12 @@ def step_prompt_passkey(context):
     context.prompt_passkey = context.text
 
 
+@step(u'{n_prompts:d} fixed prompts')
+def step_fixed_prompts(context, n_prompts):
+    context.prompts.extend([str(0)*(context.n_batch if context.n_batch is not None else 512) for i in range(n_prompts)])
+    context.n_prompts = n_prompts
+
+
 @step(u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk')
 def step_prompt_passkey(context, passkey, i_pos):
     prompt = ""
@@ -301,6 +313,7 @@ def step_prompt_passkey(context, passkey, i_pos):
         passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
         print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n")
     context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
+    context.n_prompts = len(context.prompts)
 
 
 @step(u'an OAI compatible chat completions request with {api_error} api error')
@@ -341,11 +354,13 @@ async def step_oai_chat_completions(context, api_error):
 @step(u'a prompt')
 def step_a_prompt(context):
     context.prompts.append(context.text)
+    context.n_prompts = len(context.prompts)
 
 
 @step(u'a prompt {prompt}')
 def step_a_prompt_prompt(context, prompt):
     context.prompts.append(prompt)
+    context.n_prompts = len(context.prompts)
 
 
 @step(u'concurrent completion requests')
@@ -430,25 +445,47 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
 @step(u'embeddings are computed for')
 @async_run_until_complete
 async def step_compute_embedding(context):
+    context.n_prompts = 1
     context.embeddings = await request_embedding(context.text, base_url=context.base_url)
 
 
+@step(u'all embeddings are the same')
+@async_run_until_complete
+async def step_all_embeddings_are_the_same(context):
+    n_embedding_requests = await gather_tasks_results(context)
+    assert n_embedding_requests > 0
+    embeddings = []
+    for i in range(n_embedding_requests):
+        embedding = context.tasks_result.pop().pop()
+        embeddings.append(embedding)
+        assert_embeddings(embedding)
+    n = len(embeddings)
+    for i in range(n-1):
+        for j in range(i+1, n):
+            embedding1 = np.array(embeddings[i])
+            embedding2 = np.array(embeddings[j])
+            if context.debug:
+                print(f"embedding1: {embedding1[-8:]}\n")
+                print(f"embedding2: {embedding2[-8:]}\n")
+            similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
+            msg = f"Similarity between {i} and {j}: {similarity:.10f}"
+            if context.debug:
+                print(f"{msg}\n")
+            assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg
+
 @step(u'embeddings are generated')
 def step_assert_embeddings(context):
-    if len(context.prompts) == 0:
-        assert_embeddings(context.embeddings)
-    else:
-        assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n"
-                                                                 f"context.prompts={context.prompts}\n"
-                                                                 f"context.embeddings={context.embeddings}")
-        for embedding in context.embeddings:
-            context.prompts.pop()
-            assert_embeddings(embedding)
+    assert context.n_prompts == len(context.embeddings), (f"unexpected response:\n"
+                                                             f"context.n_prompts={context.n_prompts}\n"
+                                                             f"context.embeddings={context.embeddings}")
+    for embedding in context.embeddings:
+        assert_embeddings(embedding)
 
 
 @step(u'an OAI compatible embeddings computation request for')
 @async_run_until_complete
 async def step_oai_compute_embeddings(context):
+    context.n_prompts = 1
     context.embeddings = await request_oai_embeddings(context.text,
                                                       base_url=context.base_url,
                                                       user_api_key=context.user_api_key,
@@ -462,6 +499,7 @@ async def step_oai_compute_embeddings_multiple_inputs(context):
                                                       base_url=context.base_url,
                                                       user_api_key=context.user_api_key,
                                                       model=context.model)
+    context.prompts.clear()
 
 
 @step(u'concurrent embedding requests')
@@ -488,9 +526,9 @@ async def step_concurrent_oai_embedding_requests(context):
 @async_run_until_complete()
 async def all_embeddings_are_generated(context):
     n_embedding_requests = await gather_tasks_results(context)
-    assert n_embedding_requests > 0
+    assert n_embedding_requests == context.n_prompts
     for i in range(n_embedding_requests):
-        assert_embeddings(context.tasks_result.pop())
+        assert_embeddings(context.tasks_result.pop().pop())
 
 
 @step(u'tokenizing')
@@ -588,11 +626,11 @@ def step_supported_models(context, i_model, param, preposition, param_value):
 
 
 async def concurrent_requests(context, f_completion, *args, **kwargs):
-    n_prompts = len(context.prompts)
+    context.n_prompts = len(context.prompts)
     if context.debug:
-        print(f"starting {n_prompts} concurrent completion requests...")
-    assert n_prompts > 0
-    for prompt_no in range(n_prompts):
+        print(f"starting {context.n_prompts} concurrent completion requests...")
+    assert context.n_prompts > 0
+    for prompt_no in range(context.n_prompts):
         shifted_args = [context.prompts.pop(), *args]
         context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
     await asyncio.sleep(0.1)
@@ -765,7 +803,7 @@ async def request_embedding(content, base_url=None):
                                 }) as response:
             assert response.status == 200
             response_json = await response.json()
-            return response_json['embedding']
+            return [response_json['embedding']]
 
 
 async def request_oai_embeddings(input,
@@ -775,6 +813,7 @@ async def request_oai_embeddings(input,
     user_api_key = user_api_key if user_api_key is not None else 'nope'
     if async_client:
         origin = 'llama.cpp'
+        headers=[]
         if user_api_key is not None:
             headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
         async with aiohttp.ClientSession() as session:
@@ -783,14 +822,21 @@ async def request_oai_embeddings(input,
                                         "input": input,
                                         "model": model,
                                     },
-                                    headers=headers) as response:
+                                    headers=headers,
+                                    timeout=3600) as response:
                 assert response.status == 200, f"received status code not expected: {response.status}"
                 assert response.headers['Access-Control-Allow-Origin'] == origin
                 assert response.headers['Content-Type'] == "application/json; charset=utf-8"
                 response_json = await response.json()
                 assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
                 assert response_json['object'] == 'list'
-                return response_json['data']
+                if isinstance(input, collections.abc.Sequence):
+                    embeddings = []
+                    for an_oai_embeddings in response_json['data']:
+                        embeddings.append(an_oai_embeddings['embedding'])
+                else:
+                    embeddings = [response_json['data']['embedding']]
+                return embeddings
     else:
         openai.api_key = user_api_key
         openai.api_base = f'{base_url}/v1'
@@ -804,7 +850,7 @@ async def request_oai_embeddings(input,
             for an_oai_embeddings in oai_embeddings.data:
                 embeddings.append(an_oai_embeddings.embedding)
         else:
-            embeddings = oai_embeddings.data.embedding
+            embeddings = [oai_embeddings.data.embedding]
         return embeddings
 
 
@@ -899,6 +945,8 @@ def assert_embeddings(embeddings):
     assert len(embeddings) > 0
     embeddings_computed = False
     for emb in embeddings:
+        if not isinstance(emb, float):
+            assert False, f"Bad embeddings: {embeddings}"
         if emb != 0:
             embeddings_computed = True
     assert embeddings_computed, f"Embeddings: {embeddings}"

+ 1 - 0
examples/server/tests/requirements.txt

@@ -1,5 +1,6 @@
 aiohttp~=3.9.3
 behave~=1.2.6
 huggingface_hub~=0.20.3
+numpy~=1.24.4
 openai~=0.25.0
 prometheus-client~=0.20.0

+ 307 - 396
examples/server/utils.hpp

@@ -1,15 +1,16 @@
 #pragma once
 
-#include <string>
-#include <vector>
-#include <set>
-#include <mutex>
-#include <condition_variable>
-#include <unordered_map>
+#include "llama.h"
+#include "common.h"
 
 #include "json.hpp"
 
-#include "../llava/clip.h"
+#include <string>
+#include <vector>
+#include <sstream>
+#include <random>
+
+#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
 
 using json = nlohmann::json;
 
@@ -37,83 +38,35 @@ extern bool server_log_json;
 #define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__)
 #define LOG_INFO(   MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
 
-enum server_state {
-    SERVER_STATE_LOADING_MODEL,  // Server is starting up, model not fully loaded yet
-    SERVER_STATE_READY,          // Server is ready and model is loaded
-    SERVER_STATE_ERROR           // An error occurred, load_model failed
-};
-
-enum task_type {
-    TASK_TYPE_COMPLETION,
-    TASK_TYPE_CANCEL,
-    TASK_TYPE_NEXT_RESPONSE,
-    TASK_TYPE_METRICS
-};
-
-struct task_server {
-    int id = -1; // to be filled by llama_server_queue
-    int target_id;
-    task_type type;
-    json data;
-    bool infill_mode = false;
-    bool embedding_mode = false;
-    int multitask_id = -1;
-};
-
-struct task_result {
-    int id;
-    int multitask_id = -1;
-    bool stop;
-    bool error;
-    json result_json;
-};
-
-struct task_multi {
-    int id;
-    std::set<int> subtasks_remaining{};
-    std::vector<task_result> results{};
-};
-
-// completion token output with probabilities
-struct completion_token_output {
-    struct token_prob
-    {
-        llama_token tok;
-        float prob;
-    };
-
-    std::vector<token_prob> probs;
-    llama_token tok;
-    std::string text_to_send;
-};
-
-struct token_translator {
-    llama_context * ctx;
-    std::string operator()(llama_token tok)                    const { return llama_token_to_piece(ctx, tok); }
-    std::string operator()(const completion_token_output &cto) const { return (*this)(cto.tok); }
-};
+template <typename T>
+static T json_value(const json &body, const std::string &key, const T &default_value) {
+    // Fallback null to default value
+    return body.contains(key) && !body.at(key).is_null()
+        ? body.value(key, default_value)
+        : default_value;
+}
 
 static inline void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra) {
     std::stringstream ss_tid;
     ss_tid << std::this_thread::get_id();
     json log = nlohmann::ordered_json{
-        {"tid", ss_tid.str()},
+        {"tid",       ss_tid.str()},
         {"timestamp", time(nullptr)},
     };
 
     if (server_log_json) {
-        log.merge_patch(
-                {
-                        {"level",     level},
-                        {"function",  function},
-                        {"line",      line},
-                        {"msg",       message},
-                });
+        log.merge_patch( {
+            {"level",    level},
+            {"function", function},
+            {"line",     line},
+            {"msg",      message},
+        });
+
         if (!extra.empty()) {
             log.merge_patch(extra);
         }
 
-        std::cout << log.dump(-1, ' ', false, json::error_handler_t::replace) << "\n" << std::flush;
+        printf("%s\n", log.dump(-1, ' ', false, json::error_handler_t::replace).c_str());
     } else {
         char buf[1024];
         snprintf(buf, 1024, "%4s [%24s] %s", level, function, message);
@@ -136,22 +89,13 @@ static inline void server_log(const char *level, const char *function, int line,
 }
 
 //
-// server utils
+// chat template utils
 //
 
-template <typename T>
-static T json_value(const json &body, const std::string &key, const T &default_value) {
-    // Fallback null to default value
-    return body.contains(key) && !body.at(key).is_null()
-        ? body.value(key, default_value)
-        : default_value;
-}
-
 // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
 inline bool verify_custom_template(const std::string & tmpl) {
     llama_chat_message chat[] = {{"user", "test"}};
-    std::vector<char> buf(1);
-    int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, buf.data(), buf.size());
+    int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
     return res >= 0;
 }
 
@@ -163,7 +107,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
     std::vector<llama_chat_message> chat(messages.size());
 
     for (size_t i = 0; i < messages.size(); ++i) {
-        auto &curr_msg = messages[i];
+        const auto & curr_msg = messages[i];
         str[i*2 + 0]    = json_value(curr_msg, "role",    std::string(""));
         str[i*2 + 1]    = json_value(curr_msg, "content", std::string(""));
         alloc_size     += str[i*2 + 1].length();
@@ -183,261 +127,13 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
         res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
     }
 
-    std::string formatted_chat(buf.data(), res);
+    const std::string formatted_chat(buf.data(), res);
+
     LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
 
     return formatted_chat;
 }
 
-//
-// work queue utils
-//
-
-struct llama_server_queue {
-    int id = 0;
-    std::mutex mutex_tasks;
-    bool running;
-    // queues
-    std::vector<task_server> queue_tasks;
-    std::vector<task_server> queue_tasks_deferred;
-    std::vector<task_multi> queue_multitasks;
-    std::condition_variable condition_tasks;
-    // callback functions
-    std::function<void(task_server&)> callback_new_task;
-    std::function<void(task_multi&)> callback_finish_multitask;
-    std::function<void(void)> callback_run_slots;
-
-    // Add a new task to the end of the queue
-    int post(task_server task) {
-        std::unique_lock<std::mutex> lock(mutex_tasks);
-        if (task.id == -1) {
-            task.id = id++;
-            LOG_VERBOSE("new task id", {{"new_id", task.id}});
-        }
-        queue_tasks.push_back(std::move(task));
-        condition_tasks.notify_one();
-        return task.id;
-    }
-
-    // Add a new task, but defer until one slot is available
-    void defer(task_server task) {
-        std::unique_lock<std::mutex> lock(mutex_tasks);
-        queue_tasks_deferred.push_back(std::move(task));
-    }
-
-    // Get the next id for creating anew task
-    int get_new_id() {
-        std::unique_lock<std::mutex> lock(mutex_tasks);
-        int new_id = id++;
-        LOG_VERBOSE("new task id", {{"new_id", new_id}});
-        return new_id;
-    }
-
-    // Register function to process a new task
-    void on_new_task(std::function<void(task_server&)> callback) {
-        callback_new_task = callback;
-    }
-
-    // Register function to process a multitask when it is finished
-    void on_finish_multitask(std::function<void(task_multi&)> callback) {
-        callback_finish_multitask = callback;
-    }
-
-    // Register the function to be called when all slots data is ready to be processed
-    void on_run_slots(std::function<void(void)> callback) {
-        callback_run_slots = callback;
-    }
-
-    // Call when the state of one slot is changed
-    void notify_slot_changed() {
-        // move deferred tasks back to main loop
-        std::unique_lock<std::mutex> lock(mutex_tasks);
-        for (auto & task : queue_tasks_deferred) {
-            queue_tasks.push_back(std::move(task));
-        }
-        queue_tasks_deferred.clear();
-    }
-
-    // end the start_loop routine
-    void terminate() {
-        {
-            std::unique_lock<std::mutex> lock(mutex_tasks);
-            running = false;
-        }
-        condition_tasks.notify_all();
-    }
-
-    /**
-     * Main loop consists of these steps:
-     * - Wait until a new task arrives
-     * - Process the task (i.e. maybe copy data into slot)
-     * - Check if multitask is finished
-     * - Run all slots
-     */
-    void start_loop() {
-        running = true;
-        while (true) {
-            LOG_VERBOSE("new task may arrive", {});
-            {
-                while (true)
-                {
-                    std::unique_lock<std::mutex> lock(mutex_tasks);
-                    if (queue_tasks.empty()) {
-                        lock.unlock();
-                        break;
-                    }
-                    task_server task = queue_tasks.front();
-                    queue_tasks.erase(queue_tasks.begin());
-                    lock.unlock();
-                    LOG_VERBOSE("callback_new_task", {{"task_id", task.id}});
-                    callback_new_task(task);
-                }
-                LOG_VERBOSE("update_multitasks", {});
-                // check if we have any finished multitasks
-                auto queue_iterator = queue_multitasks.begin();
-                while (queue_iterator != queue_multitasks.end())
-                {
-                    if (queue_iterator->subtasks_remaining.empty())
-                    {
-                        // all subtasks done == multitask is done
-                        task_multi current_multitask = *queue_iterator;
-                        callback_finish_multitask(current_multitask);
-                        // remove this multitask
-                        queue_iterator = queue_multitasks.erase(queue_iterator);
-                    }
-                    else
-                    {
-                        ++queue_iterator;
-                    }
-                }
-                // all tasks in the current loop is processed, slots data is now ready
-                LOG_VERBOSE("callback_run_slots", {});
-                callback_run_slots();
-            }
-            LOG_VERBOSE("wait for new task", {});
-            // wait for new task
-            {
-                std::unique_lock<std::mutex> lock(mutex_tasks);
-                if (queue_tasks.empty()) {
-                    if (!running) {
-                        LOG_VERBOSE("ending start_loop", {});
-                        return;
-                    }
-                    condition_tasks.wait(lock, [&]{
-                        return (!queue_tasks.empty() || !running);
-                    });
-                }
-            }
-        }
-    }
-
-    //
-    // functions to manage multitasks
-    //
-
-    // add a multitask by specifying the id of all subtask (subtask is a task_server)
-    void add_multitask(int multitask_id, std::vector<int>& sub_ids)
-    {
-        std::lock_guard<std::mutex> lock(mutex_tasks);
-        task_multi multi;
-        multi.id = multitask_id;
-        std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
-        queue_multitasks.push_back(multi);
-    }
-
-    // updatethe remaining subtasks, while appending results to multitask
-    void update_multitask(int multitask_id, int subtask_id, task_result& result)
-    {
-        std::lock_guard<std::mutex> lock(mutex_tasks);
-        for (auto& multitask : queue_multitasks)
-        {
-            if (multitask.id == multitask_id)
-            {
-                multitask.subtasks_remaining.erase(subtask_id);
-                multitask.results.push_back(result);
-            }
-        }
-    }
-};
-
-struct llama_server_response {
-    typedef std::function<void(int, int, task_result&)> callback_multitask_t;
-    callback_multitask_t callback_update_multitask;
-    // for keeping track of all tasks waiting for the result
-    std::set<int> waiting_task_ids;
-    // the main result queue
-    std::vector<task_result> queue_results;
-    std::mutex mutex_results;
-    std::condition_variable condition_results;
-
-    // add the task_id to the list of tasks waiting for response
-    void add_waiting_task_id(int task_id) {
-        LOG_VERBOSE("waiting for task id", {{"task_id", task_id}});
-        std::unique_lock<std::mutex> lock(mutex_results);
-        waiting_task_ids.insert(task_id);
-    }
-
-    // when the request is finished, we can remove task associated with it
-    void remove_waiting_task_id(int task_id) {
-        LOG_VERBOSE("remove waiting for task id", {{"task_id", task_id}});
-        std::unique_lock<std::mutex> lock(mutex_results);
-        waiting_task_ids.erase(task_id);
-    }
-
-    // This function blocks the thread until there is a response for this task_id
-    task_result recv(int task_id) {
-        while (true)
-        {
-            std::unique_lock<std::mutex> lock(mutex_results);
-            condition_results.wait(lock, [&]{
-                return !queue_results.empty();
-            });
-
-            for (int i = 0; i < (int) queue_results.size(); i++)
-            {
-                if (queue_results[i].id == task_id)
-                {
-                    assert(queue_results[i].multitask_id == -1);
-                    task_result res = queue_results[i];
-                    queue_results.erase(queue_results.begin() + i);
-                    return res;
-                }
-            }
-        }
-
-        // should never reach here
-    }
-
-    // Register the function to update multitask
-    void on_multitask_update(callback_multitask_t callback) {
-        callback_update_multitask = callback;
-    }
-
-    // Send a new result to a waiting task_id
-    void send(task_result result) {
-        std::unique_lock<std::mutex> lock(mutex_results);
-        LOG_VERBOSE("send new result", {{"task_id", result.id}});
-        for (auto& task_id : waiting_task_ids) {
-            // LOG_TEE("waiting task id %i \n", task_id);
-            // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
-            if (result.multitask_id == task_id)
-            {
-                LOG_VERBOSE("callback_update_multitask", {{"task_id", task_id}});
-                callback_update_multitask(task_id, result.id, result);
-                continue;
-            }
-
-            if (result.id == task_id)
-            {
-                LOG_VERBOSE("queue_results.push_back", {{"task_id", task_id}});
-                queue_results.push_back(result);
-                condition_results.notify_all();
-                return;
-            }
-        }
-    }
-};
-
 //
 // base64 utils (TODO: move to common in the future)
 //
@@ -447,13 +143,11 @@ static const std::string base64_chars =
              "abcdefghijklmnopqrstuvwxyz"
              "0123456789+/";
 
-static inline bool is_base64(uint8_t c)
-{
+static inline bool is_base64(uint8_t c) {
     return (isalnum(c) || (c == '+') || (c == '/'));
 }
 
-static inline std::vector<uint8_t> base64_decode(const std::string & encoded_string)
-{
+static inline std::vector<uint8_t> base64_decode(const std::string & encoded_string) {
     int i = 0;
     int j = 0;
     int in_ = 0;
@@ -465,13 +159,10 @@ static inline std::vector<uint8_t> base64_decode(const std::string & encoded_str
 
     std::vector<uint8_t> ret;
 
-    while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_]))
-    {
+    while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
         char_array_4[i++] = encoded_string[in_]; in_++;
-        if (i == 4)
-        {
-            for (i = 0; i <4; i++)
-            {
+        if (i == 4) {
+            for (i = 0; i < 4; i++) {
                 char_array_4[i] = base64_chars.find(char_array_4[i]);
             }
 
@@ -479,23 +170,20 @@ static inline std::vector<uint8_t> base64_decode(const std::string & encoded_str
             char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
             char_array_3[2] = ((char_array_4[2] & 0x3) << 6) +   char_array_4[3];
 
-            for (i = 0; (i < 3); i++)
-            {
+            for (i = 0; (i < 3); i++) {
                 ret.push_back(char_array_3[i]);
             }
+
             i = 0;
         }
     }
 
-    if (i)
-    {
-        for (j = i; j <4; j++)
-        {
+    if (i) {
+        for (j = i; j < 4; j++) {
             char_array_4[j] = 0;
         }
 
-        for (j = 0; j <4; j++)
-        {
+        for (j = 0; j < 4; j++) {
             char_array_4[j] = base64_chars.find(char_array_4[j]);
         }
 
@@ -503,8 +191,7 @@ static inline std::vector<uint8_t> base64_decode(const std::string & encoded_str
         char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
         char_array_3[2] = ((char_array_4[2] & 0x3) << 6) +   char_array_4[3];
 
-        for (j = 0; (j < i - 1); j++)
-        {
+        for (j = 0; j < i - 1; j++) {
             ret.push_back(char_array_3[j]);
         }
     }
@@ -516,8 +203,7 @@ static inline std::vector<uint8_t> base64_decode(const std::string & encoded_str
 // random string / id
 //
 
-static std::string random_string()
-{
+static std::string random_string() {
     static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
 
     std::random_device rd;
@@ -532,10 +218,10 @@ static std::string random_string()
     return result;
 }
 
-static std::string gen_chatcmplid()
-{
+static std::string gen_chatcmplid() {
     std::stringstream chatcmplid;
     chatcmplid << "chatcmpl-" << random_string();
+
     return chatcmplid.str();
 }
 
@@ -543,91 +229,316 @@ static std::string gen_chatcmplid()
 // other common utils
 //
 
-static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
-{
+static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
     size_t i;
-    for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++)
-    {
-    }
+    for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
+
     return i;
 }
 
-static bool ends_with(const std::string &str, const std::string &suffix)
-{
-    return str.size() >= suffix.size() &&
-           0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
+static bool ends_with(const std::string & str, const std::string & suffix) {
+    return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
 }
 
-static size_t find_partial_stop_string(const std::string &stop,
-                                       const std::string &text)
-{
-    if (!text.empty() && !stop.empty())
-    {
+static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
+    if (!text.empty() && !stop.empty()) {
         const char text_last_char = text.back();
-        for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--)
-        {
-            if (stop[char_index] == text_last_char)
-            {
+        for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
+            if (stop[char_index] == text_last_char) {
                 const std::string current_partial = stop.substr(0, char_index + 1);
-                if (ends_with(text, current_partial))
-                {
+                if (ends_with(text, current_partial)) {
                     return text.size() - char_index - 1;
                 }
             }
         }
     }
+
     return std::string::npos;
 }
 
 // TODO: reuse llama_detokenize
 template <class Iter>
-static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
-{
+static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
     std::string ret;
-    for (; begin != end; ++begin)
-    {
+    for (; begin != end; ++begin) {
         ret += llama_token_to_piece(ctx, *begin);
     }
+
     return ret;
 }
 
 // format incomplete utf-8 multibyte character for output
-static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token)
-{
+static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) {
     std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token);
+
     // if the size is 1 and first bit is 1, meaning it's a partial character
     //   (size > 1 meaning it's already a known token)
-    if (out.size() == 1 && (out[0] & 0x80) == 0x80)
-    {
+    if (out.size() == 1 && (out[0] & 0x80) == 0x80) {
         std::stringstream ss;
         ss << std::hex << (out[0] & 0xff);
         std::string res(ss.str());
         out = "byte: \\x" + res;
     }
+
     return out;
 }
 
+struct completion_token_output {
+    llama_token tok;
+    std::string text_to_send;
+
+    struct token_prob {
+        llama_token tok;
+        float prob;
+    };
+
+    std::vector<token_prob> probs;
+};
+
 // convert a vector of completion_token_output to json
-static json probs_vector_to_json(const llama_context *ctx, const std::vector<completion_token_output> &probs)
-{
+static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> & probs) {
     json out = json::array();
-    for (const auto &prob : probs)
-    {
+
+    for (const auto & prob : probs) {
         json probs_for_token = json::array();
-        for (const auto &p : prob.probs)
-        {
-            std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
-            probs_for_token.push_back(json
-            {
+
+        for (const auto & p : prob.probs) {
+            const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
+            probs_for_token.push_back(json {
                 {"tok_str", tok_str},
                 {"prob",    p.prob},
             });
         }
-        std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
-        out.push_back(json{
+
+        const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
+        out.push_back(json {
             {"content", tok_str},
             {"probs",   probs_for_token},
         });
     }
+
     return out;
 }
+
+//
+// OAI utils
+//
+
+static json oaicompat_completion_params_parse(
+    const struct llama_model * model,
+    const json & body, /* openai api json semantics */
+    const std::string & chat_template) {
+    json llama_params;
+
+    llama_params["__oaicompat"] = true;
+
+    // Map OpenAI parameters to llama.cpp parameters
+    //
+    // For parameters that are defined by the OpenAI documentation (e.g.
+    // temperature), we explicitly specify OpenAI's intended default; we
+    // need to do that because sometimes OpenAI disagrees with llama.cpp
+    //
+    // https://platform.openai.com/docs/api-reference/chat/create
+    llama_sampling_params default_sparams;
+    llama_params["model"]             = json_value(body,   "model",             std::string("unknown"));
+    llama_params["prompt"]            = format_chat(model, chat_template,       body["messages"]);
+    llama_params["cache_prompt"]      = json_value(body,   "cache_prompt",      false);
+    llama_params["temperature"]       = json_value(body,   "temperature",       0.0);
+    llama_params["top_k"]             = json_value(body,   "top_k",             default_sparams.top_k);
+    llama_params["top_p"]             = json_value(body,   "top_p",             1.0);
+    llama_params["n_predict"]         = json_value(body,   "max_tokens",        -1);
+    llama_params["logit_bias"]        = json_value(body,   "logit_bias",        json::object());
+    llama_params["frequency_penalty"] = json_value(body,   "frequency_penalty", 0.0);
+    llama_params["presence_penalty"]  = json_value(body,   "presence_penalty",  0.0);
+    llama_params["seed"]              = json_value(body,   "seed",              LLAMA_DEFAULT_SEED);
+    llama_params["stream"]            = json_value(body,   "stream",            false);
+    llama_params["mirostat"]          = json_value(body,   "mirostat",          default_sparams.mirostat);
+    llama_params["mirostat_tau"]      = json_value(body,   "mirostat_tau",      default_sparams.mirostat_tau);
+    llama_params["mirostat_eta"]      = json_value(body,   "mirostat_eta",      default_sparams.mirostat_eta);
+    llama_params["penalize_nl"]       = json_value(body,   "penalize_nl",       default_sparams.penalize_nl);
+    llama_params["typical_p"]         = json_value(body,   "typical_p",         default_sparams.typical_p);
+    llama_params["repeat_last_n"]     = json_value(body,   "repeat_last_n",     default_sparams.penalty_last_n);
+    llama_params["ignore_eos"]        = json_value(body,   "ignore_eos",        false);
+    llama_params["tfs_z"]             = json_value(body,   "tfs_z",             default_sparams.tfs_z);
+
+    if (body.count("grammar") != 0) {
+        llama_params["grammar"] = json_value(body, "grammar", json::object());
+    }
+
+    // Handle 'stop' field
+    if (body.contains("stop") && body["stop"].is_string()) {
+        llama_params["stop"] = json::array({body["stop"].get<std::string>()});
+    } else {
+        llama_params["stop"] = json_value(body, "stop", json::array());
+    }
+
+    // Ensure there is ChatML-specific end sequence among stop words
+    llama_params["stop"].push_back("<|im_end|>");
+
+    return llama_params;
+}
+
+static json format_final_response_oaicompat(const json & request, json result, bool streaming = false) {
+    bool stopped_word        = result.count("stopped_word") != 0;
+    bool stopped_eos         = json_value(result, "stopped_eos", false);
+    int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
+    int num_prompt_tokens    = json_value(result, "tokens_evaluated", 0);
+    std::string content      = json_value(result, "content", std::string(""));
+
+    std::string finish_reason = "length";
+    if (stopped_word || stopped_eos) {
+        finish_reason = "stop";
+    }
+
+    json choices =
+        streaming ? json::array({json{{"finish_reason", finish_reason},
+                                        {"index", 0},
+                                        {"delta", json::object()}}})
+                  : json::array({json{{"finish_reason", finish_reason},
+                                        {"index", 0},
+                                        {"message", json{{"content", content},
+                                                         {"role", "assistant"}}}}});
+
+    std::time_t t = std::time(0);
+
+    json res = json {
+        {"choices", choices},
+        {"created", t},
+        {"model",
+            json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+        {"object", streaming ? "chat.completion.chunk" : "chat.completion"},
+        {"usage", json {
+            {"completion_tokens", num_tokens_predicted},
+            {"prompt_tokens",     num_prompt_tokens},
+            {"total_tokens",      num_tokens_predicted + num_prompt_tokens}
+        }},
+        {"id", gen_chatcmplid()}
+    };
+
+    if (server_verbose) {
+        res["__verbose"] = result;
+    }
+
+    if (result.contains("completion_probabilities")) {
+        res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
+    }
+
+    return res;
+}
+
+// return value is vector as there is one case where we might need to generate two responses
+static std::vector<json> format_partial_response_oaicompat(json result) {
+    if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
+        return std::vector<json>({result});
+    }
+
+    bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
+    std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
+
+    bool stopped_word   = json_value(result, "stopped_word",  false);
+    bool stopped_eos    = json_value(result, "stopped_eos",   false);
+    bool stopped_limit  = json_value(result, "stopped_limit", false);
+    std::string content = json_value(result, "content",       std::string(""));
+
+    std::string finish_reason;
+    if (stopped_word || stopped_eos) {
+        finish_reason = "stop";
+    }
+    if (stopped_limit) {
+        finish_reason = "length";
+    }
+
+    std::time_t t = std::time(0);
+
+    json choices;
+
+    if (!finish_reason.empty()) {
+        choices = json::array({json{{"finish_reason", finish_reason},
+                                    {"index", 0},
+                                    {"delta", json::object()}}});
+    } else {
+        if (first) {
+            if (content.empty()) {
+                choices = json::array({json{{"finish_reason", nullptr},
+                                            {"index", 0},
+                                            {"delta", json{{"role", "assistant"}}}}});
+            } else {
+                // We have to send this as two updates to conform to openai behavior
+                json initial_ret = json{{"choices", json::array({json{
+                                        {"finish_reason", nullptr},
+                                        {"index", 0},
+                                        {"delta", json{
+                                            {"role", "assistant"}
+                                        }}}})},
+                            {"created", t},
+                            {"id", gen_chatcmplid()},
+                            {"model", modelname},
+                            {"object", "chat.completion.chunk"}};
+
+                json second_ret = json{
+                            {"choices", json::array({json{{"finish_reason", nullptr},
+                                                            {"index", 0},
+                                                            {"delta", json{
+                                                            {"content", content}}}
+                                                            }})},
+                            {"created", t},
+                            {"id", gen_chatcmplid()},
+                            {"model", modelname},
+                            {"object", "chat.completion.chunk"}};
+
+                return std::vector<json>({initial_ret, second_ret});
+            }
+        } else {
+            // Some idiosyncrasy in task processing logic makes several trailing calls
+            // with empty content, we ignore these at the calee site.
+            if (content.empty()) {
+                return std::vector<json>({json::object()});
+            }
+
+            choices = json::array({json{
+                {"finish_reason", nullptr},
+                {"index", 0},
+                {"delta",
+                json{
+                    {"content", content},
+                }},
+            }});
+        }
+    }
+
+    json ret = json {
+        {"choices", choices},
+        {"created", t},
+        {"id",      gen_chatcmplid()},
+        {"model",   modelname},
+        {"object",  "chat.completion.chunk"}
+    };
+
+    return std::vector<json>({ret});
+}
+
+static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
+    json res = json {
+        {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+        {"object", "list"},
+        {"usage", json {
+            {"prompt_tokens", 0},
+            {"total_tokens", 0}
+        }},
+        {"data", embeddings}
+    };
+
+    return res;
+}
+
+static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
+    return json {
+        {"tokens", tokens}
+    };
+}
+
+static json format_detokenized_response(const std::string & content) {
+    return json {
+        {"content", content}
+    };
+}

+ 5 - 1
llama.cpp

@@ -13541,18 +13541,22 @@ LLAMA_API int32_t llama_chat_apply_template(
             curr_tmpl = std::string(model_template.data(), model_template.size());
         }
     }
+
     // format the chat to string
     std::vector<const llama_chat_message *> chat_vec;
     chat_vec.resize(n_msg);
     for (size_t i = 0; i < n_msg; i++) {
         chat_vec[i] = &chat[i];
     }
+
     std::string formatted_chat;
     int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
     if (res < 0) {
         return res;
     }
-    strncpy(buf, formatted_chat.c_str(), length);
+    if (buf && length > 0) {
+        strncpy(buf, formatted_chat.c_str(), length);
+    }
     return res;
 }
 

Неке датотеке нису приказане због велике количине промена