|
@@ -149,11 +149,12 @@ static void sampler_queue(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-llama_token llama_sampling_sample(
|
|
|
|
|
|
|
+static llama_token llama_sampling_sample_impl(
|
|
|
struct llama_sampling_context * ctx_sampling,
|
|
struct llama_sampling_context * ctx_sampling,
|
|
|
struct llama_context * ctx_main,
|
|
struct llama_context * ctx_main,
|
|
|
struct llama_context * ctx_cfg,
|
|
struct llama_context * ctx_cfg,
|
|
|
- const int idx) {
|
|
|
|
|
|
|
+ const int idx,
|
|
|
|
|
+ bool is_resampling) { // Add a parameter to indicate if we are resampling
|
|
|
const llama_sampling_params & params = ctx_sampling->params;
|
|
const llama_sampling_params & params = ctx_sampling->params;
|
|
|
|
|
|
|
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
|
@@ -173,8 +174,17 @@ llama_token llama_sampling_sample(
|
|
|
|
|
|
|
|
llama_token id = 0;
|
|
llama_token id = 0;
|
|
|
|
|
|
|
|
|
|
+ // Get a pointer to the logits
|
|
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
|
|
|
|
|
|
|
|
|
+ // Declare original_logits at the beginning of the function scope
|
|
|
|
|
+ std::vector<float> original_logits;
|
|
|
|
|
+
|
|
|
|
|
+ if (!is_resampling) {
|
|
|
|
|
+ // Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
|
|
|
|
|
+ original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
// apply params.logit_bias map
|
|
// apply params.logit_bias map
|
|
|
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
|
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
|
|
logits[it->first] += it->second;
|
|
logits[it->first] += it->second;
|
|
@@ -210,7 +220,8 @@ llama_token llama_sampling_sample(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if (ctx_sampling->grammar != NULL) {
|
|
|
|
|
|
|
+ // If we are in the resampling phase, apply grammar checks before sampling logic
|
|
|
|
|
+ if (is_resampling && ctx_sampling->grammar != NULL) {
|
|
|
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
|
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -252,9 +263,40 @@ llama_token llama_sampling_sample(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ if (ctx_sampling->grammar != NULL && !is_resampling) {
|
|
|
|
|
+ // Create an array with a single token data element for the sampled id
|
|
|
|
|
+ llama_token_data single_token_data = {id, logits[id], 0.0f};
|
|
|
|
|
+ llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
|
|
|
|
|
+
|
|
|
|
|
+ // Apply grammar constraints to the single token
|
|
|
|
|
+ llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);
|
|
|
|
|
+
|
|
|
|
|
+ // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
|
|
|
|
|
+ bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
|
|
|
|
+
|
|
|
|
|
+ // If the token is not valid according to the grammar, perform resampling
|
|
|
|
|
+ if (!is_valid) {
|
|
|
|
|
+ LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
|
|
|
|
|
+
|
|
|
|
|
+ // Restore logits from the copy
|
|
|
|
|
+ std::copy(original_logits.begin(), original_logits.end(), logits);
|
|
|
|
|
+
|
|
|
|
|
+ return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
return id;
|
|
return id;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+llama_token llama_sampling_sample(
|
|
|
|
|
+ struct llama_sampling_context * ctx_sampling,
|
|
|
|
|
+ struct llama_context * ctx_main,
|
|
|
|
|
+ struct llama_context * ctx_cfg,
|
|
|
|
|
+ const int idx) {
|
|
|
|
|
+ // Call the implementation function with is_resampling set to false by default
|
|
|
|
|
+ return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
void llama_sampling_accept(
|
|
void llama_sampling_accept(
|
|
|
struct llama_sampling_context * ctx_sampling,
|
|
struct llama_sampling_context * ctx_sampling,
|
|
|
struct llama_context * ctx_main,
|
|
struct llama_context * ctx_main,
|