|
|
@@ -229,51 +229,48 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|
|
params.logit_bias.data()));
|
|
|
|
|
|
if (params.mirostat == 0) {
|
|
|
- if (params.top_n_sigma >= 0) {
|
|
|
- llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
|
|
- llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp));
|
|
|
- llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
|
|
|
- } else {
|
|
|
- for (const auto & cnstr : params.samplers) {
|
|
|
- switch (cnstr) {
|
|
|
- case COMMON_SAMPLER_TYPE_DRY:
|
|
|
- {
|
|
|
- std::vector<const char *> c_breakers;
|
|
|
- c_breakers.reserve(params.dry_sequence_breakers.size());
|
|
|
- for (const auto & str : params.dry_sequence_breakers) {
|
|
|
- c_breakers.push_back(str.c_str());
|
|
|
- }
|
|
|
-
|
|
|
- llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
|
|
+ for (const auto & cnstr : params.samplers) {
|
|
|
+ switch (cnstr) {
|
|
|
+ case COMMON_SAMPLER_TYPE_DRY:
|
|
|
+ {
|
|
|
+ std::vector<const char *> c_breakers;
|
|
|
+ c_breakers.reserve(params.dry_sequence_breakers.size());
|
|
|
+ for (const auto & str : params.dry_sequence_breakers) {
|
|
|
+ c_breakers.push_back(str.c_str());
|
|
|
}
|
|
|
- break;
|
|
|
- case COMMON_SAMPLER_TYPE_TOP_K:
|
|
|
- llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
|
|
- break;
|
|
|
- case COMMON_SAMPLER_TYPE_TOP_P:
|
|
|
- llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
|
|
- break;
|
|
|
- case COMMON_SAMPLER_TYPE_MIN_P:
|
|
|
- llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
|
|
- break;
|
|
|
- case COMMON_SAMPLER_TYPE_XTC:
|
|
|
- llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
|
|
- break;
|
|
|
- case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
|
|
- llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
|
|
- break;
|
|
|
- case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
|
|
- llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
|
|
- break;
|
|
|
- case COMMON_SAMPLER_TYPE_INFILL:
|
|
|
- llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
|
|
- break;
|
|
|
- case COMMON_SAMPLER_TYPE_PENALTIES:
|
|
|
- llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
|
|
- break;
|
|
|
- default:
|
|
|
- GGML_ASSERT(false && "unknown sampler type");
|
|
|
- }
|
|
|
+
|
|
|
+ llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ case COMMON_SAMPLER_TYPE_TOP_K:
|
|
|
+ llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
|
|
+ break;
|
|
|
+ case COMMON_SAMPLER_TYPE_TOP_P:
|
|
|
+ llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
|
|
+ break;
|
|
|
+ case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
|
|
+ llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
|
|
|
+ break;
|
|
|
+ case COMMON_SAMPLER_TYPE_MIN_P:
|
|
|
+ llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
|
|
+ break;
|
|
|
+ case COMMON_SAMPLER_TYPE_XTC:
|
|
|
+ llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
|
|
+ break;
|
|
|
+ case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
|
|
+ llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
|
|
+ break;
|
|
|
+ case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
|
|
+ llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
|
|
+ break;
|
|
|
+ case COMMON_SAMPLER_TYPE_INFILL:
|
|
|
+ llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
|
|
+ break;
|
|
|
+ case COMMON_SAMPLER_TYPE_PENALTIES:
|
|
|
+ llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ GGML_ASSERT(false && "unknown sampler type");
|
|
|
}
|
|
|
}
|
|
|
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
|
|
@@ -475,6 +472,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
|
|
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
|
|
|
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
|
|
|
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
|
|
|
+ case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
|
|
|
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
|
|
|
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
|
|
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
|
|
@@ -490,6 +488,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
|
|
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
|
|
|
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
|
|
|
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
|
|
|
+ case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
|
|
|
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
|
|
|
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
|
|
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
|
|
@@ -504,6 +503,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|
|
{ "dry", COMMON_SAMPLER_TYPE_DRY },
|
|
|
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
|
|
|
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
|
|
|
+ { "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
|
|
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
|
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
|
|
|
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
|
|
@@ -517,6 +517,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|
|
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
|
|
|
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
|
|
|
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
|
|
|
+ { "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
|
|
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
|
|
|
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
|
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
|
@@ -552,6 +553,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
|
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
|
|
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
|
|
|
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
|
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
|
|
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
|
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|