|
|
@@ -99,6 +99,54 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
|
|
return std::string(result);
|
|
|
}
|
|
|
|
|
|
+std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
|
|
+ std::string result = "CFG -> Penalties ";
|
|
|
+ if (params.mirostat == 0) {
|
|
|
+ for (auto s : params.samplers_sequence) {
|
|
|
+ switch (s) {
|
|
|
+ case 'k': result += "-> top_k "; break;
|
|
|
+ case 'f': result += "-> tfs_z "; break;
|
|
|
+ case 'y': result += "-> typical_p "; break;
|
|
|
+ case 'p': result += "-> top_p "; break;
|
|
|
+ case 'm': result += "-> min_p "; break;
|
|
|
+ case 't': result += "-> temp "; break;
|
|
|
+ default : break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else result += "-> mirostat ";
|
|
|
+
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+// no reasons to expose this function in header
|
|
|
+void sampler_queue(
|
|
|
+ struct llama_context * ctx_main,
|
|
|
+ const llama_sampling_params & params,
|
|
|
+ llama_token_data_array & cur_p,
|
|
|
+ size_t & min_keep) {
|
|
|
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
|
|
+
|
|
|
+ const float temp = params.temp;
|
|
|
+ const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
|
|
+ const float top_p = params.top_p;
|
|
|
+ const float min_p = params.min_p;
|
|
|
+ const float tfs_z = params.tfs_z;
|
|
|
+ const float typical_p = params.typical_p;
|
|
|
+ const std::string & samplers_sequence = params.samplers_sequence;
|
|
|
+
|
|
|
+ for (auto s : samplers_sequence) {
|
|
|
+ switch (s){
|
|
|
+ case 'k': llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
|
|
|
+ case 'f': llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
|
|
|
+ case 'y': llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
|
|
|
+ case 'p': llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
|
|
|
+ case 'm': llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
|
|
|
+ case 't': llama_sample_temp (ctx_main, &cur_p, temp); break;
|
|
|
+ default : break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
llama_token llama_sampling_sample(
|
|
|
struct llama_sampling_context * ctx_sampling,
|
|
|
struct llama_context * ctx_main,
|
|
|
@@ -109,11 +157,6 @@ llama_token llama_sampling_sample(
|
|
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
|
|
|
|
|
const float temp = params.temp;
|
|
|
- const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
|
|
- const float top_p = params.top_p;
|
|
|
- const float min_p = params.min_p;
|
|
|
- const float tfs_z = params.tfs_z;
|
|
|
- const float typical_p = params.typical_p;
|
|
|
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
|
|
const float penalty_repeat = params.penalty_repeat;
|
|
|
const float penalty_freq = params.penalty_freq;
|
|
|
@@ -188,12 +231,7 @@ llama_token llama_sampling_sample(
|
|
|
// temperature sampling
|
|
|
size_t min_keep = std::max(1, params.n_probs);
|
|
|
|
|
|
- llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
|
|
|
- llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
|
|
|
- llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
|
|
|
- llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
|
|
|
- llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
|
|
|
- llama_sample_temp (ctx_main, &cur_p, temp);
|
|
|
+ sampler_queue(ctx_main, params, cur_p, min_keep);
|
|
|
|
|
|
id = llama_sample_token(ctx_main, &cur_p);
|
|
|
|