|
|
@@ -52,6 +52,7 @@
|
|
|
#include <algorithm>
|
|
|
#include <array>
|
|
|
#include <cassert>
|
|
|
+#include <cfloat>
|
|
|
#include <cinttypes>
|
|
|
#include <climits>
|
|
|
#include <cmath>
|
|
|
@@ -8246,21 +8247,56 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- llama_sample_softmax(ctx, candidates);
|
|
|
-
|
|
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
|
|
- float scale = candidates->data[0].p; // scale by max prob
|
|
|
- size_t i = 1; // first token always matches
|
|
|
+ bool min_p_applied = false;
|
|
|
+
|
|
|
+ // if the candidates aren't sorted, try the unsorted implementation first
|
|
|
+ if (!candidates->sorted) {
|
|
|
+ std::vector<llama_token_data> filtered_tokens;
|
|
|
+
|
|
|
+ float max_logit = -FLT_MAX;
|
|
|
+ for (size_t i = 0; i < candidates->size; ++i) {
|
|
|
+ max_logit = std::max(max_logit, candidates->data[i].logit);
|
|
|
+ }
|
|
|
+ const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
|
|
|
+
|
|
|
+ for (size_t i = 0; i < candidates->size; ++i) {
|
|
|
+ if (candidates->data[i].logit >= min_logit) {
|
|
|
+ filtered_tokens.push_back(candidates->data[i]);
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- for (; i < candidates->size; ++i) {
|
|
|
- if (candidates->data[i].p < p * scale && i >= min_keep) {
|
|
|
- break; // prob too small
|
|
|
+ // if we have enough values the operation was a success
|
|
|
+ if (filtered_tokens.size() >= min_keep) {
|
|
|
+ memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
|
|
|
+ candidates->size = filtered_tokens.size();
|
|
|
+ min_p_applied = true;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // Resize the output vector to keep only the matching tokens
|
|
|
- candidates->size = i;
|
|
|
+ // if the candidates are sorted or the unsorted implementation failed, use this implementation
|
|
|
+ if (!min_p_applied) {
|
|
|
+ // Sort the logits in descending order
|
|
|
+ if (!candidates->sorted) {
|
|
|
+ std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
|
|
+ return a.logit > b.logit;
|
|
|
+ });
|
|
|
+ candidates->sorted = true;
|
|
|
+ }
|
|
|
+
|
|
|
+ const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
|
|
|
+ size_t i = 1; // first token always matches
|
|
|
+
|
|
|
+ for (; i < candidates->size; ++i) {
|
|
|
+ if (candidates->data[i].logit < min_logit && i >= min_keep) {
|
|
|
+ break; // prob too small
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Resize the output vector to keep only the matching tokens
|
|
|
+ candidates->size = i;
|
|
|
+ }
|
|
|
|
|
|
if (ctx) {
|
|
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|