|
@@ -167,11 +167,11 @@ std::string common_params_sampling::print() const {
|
|
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
|
|
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
|
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
|
|
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
|
|
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
|
|
|
- "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
|
|
|
|
|
|
+ "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, adaptive_target = %.3f, adaptive_decay = %.3f",
|
|
|
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
|
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
|
|
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
|
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
|
|
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
|
|
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
|
|
|
- mirostat, mirostat_eta, mirostat_tau);
|
|
|
|
|
|
|
+ mirostat, mirostat_eta, mirostat_tau, adaptive_target, adaptive_decay);
|
|
|
|
|
|
|
|
return std::string(result);
|
|
return std::string(result);
|
|
|
}
|
|
}
|
|
@@ -255,6 +255,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (params.mirostat == 0) {
|
|
if (params.mirostat == 0) {
|
|
|
|
|
+
|
|
|
|
|
+ bool use_adaptive_p = false; // see below
|
|
|
|
|
+
|
|
|
for (const auto & cnstr : params.samplers) {
|
|
for (const auto & cnstr : params.samplers) {
|
|
|
switch (cnstr) {
|
|
switch (cnstr) {
|
|
|
case COMMON_SAMPLER_TYPE_DRY:
|
|
case COMMON_SAMPLER_TYPE_DRY:
|
|
@@ -264,43 +267,54 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
|
|
for (const auto & str : params.dry_sequence_breakers) {
|
|
for (const auto & str : params.dry_sequence_breakers) {
|
|
|
c_breakers.push_back(str.c_str());
|
|
c_breakers.push_back(str.c_str());
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
|
|
|
|
|
|
+ samplers.push_back(llama_sampler_init_dry(vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
|
|
}
|
|
}
|
|
|
break;
|
|
break;
|
|
|
case COMMON_SAMPLER_TYPE_TOP_K:
|
|
case COMMON_SAMPLER_TYPE_TOP_K:
|
|
|
- samplers.push_back(llama_sampler_init_top_k (params.top_k));
|
|
|
|
|
|
|
+ samplers.push_back(llama_sampler_init_top_k(params.top_k));
|
|
|
break;
|
|
break;
|
|
|
case COMMON_SAMPLER_TYPE_TOP_P:
|
|
case COMMON_SAMPLER_TYPE_TOP_P:
|
|
|
- samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
|
|
|
|
|
|
|
+ samplers.push_back(llama_sampler_init_top_p(params.top_p, params.min_keep));
|
|
|
break;
|
|
break;
|
|
|
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
|
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
|
|
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
|
|
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
|
|
|
break;
|
|
break;
|
|
|
case COMMON_SAMPLER_TYPE_MIN_P:
|
|
case COMMON_SAMPLER_TYPE_MIN_P:
|
|
|
- samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
|
|
|
|
|
|
|
+ samplers.push_back(llama_sampler_init_min_p(params.min_p, params.min_keep));
|
|
|
break;
|
|
break;
|
|
|
case COMMON_SAMPLER_TYPE_XTC:
|
|
case COMMON_SAMPLER_TYPE_XTC:
|
|
|
- samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
|
|
|
|
|
|
+ samplers.push_back(llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
|
|
break;
|
|
break;
|
|
|
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
|
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
|
|
- samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
|
|
|
|
|
|
|
+ samplers.push_back(llama_sampler_init_typical(params.typ_p, params.min_keep));
|
|
|
break;
|
|
break;
|
|
|
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
|
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
|
|
- samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
|
|
|
|
|
|
+ samplers.push_back(llama_sampler_init_temp_ext(params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
|
|
break;
|
|
break;
|
|
|
case COMMON_SAMPLER_TYPE_INFILL:
|
|
case COMMON_SAMPLER_TYPE_INFILL:
|
|
|
- samplers.push_back(llama_sampler_init_infill (vocab));
|
|
|
|
|
|
|
+ samplers.push_back(llama_sampler_init_infill(vocab));
|
|
|
break;
|
|
break;
|
|
|
case COMMON_SAMPLER_TYPE_PENALTIES:
|
|
case COMMON_SAMPLER_TYPE_PENALTIES:
|
|
|
- samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
|
|
|
|
|
|
+ samplers.push_back(llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
|
|
|
|
+ break;
|
|
|
|
|
+ case COMMON_SAMPLER_TYPE_ADAPTIVE_P:
|
|
|
|
|
+ // the `adaptive-p` sampler is like `dist` and `mirostat` in that it selects
|
|
|
|
|
+ // a single token, so we will add `dist` at the end of the chain by default,
|
|
|
|
|
+ // unless the user specifically included `adaptive-p`. we set this flag here
|
|
|
|
|
+ // so we know to add the sampler at the very end.
|
|
|
|
|
+ use_adaptive_p = true;
|
|
|
break;
|
|
break;
|
|
|
default:
|
|
default:
|
|
|
GGML_ASSERT(false && "unknown sampler type");
|
|
GGML_ASSERT(false && "unknown sampler type");
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- samplers.push_back(llama_sampler_init_dist(params.seed));
|
|
|
|
|
|
|
+ if (use_adaptive_p) {
|
|
|
|
|
+ // only if user explicitly included adaptive-p sampler
|
|
|
|
|
+ samplers.push_back(llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, params.seed));
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // default: sample from distribution
|
|
|
|
|
+ samplers.push_back(llama_sampler_init_dist(params.seed));
|
|
|
|
|
+ }
|
|
|
} else if (params.mirostat == 1) {
|
|
} else if (params.mirostat == 1) {
|
|
|
samplers.push_back(llama_sampler_init_temp(params.temp));
|
|
samplers.push_back(llama_sampler_init_temp(params.temp));
|
|
|
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
|
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
|
@@ -625,6 +639,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
|
|
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
|
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
|
|
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
|
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
|
|
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
|
|
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
|
|
|
|
|
+ case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return 'a';
|
|
|
default : return '?';
|
|
default : return '?';
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -641,6 +656,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
|
|
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
|
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
|
|
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
|
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
|
|
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
|
|
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
|
|
|
|
|
+ case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return "adaptive_p";
|
|
|
default : return "";
|
|
default : return "";
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -657,6 +673,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|
|
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
|
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
|
|
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
|
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
|
|
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
|
|
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
|
|
|
|
|
+ { "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
// since samplers names are written multiple ways
|
|
// since samplers names are written multiple ways
|
|
@@ -672,6 +689,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|
|
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
|
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
|
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
|
|
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
|
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
|
|
|
|
+ { "adaptive-p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
std::vector<common_sampler_type> samplers;
|
|
std::vector<common_sampler_type> samplers;
|
|
@@ -708,6 +726,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
|
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
|
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
|
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
|
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
|
|
|
|
|
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_ADAPTIVE_P), COMMON_SAMPLER_TYPE_ADAPTIVE_P },
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
std::vector<common_sampler_type> samplers;
|
|
std::vector<common_sampler_type> samplers;
|