|
|
@@ -190,6 +190,7 @@ struct llama_server_context
|
|
|
size_t n_past = 0;
|
|
|
size_t n_remain = 0;
|
|
|
|
|
|
+ json prompt;
|
|
|
std::vector<llama_token> embd;
|
|
|
std::vector<llama_token> last_n_tokens;
|
|
|
|
|
|
@@ -267,6 +268,53 @@ struct llama_server_context
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
+ std::vector<llama_token> tokenize(json json_prompt, bool add_bos)
|
|
|
+ {
|
|
|
+ // If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
|
|
+ // or the first element of the json_prompt array is a string.
|
|
|
+ std::vector<llama_token> prompt_tokens;
|
|
|
+
|
|
|
+ if (json_prompt.is_array())
|
|
|
+ {
|
|
|
+ bool first = true;
|
|
|
+ for (const auto& p : json_prompt)
|
|
|
+ {
|
|
|
+ if (p.is_string())
|
|
|
+ {
|
|
|
+ auto s = p.template get<std::string>();
|
|
|
+ std::vector<llama_token> p;
|
|
|
+ if (first)
|
|
|
+ {
|
|
|
+ s.insert(0, 1, ' '); // add a space if it's the first
|
|
|
+ p = ::llama_tokenize(ctx, s, add_bos);
|
|
|
+ first = false;
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ p = ::llama_tokenize(ctx, s, false);
|
|
|
+ }
|
|
|
+ prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ if (first)
|
|
|
+ {
|
|
|
+ first = false;
|
|
|
+ }
|
|
|
+ prompt_tokens.push_back(p.template get<llama_token>());
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ auto s = json_prompt.template get<std::string>();
|
|
|
+ s.insert(0, 1, ' '); // always add a first space
|
|
|
+ prompt_tokens = ::llama_tokenize(ctx, s, add_bos);
|
|
|
+ }
|
|
|
+
|
|
|
+ return prompt_tokens;
|
|
|
+ }
|
|
|
+
|
|
|
bool loadGrammar()
|
|
|
{
|
|
|
if (!params.grammar.empty()) {
|
|
|
@@ -294,8 +342,8 @@ struct llama_server_context
|
|
|
|
|
|
void loadPrompt()
|
|
|
{
|
|
|
- params.prompt.insert(0, 1, ' '); // always add a first space
|
|
|
- std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
|
|
|
+ auto prompt_tokens = tokenize(prompt, true); // always add BOS
|
|
|
+
|
|
|
num_prompt_tokens = prompt_tokens.size();
|
|
|
|
|
|
if (params.n_keep < 0)
|
|
|
@@ -1016,7 +1064,7 @@ static json format_final_response(llama_server_context &llama, const std::string
|
|
|
{"tokens_predicted", llama.num_tokens_predicted},
|
|
|
{"tokens_evaluated", llama.num_prompt_tokens},
|
|
|
{"generation_settings", format_generation_settings(llama)},
|
|
|
- {"prompt", llama.params.prompt},
|
|
|
+ {"prompt", llama.prompt},
|
|
|
{"truncated", llama.truncated},
|
|
|
{"stopped_eos", llama.stopped_eos},
|
|
|
{"stopped_word", llama.stopped_word},
|
|
|
@@ -1085,10 +1133,18 @@ static void parse_options_completion(const json &body, llama_server_context &lla
|
|
|
llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl);
|
|
|
llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep);
|
|
|
llama.params.seed = json_value(body, "seed", default_params.seed);
|
|
|
- llama.params.prompt = json_value(body, "prompt", default_params.prompt);
|
|
|
llama.params.grammar = json_value(body, "grammar", default_params.grammar);
|
|
|
llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs);
|
|
|
|
|
|
+ if (body.count("prompt") != 0)
|
|
|
+ {
|
|
|
+ llama.prompt = body["prompt"];
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ llama.prompt = "";
|
|
|
+ }
|
|
|
+
|
|
|
llama.params.logit_bias.clear();
|
|
|
if (json_value(body, "ignore_eos", false))
|
|
|
{
|
|
|
@@ -1345,8 +1401,11 @@ int main(int argc, char **argv)
|
|
|
auto lock = llama.lock();
|
|
|
|
|
|
const json body = json::parse(req.body);
|
|
|
- const std::string content = json_value<std::string>(body, "content", "");
|
|
|
- const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false);
|
|
|
+ std::vector<llama_token> tokens;
|
|
|
+ if (body.count("content") != 0)
|
|
|
+ {
|
|
|
+ tokens = llama.tokenize(body["content"], false);
|
|
|
+ }
|
|
|
const json data = format_tokenizer_response(tokens);
|
|
|
return res.set_content(data.dump(), "application/json"); });
|
|
|
|
|
|
@@ -1358,7 +1417,14 @@ int main(int argc, char **argv)
|
|
|
|
|
|
llama.rewind();
|
|
|
llama_reset_timings(llama.ctx);
|
|
|
- llama.params.prompt = json_value<std::string>(body, "content", "");
|
|
|
+ if (body.count("content") != 0)
|
|
|
+ {
|
|
|
+ llama.prompt = body["content"];
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ llama.prompt = "";
|
|
|
+ }
|
|
|
llama.params.n_predict = 0;
|
|
|
llama.loadPrompt();
|
|
|
llama.beginCompletion();
|