|
|
@@ -93,6 +93,7 @@ struct slot_params {
|
|
|
|
|
|
std::vector<std::string> antiprompt;
|
|
|
bool timings_per_token = false;
|
|
|
+ bool post_sampling_probs = false;
|
|
|
bool ignore_eos = false;
|
|
|
|
|
|
struct common_params_sampling sampling;
|
|
|
@@ -151,6 +152,7 @@ struct slot_params {
|
|
|
{"speculative.n_min", speculative.n_min},
|
|
|
{"speculative.p_min", speculative.p_min},
|
|
|
{"timings_per_token", timings_per_token},
|
|
|
+ {"post_sampling_probs", post_sampling_probs},
|
|
|
};
|
|
|
}
|
|
|
};
|
|
|
@@ -231,6 +233,7 @@ struct server_task {
|
|
|
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
|
|
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
|
|
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
|
|
+ params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
|
|
|
|
|
|
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
|
|
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
|
|
|
@@ -436,36 +439,67 @@ inline std::string stop_type_to_str(stop_type type) {
|
|
|
|
|
|
struct completion_token_output {
|
|
|
llama_token tok;
|
|
|
+ float prob;
|
|
|
std::string text_to_send;
|
|
|
- struct token_prob {
|
|
|
+ struct prob_info {
|
|
|
llama_token tok;
|
|
|
- std::string tok_str;
|
|
|
+ std::string txt;
|
|
|
float prob;
|
|
|
};
|
|
|
- std::vector<token_prob> probs;
|
|
|
+ std::vector<prob_info> probs;
|
|
|
|
|
|
- json to_json() const {
|
|
|
+ json to_json(bool post_sampling_probs) const {
|
|
|
json probs_for_token = json::array();
|
|
|
for (const auto & p : probs) {
|
|
|
+ std::string txt(p.txt);
|
|
|
+ txt.resize(validate_utf8(txt));
|
|
|
probs_for_token.push_back(json {
|
|
|
- {"tok_str", p.tok_str},
|
|
|
- {"prob", p.prob},
|
|
|
+ {"id", p.tok},
|
|
|
+ {"token", txt},
|
|
|
+ {"bytes", str_to_bytes(p.txt)},
|
|
|
+ {
|
|
|
+ post_sampling_probs ? "prob" : "logprob",
|
|
|
+ post_sampling_probs ? p.prob : logarithm(p.prob)
|
|
|
+ },
|
|
|
});
|
|
|
}
|
|
|
return probs_for_token;
|
|
|
}
|
|
|
|
|
|
- static json probs_vector_to_json(const std::vector<completion_token_output> & probs) {
|
|
|
+ static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
|
|
|
json out = json::array();
|
|
|
- for (const auto & prob : probs) {
|
|
|
- const std::string tok_str = prob.text_to_send;
|
|
|
+ for (const auto & p : probs) {
|
|
|
+ std::string txt(p.text_to_send);
|
|
|
+ txt.resize(validate_utf8(txt));
|
|
|
out.push_back(json {
|
|
|
- {"content", tok_str},
|
|
|
- {"probs", prob.to_json()},
|
|
|
+ {"id", p.tok},
|
|
|
+ {"token", txt},
|
|
|
+ {"bytes", str_to_bytes(p.text_to_send)},
|
|
|
+ {
|
|
|
+ post_sampling_probs ? "prob" : "logprob",
|
|
|
+ post_sampling_probs ? p.prob : logarithm(p.prob)
|
|
|
+ },
|
|
|
+ {
|
|
|
+ post_sampling_probs ? "top_probs" : "top_logprobs",
|
|
|
+ p.to_json(post_sampling_probs)
|
|
|
+ },
|
|
|
});
|
|
|
}
|
|
|
return out;
|
|
|
}
|
|
|
+
|
|
|
+ static float logarithm(float x) {
|
|
|
+ // nlohmann::json converts -inf to null, so we need to prevent that
|
|
|
+ return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
|
|
|
+ }
|
|
|
+
|
|
|
+ static std::vector<unsigned char> str_to_bytes(const std::string & str) {
|
|
|
+ std::vector<unsigned char> bytes;
|
|
|
+ for (unsigned char c : str) {
|
|
|
+ bytes.push_back(c);
|
|
|
+ }
|
|
|
+ return bytes;
|
|
|
+ }
|
|
|
};
|
|
|
|
|
|
struct server_task_result_cmpl_final : server_task_result {
|
|
|
@@ -486,6 +520,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|
|
std::string stopping_word;
|
|
|
stop_type stop = STOP_TYPE_NONE;
|
|
|
|
|
|
+ bool post_sampling_probs;
|
|
|
std::vector<completion_token_output> probs_output;
|
|
|
|
|
|
slot_params generation_params;
|
|
|
@@ -530,8 +565,8 @@ struct server_task_result_cmpl_final : server_task_result {
|
|
|
{"tokens_cached", n_tokens_cached},
|
|
|
{"timings", timings.to_json()},
|
|
|
};
|
|
|
- if (!probs_output.empty()) {
|
|
|
- res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
|
|
|
+ if (!stream && !probs_output.empty()) {
|
|
|
+ res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
|
|
}
|
|
|
return res;
|
|
|
}
|
|
|
@@ -542,19 +577,25 @@ struct server_task_result_cmpl_final : server_task_result {
|
|
|
finish_reason = "stop";
|
|
|
}
|
|
|
|
|
|
- json choices = json::array({json{
|
|
|
+ json choice = json{
|
|
|
{"finish_reason", finish_reason},
|
|
|
{"index", 0},
|
|
|
{"message", json {
|
|
|
{"content", content},
|
|
|
{"role", "assistant"}
|
|
|
}
|
|
|
- }}});
|
|
|
+ }};
|
|
|
+
|
|
|
+ if (!stream && probs_output.size() > 0) {
|
|
|
+ choice["logprobs"] = json{
|
|
|
+ {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
|
|
+ };
|
|
|
+ }
|
|
|
|
|
|
std::time_t t = std::time(0);
|
|
|
|
|
|
json res = json {
|
|
|
- {"choices", choices},
|
|
|
+ {"choices", json::array({choice})},
|
|
|
{"created", t},
|
|
|
{"model", oaicompat_model},
|
|
|
{"object", "chat.completion"},
|
|
|
@@ -584,12 +625,14 @@ struct server_task_result_cmpl_final : server_task_result {
|
|
|
finish_reason = "stop";
|
|
|
}
|
|
|
|
|
|
- json choices = json::array({json{{"finish_reason", finish_reason},
|
|
|
- {"index", 0},
|
|
|
- {"delta", json::object()}}});
|
|
|
+ json choice = json{
|
|
|
+ {"finish_reason", finish_reason},
|
|
|
+ {"index", 0},
|
|
|
+ {"delta", json::object()}
|
|
|
+ };
|
|
|
|
|
|
json ret = json {
|
|
|
- {"choices", choices},
|
|
|
+ {"choices", json::array({choice})},
|
|
|
{"created", t},
|
|
|
{"id", oaicompat_cmpl_id},
|
|
|
{"model", oaicompat_model},
|
|
|
@@ -618,7 +661,8 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|
|
int32_t n_decoded;
|
|
|
int32_t n_prompt_tokens;
|
|
|
|
|
|
- std::vector<completion_token_output> probs_output;
|
|
|
+ bool post_sampling_probs;
|
|
|
+ completion_token_output prob_output;
|
|
|
result_timings timings;
|
|
|
|
|
|
// OAI-compat fields
|
|
|
@@ -655,8 +699,8 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|
|
if (timings.prompt_n > 0) {
|
|
|
res.push_back({"timings", timings.to_json()});
|
|
|
}
|
|
|
- if (!probs_output.empty()) {
|
|
|
- res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
|
|
|
+ if (!prob_output.probs.empty()) {
|
|
|
+ res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
|
|
|
}
|
|
|
return res;
|
|
|
}
|
|
|
@@ -708,6 +752,14 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|
|
}});
|
|
|
}
|
|
|
|
|
|
+ GGML_ASSERT(choices.size() >= 1);
|
|
|
+
|
|
|
+ if (prob_output.probs.size() > 0) {
|
|
|
+ choices[0]["logprobs"] = json{
|
|
|
+ {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
json ret = json {
|
|
|
{"choices", choices},
|
|
|
{"created", t},
|
|
|
@@ -1001,7 +1053,6 @@ struct server_slot {
|
|
|
|
|
|
// stats
|
|
|
size_t n_sent_text = 0; // number of sent text character
|
|
|
- size_t n_sent_token_probs = 0;
|
|
|
|
|
|
int64_t t_start_process_prompt;
|
|
|
int64_t t_start_generation;
|
|
|
@@ -1023,7 +1074,6 @@ struct server_slot {
|
|
|
stopping_word = "";
|
|
|
n_past = 0;
|
|
|
n_sent_text = 0;
|
|
|
- n_sent_token_probs = 0;
|
|
|
task_type = SERVER_TASK_TYPE_COMPLETION;
|
|
|
|
|
|
generated_tokens.clear();
|
|
|
@@ -1764,7 +1814,7 @@ struct server_context {
|
|
|
|
|
|
bool process_token(completion_token_output & result, server_slot & slot) {
|
|
|
// remember which tokens were sampled - used for repetition penalties during sampling
|
|
|
- const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special);
|
|
|
+ const std::string token_str = result.text_to_send;
|
|
|
slot.sampled = result.tok;
|
|
|
|
|
|
slot.generated_text += token_str;
|
|
|
@@ -1774,26 +1824,7 @@ struct server_context {
|
|
|
slot.has_next_token = true;
|
|
|
|
|
|
// check if there is incomplete UTF-8 character at the end
|
|
|
- bool incomplete = false;
|
|
|
- for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
|
|
|
- unsigned char c = slot.generated_text[slot.generated_text.size() - i];
|
|
|
- if ((c & 0xC0) == 0x80) {
|
|
|
- // continuation byte: 10xxxxxx
|
|
|
- continue;
|
|
|
- }
|
|
|
- if ((c & 0xE0) == 0xC0) {
|
|
|
- // 2-byte character: 110xxxxx ...
|
|
|
- incomplete = i < 2;
|
|
|
- } else if ((c & 0xF0) == 0xE0) {
|
|
|
- // 3-byte character: 1110xxxx ...
|
|
|
- incomplete = i < 3;
|
|
|
- } else if ((c & 0xF8) == 0xF0) {
|
|
|
- // 4-byte character: 11110xxx ...
|
|
|
- incomplete = i < 4;
|
|
|
- }
|
|
|
- // else 1-byte character or invalid byte
|
|
|
- break;
|
|
|
- }
|
|
|
+ bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
|
|
|
|
|
|
// search stop word and delete it
|
|
|
if (!incomplete) {
|
|
|
@@ -1923,6 +1954,55 @@ struct server_context {
|
|
|
return slot.has_next_token; // continue
|
|
|
}
|
|
|
|
|
|
+ void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
|
|
|
+ size_t n_probs = slot.params.sampling.n_probs;
|
|
|
+ size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
|
+ if (post_sampling) {
|
|
|
+ const auto * cur_p = common_sampler_get_candidates(slot.smpl);
|
|
|
+ const size_t max_probs = cur_p->size;
|
|
|
+
|
|
|
+ // set probability for sampled token
|
|
|
+ for (size_t i = 0; i < max_probs; i++) {
|
|
|
+ if (cur_p->data[i].id == result.tok) {
|
|
|
+ result.prob = cur_p->data[i].p;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // set probability for top n_probs tokens
|
|
|
+ result.probs.reserve(max_probs);
|
|
|
+ for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
|
|
|
+ result.probs.push_back({
|
|
|
+ cur_p->data[i].id,
|
|
|
+ common_detokenize(ctx, {cur_p->data[i].id}, special),
|
|
|
+ cur_p->data[i].p
|
|
|
+ });
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // TODO: optimize this with min-p optimization
|
|
|
+ std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
|
|
+
|
|
|
+ // set probability for sampled token
|
|
|
+ for (size_t i = 0; i < n_vocab; i++) {
|
|
|
+ // set probability for sampled token
|
|
|
+ if (cur[i].id == result.tok) {
|
|
|
+ result.prob = cur[i].p;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // set probability for top n_probs tokens
|
|
|
+ result.probs.reserve(n_probs);
|
|
|
+ for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
|
|
|
+ result.probs.push_back({
|
|
|
+ cur[i].id,
|
|
|
+ common_detokenize(ctx, {cur[i].id}, special),
|
|
|
+ cur[i].p
|
|
|
+ });
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
|
send_error(task.id, error, type);
|
|
|
}
|
|
|
@@ -1950,8 +2030,9 @@ struct server_context {
|
|
|
res->content = tkn.text_to_send;
|
|
|
res->tokens = { tkn.tok };
|
|
|
|
|
|
- res->n_decoded = slot.n_decoded;
|
|
|
- res->n_prompt_tokens = slot.n_prompt_tokens;
|
|
|
+ res->n_decoded = slot.n_decoded;
|
|
|
+ res->n_prompt_tokens = slot.n_prompt_tokens;
|
|
|
+ res->post_sampling_probs = slot.params.post_sampling_probs;
|
|
|
|
|
|
res->verbose = slot.params.verbose;
|
|
|
res->oaicompat = slot.params.oaicompat;
|
|
|
@@ -1961,17 +2042,7 @@ struct server_context {
|
|
|
|
|
|
// populate res.probs_output
|
|
|
if (slot.params.sampling.n_probs > 0) {
|
|
|
- const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
|
|
-
|
|
|
- const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
|
|
- const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
|
|
|
-
|
|
|
- std::vector<completion_token_output> probs_output;
|
|
|
- if (probs_pos < probs_stop_pos) {
|
|
|
- res->probs_output = std::vector<completion_token_output>(
|
|
|
- slot.generated_token_probs.begin() + probs_pos,
|
|
|
- slot.generated_token_probs.begin() + probs_stop_pos);
|
|
|
- }
|
|
|
+ res->prob_output = tkn; // copy the token probs
|
|
|
}
|
|
|
|
|
|
// populate timings if this is final response or timings_per_token is enabled
|
|
|
@@ -1993,13 +2064,14 @@ struct server_context {
|
|
|
res->timings = slot.get_timings();
|
|
|
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
|
|
|
|
|
- res->truncated = slot.truncated;
|
|
|
- res->n_decoded = slot.n_decoded;
|
|
|
- res->n_prompt_tokens = slot.n_prompt_tokens;
|
|
|
- res->n_tokens_cached = slot.n_past;
|
|
|
- res->has_new_line = slot.has_new_line;
|
|
|
- res->stopping_word = slot.stopping_word;
|
|
|
- res->stop = slot.stop;
|
|
|
+ res->truncated = slot.truncated;
|
|
|
+ res->n_decoded = slot.n_decoded;
|
|
|
+ res->n_prompt_tokens = slot.n_prompt_tokens;
|
|
|
+ res->n_tokens_cached = slot.n_past;
|
|
|
+ res->has_new_line = slot.has_new_line;
|
|
|
+ res->stopping_word = slot.stopping_word;
|
|
|
+ res->stop = slot.stop;
|
|
|
+ res->post_sampling_probs = slot.params.post_sampling_probs;
|
|
|
|
|
|
res->verbose = slot.params.verbose;
|
|
|
res->stream = slot.params.stream;
|
|
|
@@ -2796,7 +2868,9 @@ struct server_context {
|
|
|
continue; // continue loop of slots
|
|
|
}
|
|
|
|
|
|
- llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
|
|
|
+ const int tok_idx = slot.i_batch - i;
|
|
|
+
|
|
|
+ llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
|
|
|
|
|
slot.i_batch = -1;
|
|
|
|
|
|
@@ -2815,17 +2889,12 @@ struct server_context {
|
|
|
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
|
|
|
|
|
|
completion_token_output result;
|
|
|
- result.tok = id;
|
|
|
+ result.tok = id;
|
|
|
+ result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
|
|
|
+ result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
|
|
|
|
|
- const auto * cur_p = common_sampler_get_candidates(slot.smpl);
|
|
|
-
|
|
|
- for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
|
|
|
- auto tok_id = cur_p->data[i].id;
|
|
|
- result.probs.push_back({
|
|
|
- tok_id,
|
|
|
- tokens_to_output_formatted_string(ctx, tok_id),
|
|
|
- i >= cur_p->size ? 0.0f : cur_p->data[i].p,
|
|
|
- });
|
|
|
+ if (slot.params.sampling.n_probs > 0) {
|
|
|
+ populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
|
|
|
}
|
|
|
|
|
|
if (!process_token(result, slot)) {
|
|
|
@@ -2909,7 +2978,11 @@ struct server_context {
|
|
|
for (size_t i = 0; i < ids.size(); ++i) {
|
|
|
completion_token_output result;
|
|
|
|
|
|
- result.tok = ids[i];
|
|
|
+ result.tok = ids[i];
|
|
|
+ result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
|
|
|
+ result.prob = 1.0f; // set later
|
|
|
+
|
|
|
+ // TODO: set result.probs
|
|
|
|
|
|
if (!process_token(result, slot)) {
|
|
|
// release slot because of stop condition
|