|
|
@@ -1396,19 +1396,15 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab
|
|
|
// penalties
|
|
|
|
|
|
struct llama_sampler_penalties {
|
|
|
- const int32_t n_vocab;
|
|
|
- const llama_token special_eos_id;
|
|
|
- const llama_token linefeed_id;
|
|
|
-
|
|
|
const int32_t penalty_last_n;
|
|
|
const float penalty_repeat;
|
|
|
const float penalty_freq;
|
|
|
const float penalty_present;
|
|
|
|
|
|
- const bool penalize_nl;
|
|
|
- const bool ignore_eos;
|
|
|
-
|
|
|
ring_buffer<llama_token> prev;
|
|
|
+
|
|
|
+ // a frequency map to count token occurrences
|
|
|
+ std::unordered_map<llama_token, int> token_count;
|
|
|
};
|
|
|
|
|
|
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
|
|
|
@@ -1421,76 +1417,50 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- ctx->prev.push_back(token);
|
|
|
-}
|
|
|
-
|
|
|
-static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
|
- auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
|
|
+ ctx->token_count[token]++;
|
|
|
|
|
|
- if (ctx->ignore_eos) {
|
|
|
- assert(ctx->special_eos_id >= 0);
|
|
|
+ // if the ring buffer is full, remove the oldest token
|
|
|
+ if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
|
|
|
+ const auto old = ctx->prev.front();
|
|
|
|
|
|
- // optimistically check if the candidates are not yet sorted/shuffled/truncated
|
|
|
- if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
|
|
|
- cur_p->data[ctx->special_eos_id].logit = -INFINITY;
|
|
|
- } else {
|
|
|
- // else, search for the special EOS token
|
|
|
- for (size_t i = 0; i < cur_p->size; ++i) {
|
|
|
- if (cur_p->data[i].id == ctx->special_eos_id) {
|
|
|
- cur_p->data[i].logit = -INFINITY;
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
+ ctx->token_count[old]--;
|
|
|
+ if (ctx->token_count[old] == 0) {
|
|
|
+ ctx->token_count.erase(old);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if ((ctx->penalty_last_n == 0) ||
|
|
|
- (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
- bool nl_found = false;
|
|
|
- size_t nl_idx = 0;
|
|
|
- float nl_logit = -INFINITY;
|
|
|
- if (!ctx->penalize_nl) {
|
|
|
- assert(ctx->linefeed_id >= 0);
|
|
|
+ ctx->prev.push_back(token);
|
|
|
|
|
|
- // optimistically check if the candidates are not yet sorted/shuffled/truncated
|
|
|
- if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
|
|
|
- nl_found = true;
|
|
|
- nl_idx = ctx->linefeed_id;
|
|
|
- nl_logit = cur_p->data[ctx->linefeed_id].logit;
|
|
|
- } else {
|
|
|
- // else, search for the linefeed token
|
|
|
- for (size_t i = 0; i < cur_p->size; ++i) {
|
|
|
- if (cur_p->data[i].id == ctx->linefeed_id) {
|
|
|
- nl_found = true;
|
|
|
- nl_idx = i;
|
|
|
- nl_logit = cur_p->data[i].logit;
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
+#if 0
|
|
|
+ // sanity check
|
|
|
+ std::unordered_map<llama_token, int> tmp;
|
|
|
+ for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
|
|
|
+ tmp[ctx->prev.rat(i)]++;
|
|
|
}
|
|
|
|
|
|
- // Create a frequency map to count occurrences of each token in last_tokens
|
|
|
- // TODO: optimize this by maintaining the token count in the sampler context
|
|
|
- using llama_token_cnt = std::unordered_map<llama_token, int>;
|
|
|
- llama_token_cnt token_count;
|
|
|
+ assert(ctx->token_count == tmp);
|
|
|
+#endif
|
|
|
+}
|
|
|
+
|
|
|
+static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
|
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
|
|
|
|
|
- for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
|
|
|
- token_count[ctx->prev.rat(i)]++;
|
|
|
+ if ((ctx->penalty_last_n == 0) ||
|
|
|
+ (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
|
|
|
+ return;
|
|
|
}
|
|
|
|
|
|
// Apply frequency and presence penalties to the cur_p
|
|
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
|
- const auto token_iter = token_count.find(cur_p->data[i].id);
|
|
|
- if (token_iter == token_count.end()) {
|
|
|
+ const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
|
|
|
+ if (token_iter == ctx->token_count.end()) {
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
const int count = token_iter->second;
|
|
|
|
|
|
+ assert(count > 0 && count <= ctx->penalty_last_n);
|
|
|
+
|
|
|
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
|
|
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
|
|
if (cur_p->data[i].logit <= 0) {
|
|
|
@@ -1503,30 +1473,21 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
|
|
|
}
|
|
|
|
|
|
cur_p->sorted = false;
|
|
|
-
|
|
|
- if (!ctx->penalize_nl && nl_found) {
|
|
|
- // restore the logit of the newline token if it was penalized
|
|
|
- cur_p->data[nl_idx].logit = nl_logit;
|
|
|
- }
|
|
|
}
|
|
|
|
|
|
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
|
|
|
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
|
|
ctx->prev.clear();
|
|
|
+ ctx->token_count.clear();
|
|
|
}
|
|
|
|
|
|
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
|
|
|
const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
|
|
|
auto * result = llama_sampler_init_penalties(
|
|
|
- ctx->n_vocab,
|
|
|
- ctx->special_eos_id,
|
|
|
- ctx->linefeed_id,
|
|
|
ctx->penalty_last_n,
|
|
|
ctx->penalty_repeat,
|
|
|
ctx->penalty_freq,
|
|
|
- ctx->penalty_present,
|
|
|
- ctx->penalize_nl,
|
|
|
- ctx->ignore_eos);
|
|
|
+ ctx->penalty_present);
|
|
|
|
|
|
// copy the state
|
|
|
{
|
|
|
@@ -1552,38 +1513,21 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
|
|
|
};
|
|
|
|
|
|
struct llama_sampler * llama_sampler_init_penalties(
|
|
|
- int32_t n_vocab,
|
|
|
- llama_token special_eos_id,
|
|
|
- llama_token linefeed_id,
|
|
|
int32_t penalty_last_n,
|
|
|
float penalty_repeat,
|
|
|
float penalty_freq,
|
|
|
- float penalty_present,
|
|
|
- bool penalize_nl,
|
|
|
- bool ignore_eos) {
|
|
|
- if (linefeed_id == LLAMA_TOKEN_NULL) {
|
|
|
- penalize_nl = true;
|
|
|
- }
|
|
|
-
|
|
|
- if (special_eos_id == LLAMA_TOKEN_NULL) {
|
|
|
- ignore_eos = false;
|
|
|
- }
|
|
|
-
|
|
|
+ float penalty_present) {
|
|
|
penalty_last_n = std::max(penalty_last_n, 0);
|
|
|
|
|
|
return new llama_sampler {
|
|
|
/* .iface = */ &llama_sampler_penalties_i,
|
|
|
/* .ctx = */ new llama_sampler_penalties {
|
|
|
- /* .n_vocab = */ n_vocab,
|
|
|
- /* .special_eos_id = */ special_eos_id,
|
|
|
- /* .linefeed_id = */ linefeed_id,
|
|
|
/* .penalty_last_n = */ penalty_last_n,
|
|
|
/* .penalty_repeat = */ penalty_repeat,
|
|
|
/* .penalty_freq = */ penalty_freq,
|
|
|
/* .penalty_present = */ penalty_present,
|
|
|
- /* .penalize_nl = */ penalize_nl,
|
|
|
- /* .ignore_eos = */ ignore_eos,
|
|
|
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
|
|
+ /* .token_count = */ {},
|
|
|
},
|
|
|
};
|
|
|
}
|
|
|
@@ -1611,7 +1555,8 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
|
|
|
if (word.find(str) != std::string::npos) {
|
|
|
token_sequences.emplace(token_id, std::vector<llama_token>());
|
|
|
} else {
|
|
|
- size_t word_len = word.size(), str_len = str.size();
|
|
|
+ size_t word_len = word.size();
|
|
|
+ size_t str_len = str.size();
|
|
|
size_t pos = -1;
|
|
|
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
|
|
|
bool match = true;
|