|
|
@@ -28,6 +28,7 @@
|
|
|
#include <atomic>
|
|
|
#include <mutex>
|
|
|
#include <sstream>
|
|
|
+#include <numeric>
|
|
|
|
|
|
#define LLAMA_USE_SCRATCH
|
|
|
#define LLAMA_MAX_SCRATCH_BUFFERS 16
|
|
|
@@ -1475,109 +1476,402 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
|
|
|
// sampling
|
|
|
//
|
|
|
|
|
|
-static void sample_top_k(std::vector<std::pair<float, llama_vocab::id>> & logits_id, int top_k) {
|
|
|
- // find the top k tokens
|
|
|
- std::partial_sort(
|
|
|
- logits_id.begin(),
|
|
|
- logits_id.begin() + top_k, logits_id.end(),
|
|
|
- [](const std::pair<float, llama_vocab::id> & a, const std::pair<float, llama_vocab::id> & b) {
|
|
|
- return a.first > b.first;
|
|
|
- });
|
|
|
+void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
|
|
|
+ assert(candidates->size > 0);
|
|
|
+
|
|
|
+ const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
|
|
- logits_id.resize(top_k);
|
|
|
+ // Sort the logits in descending order
|
|
|
+ if (!candidates->sorted) {
|
|
|
+ std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
|
|
+ return a.logit > b.logit;
|
|
|
+ });
|
|
|
+ candidates->sorted = true;
|
|
|
+ }
|
|
|
+
|
|
|
+ float max_l = candidates->data[0].logit;
|
|
|
+ float cum_sum = 0.0f;
|
|
|
+ for (size_t i = 0; i < candidates->size; ++i) {
|
|
|
+ float p = expf(candidates->data[i].logit - max_l);
|
|
|
+ candidates->data[i].p = p;
|
|
|
+ cum_sum += p;
|
|
|
+ }
|
|
|
+ for (size_t i = 0; i < candidates->size; ++i) {
|
|
|
+ candidates->data[i].p /= cum_sum;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (ctx) {
|
|
|
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
-static llama_vocab::id llama_sample_top_p_top_k(
|
|
|
- llama_context & lctx,
|
|
|
- const std::vector<llama_vocab::id> & last_n_tokens,
|
|
|
- int top_k,
|
|
|
- float top_p,
|
|
|
- float temp,
|
|
|
- float repeat_penalty) {
|
|
|
- auto & rng = lctx.rng;
|
|
|
-
|
|
|
- const int n_logits = lctx.model.hparams.n_vocab;
|
|
|
-
|
|
|
- const auto & logits = lctx.logits;
|
|
|
- const auto * plogits = logits.data() + logits.size() - n_logits;
|
|
|
-
|
|
|
- if (temp <= 0) {
|
|
|
- // select the token with the highest logit directly
|
|
|
- float max_logit = plogits[0];
|
|
|
- llama_vocab::id max_id = 0;
|
|
|
-
|
|
|
- for (int i = 1; i < n_logits; ++i) {
|
|
|
- if (plogits[i] > max_logit) {
|
|
|
- max_logit = plogits[i];
|
|
|
- max_id = i;
|
|
|
- }
|
|
|
+void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep) {
|
|
|
+ const int64_t t_start_sample_us = ggml_time_us();
|
|
|
+
|
|
|
+ k = std::max(k, (int) min_keep);
|
|
|
+ k = std::min(k, (int) candidates->size);
|
|
|
+
|
|
|
+ // Sort scores in descending order
|
|
|
+ if (!candidates->sorted) {
|
|
|
+ auto comp = [](const llama_token_data & a, const llama_token_data & b) {
|
|
|
+ return a.logit > b.logit;
|
|
|
+ };
|
|
|
+ if (k == (int) candidates->size) {
|
|
|
+ std::sort(candidates->data, candidates->data + candidates->size, comp);
|
|
|
+ } else {
|
|
|
+ std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
|
|
|
}
|
|
|
- return max_id;
|
|
|
+ candidates->sorted = true;
|
|
|
}
|
|
|
+ candidates->size = k;
|
|
|
|
|
|
- std::vector<std::pair<float, llama_vocab::id>> logits_id;
|
|
|
- logits_id.reserve(n_logits);
|
|
|
+ if (ctx) {
|
|
|
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
- {
|
|
|
- const float scale = 1.0f/temp;
|
|
|
- for (int i = 0; i < n_logits; ++i) {
|
|
|
- // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
|
|
|
- // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
|
|
|
- if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
|
|
|
- // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
|
|
- if (plogits[i] < 0.0f) {
|
|
|
- logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
|
|
|
- } else {
|
|
|
- logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
|
|
|
- }
|
|
|
- } else {
|
|
|
- logits_id.push_back(std::make_pair(plogits[i]*scale, i));
|
|
|
- }
|
|
|
+void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
|
|
+ if (p >= 1.0f) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const int64_t t_start_sample_us = ggml_time_us();
|
|
|
+
|
|
|
+ llama_sample_softmax(ctx, candidates);
|
|
|
+
|
|
|
+ // Compute the cumulative probabilities
|
|
|
+ float cum_sum = 0.0f;
|
|
|
+ size_t last_idx = candidates->size;
|
|
|
+
|
|
|
+ for (size_t i = 0; i < candidates->size; ++i) {
|
|
|
+ cum_sum += candidates->data[i].p;
|
|
|
+
|
|
|
+ // Check if the running sum is greater than p or if we have kept at least min_keep tokens
|
|
|
+ if (cum_sum > p && i >= min_keep) {
|
|
|
+ last_idx = i;
|
|
|
+ break;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- sample_top_k(logits_id, top_k > 0 ? std::min(top_k, n_logits) : n_logits);
|
|
|
+ // Resize the output vector to keep only the top-p tokens
|
|
|
+ candidates->size = last_idx;
|
|
|
|
|
|
- // compute probs for the top k tokens
|
|
|
- std::vector<float> probs;
|
|
|
- probs.reserve(logits_id.size());
|
|
|
+ if (ctx) {
|
|
|
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
- float maxl = logits_id[0].first;
|
|
|
- double sum = 0.0;
|
|
|
- for (const auto & kv : logits_id) {
|
|
|
- const float p = expf(kv.first - maxl);
|
|
|
- probs.push_back(p);
|
|
|
- sum += p;
|
|
|
+void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
|
|
|
+ if (z >= 1.0f || candidates->size <= 2) {
|
|
|
+ return;
|
|
|
}
|
|
|
|
|
|
- // normalize the probs
|
|
|
- for (auto & p : probs) {
|
|
|
- p /= sum;
|
|
|
+ const int64_t t_start_sample_us = ggml_time_us();
|
|
|
+
|
|
|
+ llama_sample_softmax(nullptr, candidates);
|
|
|
+
|
|
|
+ // Compute the first and second derivatives
|
|
|
+ std::vector<float> first_derivatives(candidates->size - 1);
|
|
|
+ std::vector<float> second_derivatives(candidates->size - 2);
|
|
|
+
|
|
|
+ for (size_t i = 0; i < first_derivatives.size(); ++i) {
|
|
|
+ first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
|
|
|
+ }
|
|
|
+ for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
|
|
+ second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
|
|
|
}
|
|
|
|
|
|
- if (top_p < 1.0) {
|
|
|
- double cumsum = 0.0;
|
|
|
- for (int i = 0; i < (int) probs.size(); i++) {
|
|
|
- cumsum += probs[i];
|
|
|
- if (cumsum >= top_p) {
|
|
|
- probs.resize(i + 1);
|
|
|
- logits_id.resize(i + 1);
|
|
|
- break;
|
|
|
- }
|
|
|
+ // Calculate absolute value of second derivatives
|
|
|
+ for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
|
|
+ second_derivatives[i] = abs(second_derivatives[i]);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Normalize the second derivatives
|
|
|
+ float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
|
|
|
+ for (float & value : second_derivatives) {
|
|
|
+ value /= second_derivatives_sum;
|
|
|
+ }
|
|
|
+
|
|
|
+ float cum_sum = 0.0f;
|
|
|
+ size_t last_idx = candidates->size;
|
|
|
+ for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
|
|
+ cum_sum += second_derivatives[i];
|
|
|
+
|
|
|
+ // Check if the running sum is greater than z or if we have kept at least min_keep tokens
|
|
|
+ if (cum_sum > z && i >= min_keep) {
|
|
|
+ last_idx = i;
|
|
|
+ break;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- //printf("\n");
|
|
|
- //for (int i = 0; i < (int) 10; i++) {
|
|
|
- // printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]);
|
|
|
- //}
|
|
|
- //printf("\n\n");
|
|
|
- //exit(0);
|
|
|
+ // Resize the output vector to keep only the tokens above the tail location
|
|
|
+ candidates->size = last_idx;
|
|
|
+
|
|
|
+ if (ctx) {
|
|
|
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
|
|
+ // Reference implementation:
|
|
|
+ // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
|
|
|
+ if (p >= 1.0f) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const int64_t t_start_sample_us = ggml_time_us();
|
|
|
+
|
|
|
+ // Compute the softmax of logits and calculate entropy
|
|
|
+ llama_sample_softmax(nullptr, candidates);
|
|
|
+
|
|
|
+ float entropy = 0.0f;
|
|
|
+ for (size_t i = 0; i < candidates->size; ++i) {
|
|
|
+ entropy += -candidates->data[i].p * logf(candidates->data[i].p);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Compute the absolute difference between negative log probability and entropy for each candidate
|
|
|
+ std::vector<float> shifted_scores;
|
|
|
+ for (size_t i = 0; i < candidates->size; ++i) {
|
|
|
+ float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
|
|
|
+ shifted_scores.push_back(shifted_score);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Sort tokens based on the shifted_scores and their corresponding indices
|
|
|
+ std::vector<size_t> indices(candidates->size);
|
|
|
+ std::iota(indices.begin(), indices.end(), 0);
|
|
|
+
|
|
|
+ std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
|
|
|
+ return shifted_scores[a] < shifted_scores[b];
|
|
|
+ });
|
|
|
+
|
|
|
+ // Compute the cumulative probabilities
|
|
|
+ float cum_sum = 0.0f;
|
|
|
+ size_t last_idx = indices.size();
|
|
|
+
|
|
|
+ for (size_t i = 0; i < indices.size(); ++i) {
|
|
|
+ size_t idx = indices[i];
|
|
|
+ cum_sum += candidates->data[idx].p;
|
|
|
+
|
|
|
+ // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
|
|
|
+ if (cum_sum > p && i >= min_keep - 1) {
|
|
|
+ last_idx = i + 1;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Resize the output vector to keep only the locally typical tokens
|
|
|
+ std::vector<llama_token_data> new_candidates;
|
|
|
+ for (size_t i = 0; i < last_idx; ++i) {
|
|
|
+ size_t idx = indices[i];
|
|
|
+ new_candidates.push_back(candidates->data[idx]);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Replace the data in candidates with the new_candidates data
|
|
|
+ std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
|
|
|
+ candidates->size = new_candidates.size();
|
|
|
+
|
|
|
+ if (ctx) {
|
|
|
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
|
|
|
+ const int64_t t_start_sample_us = ggml_time_us();
|
|
|
+
|
|
|
+ for (size_t i = 0; i < candidates_p->size; ++i) {
|
|
|
+ candidates_p->data[i].logit /= temp;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (ctx) {
|
|
|
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float penalty) {
|
|
|
+ if (last_tokens_size == 0 || penalty == 1.0f) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const int64_t t_start_sample_us = ggml_time_us();
|
|
|
+
|
|
|
+ for (size_t i = 0; i < candidates->size; ++i) {
|
|
|
+ auto token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id);
|
|
|
+ if (token_iter == last_tokens + last_tokens_size) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
|
|
+ // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
|
|
+ if (candidates->data[i].logit <= 0) {
|
|
|
+ candidates->data[i].logit *= penalty;
|
|
|
+ } else {
|
|
|
+ candidates->data[i].logit /= penalty;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ candidates->sorted = false;
|
|
|
+
|
|
|
+ if (ctx) {
|
|
|
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) {
|
|
|
+ if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const int64_t t_start_sample_us = ggml_time_us();
|
|
|
+
|
|
|
+ // Create a frequency map to count occurrences of each token in last_tokens
|
|
|
+ std::unordered_map<llama_token, int> token_count;
|
|
|
+ for (size_t i = 0; i < last_tokens_size; ++i) {
|
|
|
+ token_count[last_tokens_p[i]]++;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Apply frequency and presence penalties to the candidates
|
|
|
+ for (size_t i = 0; i < candidates->size; ++i) {
|
|
|
+ auto token_iter = token_count.find(candidates->data[i].id);
|
|
|
+ if (token_iter == token_count.end()) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ int count = token_iter->second;
|
|
|
+ candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence;
|
|
|
+ }
|
|
|
+
|
|
|
+ candidates->sorted = false;
|
|
|
+
|
|
|
+ if (ctx) {
|
|
|
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) {
|
|
|
+ assert(ctx);
|
|
|
+ auto N = float(llama_n_vocab(ctx));
|
|
|
+ int64_t t_start_sample_us;
|
|
|
+ t_start_sample_us = ggml_time_us();
|
|
|
+
|
|
|
+ llama_sample_softmax(nullptr, candidates);
|
|
|
+
|
|
|
+ // Estimate s_hat using the most probable m tokens
|
|
|
+ float s_hat = 0.0;
|
|
|
+ float sum_ti_bi = 0.0;
|
|
|
+ float sum_ti_sq = 0.0;
|
|
|
+ for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
|
|
|
+ float t_i = logf(float(i + 2) / float(i + 1));
|
|
|
+ float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
|
|
|
+ sum_ti_bi += t_i * b_i;
|
|
|
+ sum_ti_sq += t_i * t_i;
|
|
|
+ }
|
|
|
+ s_hat = sum_ti_bi / sum_ti_sq;
|
|
|
+
|
|
|
+ // Compute k from the estimated s_hat and target surprise value
|
|
|
+ float epsilon_hat = s_hat - 1;
|
|
|
+ float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat);
|
|
|
+
|
|
|
+ // Sample the next word X using top-k sampling
|
|
|
+ llama_sample_top_k(nullptr, candidates, int(k));
|
|
|
+ if (ctx) {
|
|
|
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
+ }
|
|
|
+ llama_token X = llama_sample_token(ctx, candidates);
|
|
|
+ t_start_sample_us = ggml_time_us();
|
|
|
+
|
|
|
+ // Compute error as the difference between observed surprise and target surprise value
|
|
|
+ size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
|
|
+ return candidate.id == X;
|
|
|
+ }));
|
|
|
+ float observed_surprise = -log2f(candidates->data[X_idx].p);
|
|
|
+ float e = observed_surprise - tau;
|
|
|
+
|
|
|
+ // Update mu using the learning rate and error
|
|
|
+ *mu = *mu - eta * e;
|
|
|
+
|
|
|
+ if (ctx) {
|
|
|
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
+ ctx->n_sample++;
|
|
|
+ }
|
|
|
+ return X;
|
|
|
+}
|
|
|
+
|
|
|
+llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
|
|
|
+ assert(ctx);
|
|
|
+ int64_t t_start_sample_us;
|
|
|
+ t_start_sample_us = ggml_time_us();
|
|
|
+
|
|
|
+ llama_sample_softmax(ctx, candidates);
|
|
|
+
|
|
|
+ // Truncate the words with surprise values greater than mu
|
|
|
+ candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
|
|
+ return -log2f(candidate.p) > *mu;
|
|
|
+ }));
|
|
|
+
|
|
|
+ // Normalize the probabilities of the remaining words
|
|
|
+ llama_sample_softmax(ctx, candidates);
|
|
|
+
|
|
|
+ // Sample the next word X from the remaining words
|
|
|
+ if (ctx) {
|
|
|
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
+ }
|
|
|
+ llama_token X = llama_sample_token(ctx, candidates);
|
|
|
+ t_start_sample_us = ggml_time_us();
|
|
|
+
|
|
|
+ // Compute error as the difference between observed surprise and target surprise value
|
|
|
+ size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
|
|
+ return candidate.id == X;
|
|
|
+ }));
|
|
|
+ float observed_surprise = -log2f(candidates->data[X_idx].p);
|
|
|
+ float e = observed_surprise - tau;
|
|
|
+
|
|
|
+ // Update mu using the learning rate and error
|
|
|
+ *mu = *mu - eta * e;
|
|
|
+
|
|
|
+ if (ctx) {
|
|
|
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
+ }
|
|
|
+ return X;
|
|
|
+}
|
|
|
+
|
|
|
+llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
|
|
|
+ const int64_t t_start_sample_us = ggml_time_us();
|
|
|
+
|
|
|
+ // Find max element
|
|
|
+ auto max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
|
|
+ return a.logit < b.logit;
|
|
|
+ });
|
|
|
+
|
|
|
+ llama_token result = max_iter->id;
|
|
|
+ if (ctx) {
|
|
|
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
+ ctx->n_sample++;
|
|
|
+ }
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
|
|
|
+ assert(ctx);
|
|
|
+ const int64_t t_start_sample_us = ggml_time_us();
|
|
|
+ llama_sample_softmax(nullptr, candidates);
|
|
|
+
|
|
|
+ std::vector<float> probs;
|
|
|
+ probs.reserve(candidates->size);
|
|
|
+ for (size_t i = 0; i < candidates->size; ++i) {
|
|
|
+ probs.push_back(candidates->data[i].p);
|
|
|
+ }
|
|
|
|
|
|
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
|
|
+ auto & rng = ctx->rng;
|
|
|
int idx = dist(rng);
|
|
|
|
|
|
- return logits_id[idx].second;
|
|
|
+ llama_token result = candidates->data[idx].id;
|
|
|
+
|
|
|
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
+ ctx->n_sample++;
|
|
|
+ return result;
|
|
|
}
|
|
|
|
|
|
//
|
|
|
@@ -2348,33 +2642,8 @@ llama_token llama_token_eos() {
|
|
|
return 2;
|
|
|
}
|
|
|
|
|
|
-llama_token llama_sample_top_p_top_k(
|
|
|
- llama_context * ctx,
|
|
|
- const llama_token * last_n_tokens_data,
|
|
|
- int last_n_tokens_size,
|
|
|
- int top_k,
|
|
|
- float top_p,
|
|
|
- float temp,
|
|
|
- float repeat_penalty) {
|
|
|
- const int64_t t_start_sample_us = ggml_time_us();
|
|
|
-
|
|
|
- llama_token result = 0;
|
|
|
-
|
|
|
- // TODO: avoid this ...
|
|
|
- const auto last_n_tokens = std::vector<llama_token>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
|
|
|
-
|
|
|
- result = llama_sample_top_p_top_k(
|
|
|
- *ctx,
|
|
|
- last_n_tokens,
|
|
|
- top_k,
|
|
|
- top_p,
|
|
|
- temp,
|
|
|
- repeat_penalty);
|
|
|
-
|
|
|
- ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
- ctx->n_sample++;
|
|
|
-
|
|
|
- return result;
|
|
|
+llama_token llama_token_nl() {
|
|
|
+ return 13;
|
|
|
}
|
|
|
|
|
|
|