|
|
@@ -1326,11 +1326,12 @@ private:
|
|
|
}
|
|
|
|
|
|
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
|
|
|
- const size_t n_probs = slot.task->params.sampling.n_probs;
|
|
|
+ const size_t n_probs_request = slot.task->params.sampling.n_probs;
|
|
|
|
|
|
if (post_sampling) {
|
|
|
const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true);
|
|
|
const size_t max_probs = cur_p->size;
|
|
|
+ const size_t n_probs = std::min(max_probs, n_probs_request);
|
|
|
|
|
|
// set probability for sampled token
|
|
|
for (size_t i = 0; i < max_probs; i++) {
|
|
|
@@ -1341,8 +1342,8 @@ private:
|
|
|
}
|
|
|
|
|
|
// set probability for top n_probs tokens
|
|
|
- result.probs.reserve(max_probs);
|
|
|
- for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
|
|
|
+ result.probs.reserve(n_probs);
|
|
|
+ for (size_t i = 0; i < n_probs; i++) {
|
|
|
result.probs.push_back({
|
|
|
cur_p->data[i].id,
|
|
|
common_token_to_piece(ctx, cur_p->data[i].id, special),
|
|
|
@@ -1352,9 +1353,11 @@ private:
|
|
|
} else {
|
|
|
// TODO: optimize this with min-p optimization
|
|
|
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
|
|
+ const size_t max_probs = cur.size();
|
|
|
+ const size_t n_probs = std::min(max_probs, n_probs_request);
|
|
|
|
|
|
// set probability for sampled token
|
|
|
- for (size_t i = 0; i < cur.size(); i++) {
|
|
|
+ for (size_t i = 0; i < max_probs; i++) {
|
|
|
// set probability for sampled token
|
|
|
if (cur[i].id == result.tok) {
|
|
|
result.prob = cur[i].p;
|
|
|
@@ -1364,7 +1367,7 @@ private:
|
|
|
|
|
|
// set probability for top n_probs tokens
|
|
|
result.probs.reserve(n_probs);
|
|
|
- for (size_t i = 0; i < std::min(cur.size(), n_probs); i++) {
|
|
|
+ for (size_t i = 0; i < n_probs; i++) {
|
|
|
result.probs.push_back({
|
|
|
cur[i].id,
|
|
|
common_token_to_piece(ctx, cur[i].id, special),
|