|
|
@@ -1739,6 +1739,207 @@ struct llama_sampler * llama_sampler_init_logit_bias(
|
|
|
};
|
|
|
}
|
|
|
|
|
|
+// infill
|
|
|
+
|
|
|
+//#define GGML_DEBUG_SAMPLER_INFILL
|
|
|
+
|
|
|
+struct llama_sampler_infill {
|
|
|
+ const struct llama_vocab * vocab;
|
|
|
+};
|
|
|
+
|
|
|
+static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
|
|
|
+ return "infill";
|
|
|
+}
|
|
|
+
|
|
|
+static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
|
+ auto * ctx = (llama_sampler_infill *) smpl->ctx;
|
|
|
+
|
|
|
+ llama_sampler_softmax_impl(cur_p);
|
|
|
+
|
|
|
+#if defined(GGML_DEBUG_SAMPLER_INFILL)
|
|
|
+#define LOG_DBG_CUR LLAMA_LOG_DEBUG
|
|
|
+#else
|
|
|
+#define LOG_DBG_CUR(...)
|
|
|
+#endif
|
|
|
+
|
|
|
+ for (size_t i = 0; i < cur_p->size; ++i) {
|
|
|
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
|
|
+ }
|
|
|
+
|
|
|
+ float p_txt_sum = 0.0f;
|
|
|
+ float p_eog_sum = 0.0f;
|
|
|
+
|
|
|
+ for (size_t i = 0; i < cur_p->size; ++i) {
|
|
|
+ if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
|
|
|
+ p_eog_sum += cur_p->data[i].p;
|
|
|
+ } else {
|
|
|
+ p_txt_sum += cur_p->data[i].p;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
|
|
|
+
|
|
|
+ LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
|
|
|
+
|
|
|
+ if (3*p_eog_sum*cur_p->size > p_txt_sum) {
|
|
|
+ LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
|
|
|
+
|
|
|
+ // keep just the EOG tokens
|
|
|
+ const auto size_org = cur_p->size;
|
|
|
+
|
|
|
+ cur_p->size = 0;
|
|
|
+
|
|
|
+ float p_sum = 0.0f;
|
|
|
+
|
|
|
+ for (size_t i = 0; i < size_org; ++i) {
|
|
|
+ if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
|
|
|
+ p_sum += cur_p->data[i].p;
|
|
|
+
|
|
|
+ cur_p->data[cur_p->size++] = cur_p->data[i];
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // normalize probs
|
|
|
+ for (size_t i = 0; i < cur_p->size; ++i) {
|
|
|
+ cur_p->data[i].p /= p_sum;
|
|
|
+ }
|
|
|
+
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ size_t n_combined = 0; GGML_UNUSED(n_combined);
|
|
|
+
|
|
|
+ // combine tokens with common prefix
|
|
|
+ for (size_t i = 0; i < cur_p->size; ++i) {
|
|
|
+ for (size_t j = 0; j < cur_p->size; ++j) {
|
|
|
+ if (cur_p->data[i].logit == -INFINITY) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (i == j || cur_p->data[j].logit == -INFINITY) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
|
|
|
+ if (cur_p->data[i].p > cur_p->data[j].p) {
|
|
|
+ cur_p->data[i].p += cur_p->data[j].p;
|
|
|
+ cur_p->data[j].logit = -INFINITY;
|
|
|
+ cur_p->data[j].p = 0.0f;
|
|
|
+ } else {
|
|
|
+ cur_p->data[j].p += cur_p->data[i].p;
|
|
|
+ cur_p->data[i].logit = -INFINITY;
|
|
|
+ cur_p->data[i].p = 0.0f;
|
|
|
+ }
|
|
|
+
|
|
|
+ n_combined++;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ size_t n_non_eog = 0;
|
|
|
+
|
|
|
+ size_t size_org = cur_p->size;
|
|
|
+
|
|
|
+ float p_sum = 0.0f;
|
|
|
+ float thold = 0.2f;
|
|
|
+
|
|
|
+ cur_p->size = 0;
|
|
|
+
|
|
|
+ LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
|
|
|
+
|
|
|
+ for (size_t i = 0; i < size_org; ++i) {
|
|
|
+ const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
|
|
|
+
|
|
|
+ if (cur_p->data[i].p < thold && !is_eog) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (!is_eog) {
|
|
|
+ ++n_non_eog;
|
|
|
+ }
|
|
|
+
|
|
|
+ p_sum += cur_p->data[i].p;
|
|
|
+
|
|
|
+ // keep this token
|
|
|
+ cur_p->data[cur_p->size++] = cur_p->data[i];
|
|
|
+ }
|
|
|
+
|
|
|
+ LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
|
|
|
+
|
|
|
+ // if no non-EOG tokens are left -> reduce cur_p to single EOT token
|
|
|
+ if (n_non_eog == 0) {
|
|
|
+ cur_p->size = 1;
|
|
|
+ cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
|
|
|
+ cur_p->data[0].logit = 1.0f;
|
|
|
+
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ // normalize probs
|
|
|
+ for (size_t i = 0; i < cur_p->size; ++i) {
|
|
|
+ cur_p->data[i].p /= p_sum;
|
|
|
+
|
|
|
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
|
|
+ }
|
|
|
+
|
|
|
+ size_org = cur_p->size;
|
|
|
+ p_sum = 0.0f;
|
|
|
+ thold = 1.0/(n_non_eog + 1);
|
|
|
+
|
|
|
+ cur_p->size = 0;
|
|
|
+
|
|
|
+ LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
|
|
|
+
|
|
|
+ for (size_t i = 0; i < size_org; ++i) {
|
|
|
+ const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
|
|
|
+
|
|
|
+ if (cur_p->data[i].p < thold && !is_eog) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ p_sum += cur_p->data[i].p;
|
|
|
+
|
|
|
+ cur_p->data[cur_p->size++] = cur_p->data[i];
|
|
|
+ }
|
|
|
+
|
|
|
+ // normalize probs
|
|
|
+ for (size_t i = 0; i < cur_p->size; ++i) {
|
|
|
+ cur_p->data[i].p /= p_sum;
|
|
|
+
|
|
|
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
|
|
+ }
|
|
|
+
|
|
|
+#undef LOG_DBG_CUR
|
|
|
+}
|
|
|
+
|
|
|
+static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
|
|
|
+ const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
|
|
|
+ return llama_sampler_init_infill_impl(*ctx->vocab);
|
|
|
+}
|
|
|
+
|
|
|
+static void llama_sampler_infill_free(struct llama_sampler * smpl) {
|
|
|
+ delete (llama_sampler_infill *) smpl->ctx;
|
|
|
+}
|
|
|
+
|
|
|
+static struct llama_sampler_i llama_sampler_infill_i = {
|
|
|
+ /* .name = */ llama_sampler_infill_name,
|
|
|
+ /* .accept = */ nullptr,
|
|
|
+ /* .apply = */ llama_sampler_infill_apply,
|
|
|
+ /* .reset = */ nullptr,
|
|
|
+ /* .clone = */ llama_sampler_infill_clone,
|
|
|
+ /* .free = */ llama_sampler_infill_free,
|
|
|
+};
|
|
|
+
|
|
|
+struct llama_sampler * llama_sampler_init_infill_impl(
|
|
|
+ const struct llama_vocab & vocab) {
|
|
|
+ return new llama_sampler {
|
|
|
+ /* .iface = */ &llama_sampler_infill_i,
|
|
|
+ /* .ctx = */ new llama_sampler_infill {
|
|
|
+ /* .vocab = */ &vocab,
|
|
|
+ },
|
|
|
+ };
|
|
|
+}
|
|
|
+
|
|
|
// utils
|
|
|
|
|
|
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|