|
|
@@ -626,18 +626,36 @@ struct llama_server_context
|
|
|
const int n_vocab = llama_n_vocab(model);
|
|
|
for (const auto &el : *logit_bias)
|
|
|
{
|
|
|
- if (el.is_array() && el.size() == 2 && el[0].is_number_integer())
|
|
|
+ if (el.is_array() && el.size() == 2)
|
|
|
{
|
|
|
- llama_token tok = el[0].get<llama_token>();
|
|
|
- if (tok >= 0 && tok < n_vocab)
|
|
|
+ float bias;
|
|
|
+ if (el[1].is_number())
|
|
|
{
|
|
|
- if (el[1].is_number())
|
|
|
+ bias = el[1].get<float>();
|
|
|
+ }
|
|
|
+ else if (el[1].is_boolean() && !el[1].get<bool>())
|
|
|
+ {
|
|
|
+ bias = -INFINITY;
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (el[0].is_number_integer())
|
|
|
+ {
|
|
|
+ llama_token tok = el[0].get<llama_token>();
|
|
|
+ if (tok >= 0 && tok < n_vocab)
|
|
|
{
|
|
|
- slot->sparams.logit_bias[tok] = el[1].get<float>();
|
|
|
+ slot->sparams.logit_bias[tok] = bias;
|
|
|
}
|
|
|
- else if (el[1].is_boolean() && !el[1].get<bool>())
|
|
|
+ }
|
|
|
+ else if (el[0].is_string())
|
|
|
+ {
|
|
|
+ auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
|
|
|
+ for (auto tok : toks)
|
|
|
{
|
|
|
- slot->sparams.logit_bias[tok] = -INFINITY;
|
|
|
+ slot->sparams.logit_bias[tok] = bias;
|
|
|
}
|
|
|
}
|
|
|
}
|