|
@@ -1,6 +1,7 @@
|
|
|
#include "common.h"
|
|
#include "common.h"
|
|
|
#include "llama.h"
|
|
#include "llama.h"
|
|
|
#include "build-info.h"
|
|
#include "build-info.h"
|
|
|
|
|
+#include "grammar-parser.h"
|
|
|
|
|
|
|
|
#ifndef NDEBUG
|
|
#ifndef NDEBUG
|
|
|
// crash the server in debug mode, otherwise send an http 500 error
|
|
// crash the server in debug mode, otherwise send an http 500 error
|
|
@@ -195,6 +196,8 @@ struct llama_server_context
|
|
|
llama_context *ctx = nullptr;
|
|
llama_context *ctx = nullptr;
|
|
|
gpt_params params;
|
|
gpt_params params;
|
|
|
|
|
|
|
|
|
|
+ llama_grammar *grammar = nullptr;
|
|
|
|
|
+
|
|
|
bool truncated = false;
|
|
bool truncated = false;
|
|
|
bool stopped_eos = false;
|
|
bool stopped_eos = false;
|
|
|
bool stopped_word = false;
|
|
bool stopped_word = false;
|
|
@@ -226,6 +229,7 @@ struct llama_server_context
|
|
|
void rewind()
|
|
void rewind()
|
|
|
{
|
|
{
|
|
|
params.antiprompt.clear();
|
|
params.antiprompt.clear();
|
|
|
|
|
+ params.grammar.clear();
|
|
|
num_prompt_tokens = 0;
|
|
num_prompt_tokens = 0;
|
|
|
num_tokens_predicted = 0;
|
|
num_tokens_predicted = 0;
|
|
|
generated_text = "";
|
|
generated_text = "";
|
|
@@ -237,6 +241,7 @@ struct llama_server_context
|
|
|
stopped_limit = false;
|
|
stopped_limit = false;
|
|
|
stopping_word = "";
|
|
stopping_word = "";
|
|
|
multibyte_pending = 0;
|
|
multibyte_pending = 0;
|
|
|
|
|
+ grammar = nullptr;
|
|
|
|
|
|
|
|
n_remain = 0;
|
|
n_remain = 0;
|
|
|
n_past = 0;
|
|
n_past = 0;
|
|
@@ -257,6 +262,33 @@ struct llama_server_context
|
|
|
return true;
|
|
return true;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ bool loadGrammar()
|
|
|
|
|
+ {
|
|
|
|
|
+ if (!params.grammar.empty()) {
|
|
|
|
|
+ grammar_parser::parse_state parsed_grammar;
|
|
|
|
|
+
|
|
|
|
|
+ parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
|
|
|
|
+ // will be empty (default) if there are parse errors
|
|
|
|
|
+ if (parsed_grammar.rules.empty()) {
|
|
|
|
|
+ LOG_ERROR("grammar parse error", {{"grammar", params.grammar}});
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ grammar_parser::print_grammar(stderr, parsed_grammar);
|
|
|
|
|
+
|
|
|
|
|
+ {
|
|
|
|
|
+ auto it = params.logit_bias.find(llama_token_eos());
|
|
|
|
|
+ if (it != params.logit_bias.end() && it->second == -INFINITY) {
|
|
|
|
|
+ LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
|
|
|
|
+ grammar = llama_grammar_init(
|
|
|
|
|
+ grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
|
|
|
|
+ }
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
void loadPrompt()
|
|
void loadPrompt()
|
|
|
{
|
|
{
|
|
|
params.prompt.insert(0, 1, ' '); // always add a first space
|
|
params.prompt.insert(0, 1, ' '); // always add a first space
|
|
@@ -420,6 +452,10 @@ struct llama_server_context
|
|
|
logits[llama_token_nl()] = nl_logit;
|
|
logits[llama_token_nl()] = nl_logit;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ if (grammar != nullptr) {
|
|
|
|
|
+ llama_sample_grammar(ctx, &candidates_p, grammar);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
if (temp <= 0)
|
|
if (temp <= 0)
|
|
|
{
|
|
{
|
|
|
// Greedy sampling
|
|
// Greedy sampling
|
|
@@ -457,10 +493,15 @@ struct llama_server_context
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ if (grammar != nullptr) {
|
|
|
|
|
+ llama_grammar_accept_token(ctx, grammar, result.tok);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)
|
|
for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)
|
|
|
{
|
|
{
|
|
|
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
|
|
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
last_n_tokens.erase(last_n_tokens.begin());
|
|
last_n_tokens.erase(last_n_tokens.begin());
|
|
|
last_n_tokens.push_back(result.tok);
|
|
last_n_tokens.push_back(result.tok);
|
|
|
num_tokens_predicted++;
|
|
num_tokens_predicted++;
|
|
@@ -947,6 +988,7 @@ static json format_generation_settings(llama_server_context &llama)
|
|
|
{"stream", llama.stream},
|
|
{"stream", llama.stream},
|
|
|
{"logit_bias", llama.params.logit_bias},
|
|
{"logit_bias", llama.params.logit_bias},
|
|
|
{"n_probs", llama.params.n_probs},
|
|
{"n_probs", llama.params.n_probs},
|
|
|
|
|
+ {"grammar", llama.params.grammar},
|
|
|
};
|
|
};
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -1048,6 +1090,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla
|
|
|
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
|
|
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
|
|
|
llama.params.seed = body.value("seed", default_params.seed);
|
|
llama.params.seed = body.value("seed", default_params.seed);
|
|
|
llama.params.prompt = body.value("prompt", default_params.prompt);
|
|
llama.params.prompt = body.value("prompt", default_params.prompt);
|
|
|
|
|
+ llama.params.grammar = body.value("grammar", default_params.grammar);
|
|
|
llama.params.n_probs = body.value("n_probs", default_params.n_probs);
|
|
llama.params.n_probs = body.value("n_probs", default_params.n_probs);
|
|
|
|
|
|
|
|
llama.params.logit_bias.clear();
|
|
llama.params.logit_bias.clear();
|
|
@@ -1179,6 +1222,12 @@ int main(int argc, char **argv)
|
|
|
|
|
|
|
|
parse_options_completion(json::parse(req.body), llama);
|
|
parse_options_completion(json::parse(req.body), llama);
|
|
|
|
|
|
|
|
|
|
+ if (!llama.loadGrammar())
|
|
|
|
|
+ {
|
|
|
|
|
+ res.status = 400;
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
llama.loadPrompt();
|
|
llama.loadPrompt();
|
|
|
llama.beginCompletion();
|
|
llama.beginCompletion();
|
|
|
|
|
|
|
@@ -1334,8 +1383,12 @@ int main(int argc, char **argv)
|
|
|
|
|
|
|
|
svr.set_error_handler([](const Request &, Response &res)
|
|
svr.set_error_handler([](const Request &, Response &res)
|
|
|
{
|
|
{
|
|
|
- res.set_content("File Not Found", "text/plain");
|
|
|
|
|
- res.status = 404; });
|
|
|
|
|
|
|
+ if (res.status == 400) {
|
|
|
|
|
+ res.set_content("Invalid request", "text/plain");
|
|
|
|
|
+ } else {
|
|
|
|
|
+ res.set_content("File Not Found", "text/plain");
|
|
|
|
|
+ res.status = 404;
|
|
|
|
|
+ } });
|
|
|
|
|
|
|
|
// set timeouts and change hostname and port
|
|
// set timeouts and change hostname and port
|
|
|
svr.set_read_timeout(sparams.read_timeout);
|
|
svr.set_read_timeout(sparams.read_timeout);
|
|
@@ -1363,6 +1416,9 @@ int main(int argc, char **argv)
|
|
|
return 1;
|
|
return 1;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ if (llama.grammar != nullptr) {
|
|
|
|
|
+ llama_grammar_free(llama.grammar);
|
|
|
|
|
+ }
|
|
|
llama_backend_free();
|
|
llama_backend_free();
|
|
|
|
|
|
|
|
return 0;
|
|
return 0;
|