|
|
@@ -421,39 +421,6 @@ void llama_sampler_free(struct llama_sampler * smpl) {
|
|
|
delete smpl;
|
|
|
}
|
|
|
|
|
|
-llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
|
|
|
- const auto * logits = llama_get_logits_ith(ctx, idx);
|
|
|
-
|
|
|
- const llama_model * model = llama_get_model(ctx);
|
|
|
- const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
|
-
|
|
|
- const int n_vocab = llama_vocab_n_tokens(vocab);
|
|
|
-
|
|
|
- // TODO: do not allocate each time
|
|
|
- std::vector<llama_token_data> cur;
|
|
|
- cur.reserve(n_vocab);
|
|
|
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
|
- cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
|
|
- }
|
|
|
-
|
|
|
- llama_token_data_array cur_p = {
|
|
|
- /* .data = */ cur.data(),
|
|
|
- /* .size = */ cur.size(),
|
|
|
- /* .selected = */ -1,
|
|
|
- /* .sorted = */ false,
|
|
|
- };
|
|
|
-
|
|
|
- llama_sampler_apply(smpl, &cur_p);
|
|
|
-
|
|
|
- GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
|
|
|
-
|
|
|
- auto token = cur_p.data[cur_p.selected].id;
|
|
|
-
|
|
|
- llama_sampler_accept(smpl, token);
|
|
|
-
|
|
|
- return token;
|
|
|
-}
|
|
|
-
|
|
|
// sampler chain
|
|
|
|
|
|
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
|
|
|
@@ -527,12 +494,56 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
|
|
|
/* .ctx = */ new llama_sampler_chain {
|
|
|
/* .params = */ params,
|
|
|
/* .samplers = */ {},
|
|
|
+ /* .cur = */ {},
|
|
|
/* .t_sample_us = */ 0,
|
|
|
/* .n_sample = */ 0,
|
|
|
}
|
|
|
);
|
|
|
}
|
|
|
|
|
|
+llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
|
|
|
+ const auto * logits = llama_get_logits_ith(ctx, idx);
|
|
|
+
|
|
|
+ const llama_model * model = llama_get_model(ctx);
|
|
|
+ const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
|
+
|
|
|
+ const int n_vocab = llama_vocab_n_tokens(vocab);
|
|
|
+
|
|
|
+ // use pre-allocated buffer from chain if available, otherwise allocate locally
|
|
|
+ std::vector<llama_token_data> * cur_ptr;
|
|
|
+ std::vector<llama_token_data> cur_local;
|
|
|
+
|
|
|
+ if (smpl->iface == &llama_sampler_chain_i) {
|
|
|
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
|
+ cur_ptr = &chain->cur;
|
|
|
+ } else {
|
|
|
+ cur_ptr = &cur_local;
|
|
|
+ }
|
|
|
+
|
|
|
+ auto & cur = *cur_ptr;
|
|
|
+ cur.resize(n_vocab);
|
|
|
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
|
+ cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
|
|
+ }
|
|
|
+
|
|
|
+ llama_token_data_array cur_p = {
|
|
|
+ /* .data = */ cur.data(),
|
|
|
+ /* .size = */ cur.size(),
|
|
|
+ /* .selected = */ -1,
|
|
|
+ /* .sorted = */ false,
|
|
|
+ };
|
|
|
+
|
|
|
+ llama_sampler_apply(smpl, &cur_p);
|
|
|
+
|
|
|
+ GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
|
|
|
+
|
|
|
+ auto token = cur_p.data[cur_p.selected].id;
|
|
|
+
|
|
|
+ llama_sampler_accept(smpl, token);
|
|
|
+
|
|
|
+ return token;
|
|
|
+}
|
|
|
+
|
|
|
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
|
|
auto * p = (llama_sampler_chain *) chain->ctx;
|
|
|
p->samplers.push_back(smpl);
|