Răsfoiți Sursa

`main`: add --json-schema / -j flag (#6659)

* main: add --json-schema / -j

* json: move json-schema-to-grammar to common lib

* json: fix zig build
Olivier Chafik 1 an în urmă
părinte
comite
7593639ce3
7 a modificat fișierele cu 31 adăugiri și 16 ștergeri
  1. 3 3
      Makefile
  2. 7 7
      build.zig
  3. 1 3
      common/CMakeLists.txt
  4. 15 0
      common/common.cpp
  5. 3 1
      examples/main/README.md
  6. 1 1
      examples/server/CMakeLists.txt
  7. 1 1
      tests/CMakeLists.txt

+ 3 - 3
Makefile

@@ -688,7 +688,7 @@ llama.o: llama.cpp unicode.h ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml
 	$(CXX) $(CXXFLAGS) -c $< -o $@
 
 COMMON_H_DEPS = common/common.h common/sampling.h common/log.h
-COMMON_DEPS   = common.o sampling.o grammar-parser.o build-info.o
+COMMON_DEPS   = common.o sampling.o grammar-parser.o build-info.o json-schema-to-grammar.o
 
 common.o: common/common.cpp $(COMMON_H_DEPS)
 	$(CXX) $(CXXFLAGS) -c $< -o $@
@@ -756,7 +756,7 @@ batched: examples/batched/batched.cpp                         ggml.o llama.o $(C
 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
 	$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
 
-batched-bench: examples/batched-bench/batched-bench.cpp       build-info.o ggml.o llama.o common.o $(OBJS)
+batched-bench: examples/batched-bench/batched-bench.cpp       build-info.o ggml.o llama.o $(COMMON_DEPS) $(OBJS)
 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
 	$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
 
@@ -788,7 +788,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
 	$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
 
-server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h common/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp json-schema-to-grammar.o common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
+server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h common/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
 	$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
 

+ 7 - 7
build.zig

@@ -128,14 +128,14 @@ pub fn build(b: *std.build.Builder) !void {
     const clip = make.obj("clip", "examples/llava/clip.cpp");
     const llava = make.obj("llava", "examples/llava/llava.cpp");
 
-    _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, sampling, console, grammar_parser });
-    _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo });
-    _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo });
-    _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo });
-    _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, train });
-    _ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, train });
+    _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, console, grammar_parser });
+    _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
+    _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
+    _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
+    _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });
+    _ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });
 
-    const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, sampling, grammar_parser, json_schema_to_grammar, clip, llava });
+    const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, grammar_parser, clip, llava });
     if (server.target.isWindows()) {
         server.linkSystemLibrary("ws2_32");
     }

+ 1 - 3
common/CMakeLists.txt

@@ -47,9 +47,6 @@ if (BUILD_SHARED_LIBS)
     set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
 endif()
 
-set(TARGET json-schema-to-grammar)
-add_library(${TARGET} OBJECT json-schema-to-grammar.cpp json-schema-to-grammar.h)
-
 set(TARGET common)
 
 add_library(${TARGET} STATIC
@@ -63,6 +60,7 @@ add_library(${TARGET} STATIC
     grammar-parser.h
     grammar-parser.cpp
     json.hpp
+    json-schema-to-grammar.cpp
     train.h
     train.cpp
     ngram-cache.h

+ 15 - 0
common/common.cpp

@@ -1,4 +1,6 @@
 #include "common.h"
+#include "json.hpp"
+#include "json-schema-to-grammar.h"
 #include "llama.h"
 
 #include <algorithm>
@@ -68,6 +70,8 @@
 #define LLAMA_CURL_MAX_HEADER_LENGTH 256
 #endif // LLAMA_USE_CURL
 
+using json = nlohmann::ordered_json;
+
 int32_t get_num_physical_cores() {
 #ifdef __linux__
     // enumerate the set of thread siblings, num entries is num cores
@@ -1148,6 +1152,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         );
         return true;
     }
+    if (arg == "-j" || arg == "--json-schema") {
+        if (++i >= argc) {
+            invalid_param = true;
+            return true;
+        }
+        sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
+        return true;
+    }
     if (arg == "--override-kv") {
         if (++i >= argc) {
             invalid_param = true;
@@ -1353,6 +1365,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("                        or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
     printf("  --grammar GRAMMAR     BNF-like grammar to constrain generations (see samples in grammars/ dir)\n");
     printf("  --grammar-file FNAME  file to read grammar from\n");
+    printf("  -j SCHEMA, --json-schema SCHEMA\n");
+    printf("                        JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n");
+    printf("                        For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n");
     printf("  --cfg-negative-prompt PROMPT\n");
     printf("                        negative prompt to use for guidance. (default: empty)\n");
     printf("  --cfg-negative-prompt-file FNAME\n");

+ 3 - 1
examples/main/README.md

@@ -304,10 +304,12 @@ These options help improve the performance and memory usage of the LLaMA models.
 
 -   `--prompt-cache FNAME`: Specify a file to cache the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The file is created during the first run and is reused and updated in subsequent runs. **Note**: Restoring a cached prompt does not imply restoring the exact state of the session at the point it was saved. So even when specifying a specific seed, you are not guaranteed to get the same sequence of tokens as the original generation.
 
-### Grammars
+### Grammars & JSON schemas
 
 -   `--grammar GRAMMAR`, `--grammar-file FILE`: Specify a grammar (defined inline or in a file) to constrain model output to a specific format. For example, you could force the model to output JSON or to speak only in emojis. See the [GBNF guide](../../grammars/README.md) for details on the syntax.
 
+-   `--json-schema SCHEMA`: Specify a [JSON schema](https://json-schema.org/) to constrain model output to (e.g. `{}` for any JSON object, or `{"items": {"type": "string", "minLength": 10, "maxLength": 100}, "minItems": 10}` for a JSON array of strings with size constraints). If a schema uses external `$ref`s, you should use `--grammar "$( python examples/json_schema_to_grammar.py myschema.json )"` instead.
+
 ### Quantization
 
 For information about 4-bit quantization, which can significantly improve performance and reduce memory usage, please refer to llama.cpp's primary [README](../../README.md#prepare-and-quantize).

+ 1 - 1
examples/server/CMakeLists.txt

@@ -11,7 +11,7 @@ install(TARGETS ${TARGET} RUNTIME)
 target_compile_definitions(${TARGET} PRIVATE
     SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
 )
-target_link_libraries(${TARGET} PRIVATE common json-schema-to-grammar ${CMAKE_THREAD_LIBS_INIT})
+target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
 if (LLAMA_SERVER_SSL)
     find_package(OpenSSL REQUIRED)
     target_link_libraries(${TARGET} PRIVATE OpenSSL::SSL OpenSSL::Crypto)

+ 1 - 1
tests/CMakeLists.txt

@@ -25,7 +25,7 @@ function(llama_test source)
 
     add_executable(${TEST_TARGET} ${source} get-model.cpp)
     install(TARGETS ${TEST_TARGET} RUNTIME)
-    target_link_libraries(${TEST_TARGET} PRIVATE common json-schema-to-grammar)
+    target_link_libraries(${TEST_TARGET} PRIVATE common)
     add_test(
         NAME ${TEST_TARGET}
         WORKING_DIRECTORY ${LLAMA_TEST_WORKING_DIRECTORY}