|
@@ -89,10 +89,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
|
|
|
|
|
|
|
snprintf(result, sizeof(result),
|
|
snprintf(result, sizeof(result),
|
|
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
|
|
- "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, typical_p = %.3f, temp = %.3f\n"
|
|
|
|
|
|
|
+ "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
|
|
|
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
|
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
|
|
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
|
|
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
|
|
|
- params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp,
|
|
|
|
|
|
|
+ params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
|
|
|
params.mirostat, params.mirostat_eta, params.mirostat_tau);
|
|
params.mirostat, params.mirostat_eta, params.mirostat_tau);
|
|
|
|
|
|
|
|
return std::string(result);
|
|
return std::string(result);
|
|
@@ -110,6 +110,7 @@ llama_token llama_sampling_sample(
|
|
|
const float temp = params.temp;
|
|
const float temp = params.temp;
|
|
|
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
|
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
|
|
const float top_p = params.top_p;
|
|
const float top_p = params.top_p;
|
|
|
|
|
+ const float min_p = params.min_p;
|
|
|
const float tfs_z = params.tfs_z;
|
|
const float tfs_z = params.tfs_z;
|
|
|
const float typical_p = params.typical_p;
|
|
const float typical_p = params.typical_p;
|
|
|
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
|
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
|
@@ -190,6 +191,7 @@ llama_token llama_sampling_sample(
|
|
|
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
|
|
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
|
|
|
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
|
|
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
|
|
|
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
|
|
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
|
|
|
|
|
+ llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
|
|
|
llama_sample_temp (ctx_main, &cur_p, temp);
|
|
llama_sample_temp (ctx_main, &cur_p, temp);
|
|
|
|
|
|
|
|
id = llama_sample_token(ctx_main, &cur_p);
|
|
id = llama_sample_token(ctx_main, &cur_p);
|