Просмотр исходного кода

common : remove json.hpp from common.cpp (#12697)

* common : remove json.hpp from common.cpp

* fix comment
Xuan-Son Nguyen 9 месяцев назад
Родитель
Сommit
42eb248f46
4 измененных файлов с 34 добавлено и 38 удалено
  1. 0 28
      common/common.cpp
  2. 0 4
      common/common.h
  3. 6 5
      examples/server/server.cpp
  4. 28 1
      examples/server/utils.hpp

+ 0 - 28
common/common.cpp

@@ -7,9 +7,6 @@
 
 #include "common.h"
 #include "log.h"
-// Change JSON_ASSERT from assert() to GGML_ASSERT:
-#define JSON_ASSERT GGML_ASSERT
-#include "json.hpp"
 #include "llama.h"
 
 #include <algorithm>
@@ -56,8 +53,6 @@
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
-using json = nlohmann::ordered_json;
-
 //
 // CPU utils
 //
@@ -1545,26 +1540,3 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
 
     return result;
 }
-
-template <>
-json common_grammar_trigger::to_json() const {
-    json out {
-        {"type", (int) type},
-        {"value", value},
-    };
-    if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
-        out["token"] = (int) token;
-    }
-    return out;
-}
-
-template <>
-common_grammar_trigger common_grammar_trigger::from_json(const json & in) {
-    common_grammar_trigger out;
-    out.type = (common_grammar_trigger_type) in.at("type").get<int>();
-    out.value = in.at("value").get<std::string>();
-    if (out.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
-        out.token = (llama_token) in.at("token").get<int>();
-    }
-    return out;
-}

+ 0 - 4
common/common.h

@@ -121,10 +121,6 @@ struct common_grammar_trigger {
     common_grammar_trigger_type type;
     std::string value;
     llama_token token = LLAMA_TOKEN_NULL;
-
-    // T can only be nlohmann::ordered_json
-    template <class T> T to_json() const;
-    template <class T> static common_grammar_trigger from_json(const T & in);
 };
 
 // sampling parameters

+ 6 - 5
examples/server/server.cpp

@@ -133,7 +133,8 @@ struct slot_params {
 
         auto grammar_triggers = json::array();
         for (const auto & trigger : sampling.grammar_triggers) {
-            grammar_triggers.push_back(trigger.to_json<json>());
+            server_grammar_trigger ct(std::move(trigger));
+            grammar_triggers.push_back(ct.to_json());
         }
 
         return json {
@@ -372,9 +373,9 @@ struct server_task {
             const auto grammar_triggers = data.find("grammar_triggers");
             if (grammar_triggers != data.end()) {
                 for (const auto & t : *grammar_triggers) {
-                    auto ct = common_grammar_trigger::from_json(t);
-                    if (ct.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
-                        const auto & word = ct.value;
+                    server_grammar_trigger ct(t);
+                    if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
+                        const auto & word = ct.value.value;
                         auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
                         if (ids.size() == 1) {
                             auto token = ids[0];
@@ -392,7 +393,7 @@ struct server_task {
                             params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
                         }
                     } else {
-                        params.sampling.grammar_triggers.push_back(ct);
+                        params.sampling.grammar_triggers.push_back(std::move(ct.value));
                     }
                 }
             }

+ 28 - 1
examples/server/utils.hpp

@@ -58,6 +58,32 @@ static T json_value(const json & body, const std::string & key, const T & defaul
 
 const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
 
+// thin wrapper around common_grammar_trigger with (de)serialization functions
+struct server_grammar_trigger {
+    common_grammar_trigger value;
+
+    server_grammar_trigger() = default;
+    server_grammar_trigger(const common_grammar_trigger & value) : value(value) {}
+    server_grammar_trigger(const json & in) {
+        value.type = (common_grammar_trigger_type) in.at("type").get<int>();
+        value.value = in.at("value").get<std::string>();
+        if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
+            value.token = (llama_token) in.at("token").get<int>();
+        }
+    }
+
+    json to_json() const {
+        json out {
+            {"type", (int) value.type},
+            {"value", value.value},
+        };
+        if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
+            out["token"] = (int) value.token;
+        }
+        return out;
+    }
+};
+
 //
 // tokenizer and input processing utils
 //
@@ -627,7 +653,8 @@ static json oaicompat_completion_params_parse(
     llama_params["grammar_lazy"]     = chat_params.grammar_lazy;
     auto grammar_triggers = json::array();
     for (const auto & trigger : chat_params.grammar_triggers) {
-        grammar_triggers.push_back(trigger.to_json<json>());
+        server_grammar_trigger ct(trigger);
+        grammar_triggers.push_back(ct.to_json());
     }
     llama_params["grammar_triggers"] = grammar_triggers;
     llama_params["preserved_tokens"] = chat_params.preserved_tokens;