|
|
@@ -761,6 +761,42 @@ struct llama_server_context
|
|
|
slot->prompt = "";
|
|
|
}
|
|
|
|
|
|
+ slot->sparams.penalty_prompt_tokens.clear();
|
|
|
+ slot->sparams.use_penalty_prompt_tokens = false;
|
|
|
+ const auto &penalty_prompt = data.find("penalty_prompt");
|
|
|
+ if (penalty_prompt != data.end())
|
|
|
+ {
|
|
|
+ if (penalty_prompt->is_string())
|
|
|
+ {
|
|
|
+ const auto penalty_prompt_string = penalty_prompt->get<std::string>();
|
|
|
+ auto penalty_tokens = llama_tokenize(model, penalty_prompt_string, false);
|
|
|
+ slot->sparams.penalty_prompt_tokens.swap(penalty_tokens);
|
|
|
+ if (slot->params.n_predict > 0)
|
|
|
+ {
|
|
|
+ slot->sparams.penalty_prompt_tokens.reserve(slot->sparams.penalty_prompt_tokens.size() + slot->params.n_predict);
|
|
|
+ }
|
|
|
+ slot->sparams.use_penalty_prompt_tokens = true;
|
|
|
+ }
|
|
|
+ else if (penalty_prompt->is_array())
|
|
|
+ {
|
|
|
+ const auto n_tokens = penalty_prompt->size();
|
|
|
+ slot->sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot->params.n_predict));
|
|
|
+ const int n_vocab = llama_n_vocab(model);
|
|
|
+ for (const auto &penalty_token : *penalty_prompt)
|
|
|
+ {
|
|
|
+ if (penalty_token.is_number_integer())
|
|
|
+ {
|
|
|
+ const auto tok = penalty_token.get<llama_token>();
|
|
|
+ if (tok >= 0 && tok < n_vocab)
|
|
|
+ {
|
|
|
+ slot->sparams.penalty_prompt_tokens.push_back(tok);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ slot->sparams.use_penalty_prompt_tokens = true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
slot->sparams.logit_bias.clear();
|
|
|
|
|
|
if (json_value(data, "ignore_eos", false))
|
|
|
@@ -992,6 +1028,12 @@ struct llama_server_context
|
|
|
slot.generated_text += token_str;
|
|
|
slot.has_next_token = true;
|
|
|
|
|
|
+ if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1)
|
|
|
+ {
|
|
|
+ // we can change penalty_prompt_tokens because it is always created from scratch each request
|
|
|
+ slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
|
|
|
+ }
|
|
|
+
|
|
|
// check if there is incomplete UTF-8 character at the end
|
|
|
bool incomplete = false;
|
|
|
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i)
|
|
|
@@ -1183,6 +1225,8 @@ struct llama_server_context
|
|
|
{"repeat_penalty", slot.sparams.penalty_repeat},
|
|
|
{"presence_penalty", slot.sparams.penalty_present},
|
|
|
{"frequency_penalty", slot.sparams.penalty_freq},
|
|
|
+ {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
|
|
|
+ {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
|
|
|
{"mirostat", slot.sparams.mirostat},
|
|
|
{"mirostat_tau", slot.sparams.mirostat_tau},
|
|
|
{"mirostat_eta", slot.sparams.mirostat_eta},
|