Parcourir la source

common : refactor common_sampler + grammar logic changes (#17937)

* common : refactor common_sampler + grammar logic changes

* tests : increase max_tokens to get needed response

* batched : fix uninitialized samplers
Georgi Gerganov il y a 1 mois
Parent
commit
254098a279

+ 1 - 1
common/arg.cpp

@@ -1415,7 +1415,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.sampling.top_k = value;
             params.sampling.top_k = value;
             params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K;
             params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K;
         }
         }
-    ).set_sparam());
+    ).set_sparam().set_env("LLAMA_ARG_TOP_K"));
     add_opt(common_arg(
     add_opt(common_arg(
         {"--top-p"}, "N",
         {"--top-p"}, "N",
         string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
         string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),

+ 125 - 65
common/common.cpp

@@ -1013,31 +1013,40 @@ bool tty_can_use_colors() {
 // Model utils
 // Model utils
 //
 //
 
 
-static inline void common_init_sampler_from_model(
+// TODO: move to common/sampling
+static void common_init_sampler_from_model(
     const llama_model * model,
     const llama_model * model,
     common_params_sampling & sparams) {
     common_params_sampling & sparams) {
 
 
     const uint64_t config = sparams.user_sampling_config;
     const uint64_t config = sparams.user_sampling_config;
 
 
     auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
     auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
-        if (config & user_config) return;
+        if (config & user_config) {
+            return;
+        }
 
 
         char buf[64] = {0};
         char buf[64] = {0};
         if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
         if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
             char * end = nullptr;
             char * end = nullptr;
             int32_t v = strtol(buf, &end, 10);
             int32_t v = strtol(buf, &end, 10);
-            if (end && end != buf) dst = v;
+            if (end && end != buf) {
+                dst = v;
+            }
         }
         }
     };
     };
 
 
     auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
     auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
-        if (config & user_config) return;
+        if (config & user_config) {
+            return;
+        }
 
 
         char buf[128] = {0};
         char buf[128] = {0};
         if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
         if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
             char * end = nullptr;
             char * end = nullptr;
             float v = strtof(buf, &end);
             float v = strtof(buf, &end);
-            if (end && end != buf) dst = v;
+            if (end && end != buf) {
+                dst = v;
+            }
         }
         }
     };
     };
 
 
@@ -1065,31 +1074,122 @@ static inline void common_init_sampler_from_model(
     get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA),    sparams.mirostat_eta,    common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
     get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA),    sparams.mirostat_eta,    common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
 }
 }
 
 
-struct common_init_result common_init_from_params(common_params & params) {
-    common_init_result iparams;
-    auto mparams = common_model_params_to_llama(params);
+struct common_init_result::impl {
+    impl() = default;
+    ~impl() = default;
+
+    llama_model_ptr   model;
+    llama_context_ptr context;
+
+    std::vector<llama_adapter_lora_ptr> lora;
+
+    std::vector<common_sampler_ptr> samplers;
+};
+
+common_init_result::common_init_result(common_params & params) :
+    pimpl(new impl{}) {
+    const auto mparams = common_model_params_to_llama(params);
 
 
     llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
     llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
     if (model == NULL) {
     if (model == NULL) {
-        LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
-            __func__, params.model.path.c_str());
-        return iparams;
+        return;
     }
     }
 
 
-    common_init_sampler_from_model(model, params.sampling);
+    pimpl->model.reset(model);
 
 
     const llama_vocab * vocab = llama_model_get_vocab(model);
     const llama_vocab * vocab = llama_model_get_vocab(model);
 
 
+    // updates params.sampling
+    // TODO: fix naming
+    common_init_sampler_from_model(model, params.sampling);
+
     auto cparams = common_context_params_to_llama(params);
     auto cparams = common_context_params_to_llama(params);
 
 
+    if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
+        LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
+        params.sampling.ignore_eos = false;
+    }
+
+    // initialize once
+    for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
+        if (llama_vocab_is_eog(vocab, i)) {
+            LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
+            params.sampling.logit_bias_eog.push_back({i, -INFINITY});
+        }
+    }
+
+    if (params.sampling.ignore_eos) {
+        // add EOG biases to the active set of logit biases
+        params.sampling.logit_bias.insert(
+                params.sampling.logit_bias.end(),
+                params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
+    }
+
+    //if (params.sampling.penalty_last_n == -1) {
+    //    LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
+    //    params.sampling.penalty_last_n = llama_n_ctx(lctx);
+    //}
+
+    //if (params.sampling.dry_penalty_last_n == -1) {
+    //    LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
+    //    params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
+    //}
+
+    pimpl->samplers.resize(cparams.n_seq_max);
+
+    for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
+        pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
+    }
+
     llama_context * lctx = llama_init_from_model(model, cparams);
     llama_context * lctx = llama_init_from_model(model, cparams);
+    if (lctx == NULL) {
+        LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
+                __func__, params.model.path.c_str());
+        return;
+    }
+
+    pimpl->context.reset(lctx);
+}
+
+llama_model * common_init_result::model() {
+    return pimpl->model.get();
+}
+
+llama_context * common_init_result::context() {
+    return pimpl->context.get();
+}
+
+common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
+    return pimpl->samplers[seq_id].get();
+}
+
+std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
+    return pimpl->lora;
+}
+
+void common_init_result::free_context() {
+    pimpl->context.reset();
+}
+
+common_init_result_ptr common_init_from_params(common_params & params) {
+    common_init_result_ptr res(new common_init_result(params));
+
+    llama_model * model = res->model();
+    if (model == NULL) {
+        LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
+            __func__, params.model.path.c_str());
+        return res;
+    }
+
+    llama_context * lctx = res->context();
     if (lctx == NULL) {
     if (lctx == NULL) {
         LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
         LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
             __func__, params.model.path.c_str());
             __func__, params.model.path.c_str());
-        llama_model_free(model);
-        return iparams;
+        return res;
     }
     }
 
 
+    const llama_vocab * vocab = llama_model_get_vocab(model);
+
     if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
     if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
         LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
         LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
         params.ctx_shift = false;
         params.ctx_shift = false;
@@ -1101,10 +1201,7 @@ struct common_init_result common_init_from_params(common_params & params) {
 
 
         const auto cvec = common_control_vector_load(params.control_vectors);
         const auto cvec = common_control_vector_load(params.control_vectors);
         if (cvec.n_embd == -1) {
         if (cvec.n_embd == -1) {
-            llama_free(lctx);
-            llama_model_free(model);
-
-            return iparams;
+            return res;
         }
         }
 
 
         int err = llama_apply_adapter_cvec(
         int err = llama_apply_adapter_cvec(
@@ -1115,10 +1212,7 @@ struct common_init_result common_init_from_params(common_params & params) {
                 params.control_vector_layer_start,
                 params.control_vector_layer_start,
                 params.control_vector_layer_end);
                 params.control_vector_layer_end);
         if (err) {
         if (err) {
-            llama_free(lctx);
-            llama_model_free(model);
-
-            return iparams;
+            return res;
         }
         }
     }
     }
 
 
@@ -1142,10 +1236,7 @@ struct common_init_result common_init_from_params(common_params & params) {
         }
         }
 
 
         if (!ok) {
         if (!ok) {
-            llama_free(lctx);
-            llama_model_free(model);
-
-            return iparams;
+            return res;
         }
         }
     }
     }
 
 
@@ -1155,9 +1246,7 @@ struct common_init_result common_init_from_params(common_params & params) {
         lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
         lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
         if (lora == nullptr) {
         if (lora == nullptr) {
             LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
             LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
-            llama_free(lctx);
-            llama_model_free(model);
-            return iparams;
+            return res;
         }
         }
 
 
         char buf[1024];
         char buf[1024];
@@ -1166,43 +1255,13 @@ struct common_init_result common_init_from_params(common_params & params) {
         la.task_name = buf;
         la.task_name = buf;
         llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
         llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
         la.prompt_prefix = buf;
         la.prompt_prefix = buf;
-        iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
+        res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters
     }
     }
 
 
     if (!params.lora_init_without_apply) {
     if (!params.lora_init_without_apply) {
         common_set_adapter_lora(lctx, params.lora_adapters);
         common_set_adapter_lora(lctx, params.lora_adapters);
     }
     }
 
 
-    if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
-        LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
-        params.sampling.ignore_eos = false;
-    }
-
-    // initialize once
-    for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
-        if (llama_vocab_is_eog(vocab, i)) {
-            LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
-            params.sampling.logit_bias_eog.push_back({i, -INFINITY});
-        }
-    }
-
-    if (params.sampling.ignore_eos) {
-        // add EOG biases to the active set of logit biases
-        params.sampling.logit_bias.insert(
-                params.sampling.logit_bias.end(),
-                params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
-    }
-
-    if (params.sampling.penalty_last_n == -1) {
-        LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
-        params.sampling.penalty_last_n = llama_n_ctx(lctx);
-    }
-
-    if (params.sampling.dry_penalty_last_n == -1) {
-        LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
-        params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
-    }
-
     if (params.warmup) {
     if (params.warmup) {
         LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
         LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
 
 
@@ -1241,12 +1300,11 @@ struct common_init_result common_init_from_params(common_params & params) {
         llama_set_warmup(lctx, false);
         llama_set_warmup(lctx, false);
     }
     }
 
 
-    iparams.model.reset(model);
-    iparams.context.reset(lctx);
-
-    return iparams;
+    return res;
 }
 }
 
 
+common_init_result::~common_init_result() = default;
+
 std::string get_model_endpoint() {
 std::string get_model_endpoint() {
     const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
     const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
     // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
     // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
@@ -1255,7 +1313,9 @@ std::string get_model_endpoint() {
     std::string model_endpoint = "https://huggingface.co/";
     std::string model_endpoint = "https://huggingface.co/";
     if (endpoint_env) {
     if (endpoint_env) {
         model_endpoint = endpoint_env;
         model_endpoint = endpoint_env;
-        if (model_endpoint.back() != '/') model_endpoint += '/';
+        if (model_endpoint.back() != '/') {
+            model_endpoint += '/';
+        }
     }
     }
     return model_endpoint;
     return model_endpoint;
 }
 }

+ 23 - 6
common/common.h

@@ -195,7 +195,6 @@ struct common_params_sampling {
 
 
     std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"};     // default sequence breakers for DRY
     std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"};     // default sequence breakers for DRY
 
 
-
     std::vector<enum common_sampler_type> samplers = {
     std::vector<enum common_sampler_type> samplers = {
         COMMON_SAMPLER_TYPE_PENALTIES,
         COMMON_SAMPLER_TYPE_PENALTIES,
         COMMON_SAMPLER_TYPE_DRY,
         COMMON_SAMPLER_TYPE_DRY,
@@ -216,6 +215,10 @@ struct common_params_sampling {
     std::vector<llama_logit_bias> logit_bias;     // logit biases to apply
     std::vector<llama_logit_bias> logit_bias;     // logit biases to apply
     std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
     std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
 
 
+    bool has_logit_bias() const {
+        return !logit_bias.empty();
+    }
+
     // print the parameters into a string
     // print the parameters into a string
     std::string print() const;
     std::string print() const;
 };
 };
@@ -669,15 +672,29 @@ bool tty_can_use_colors();
 // Model utils
 // Model utils
 //
 //
 
 
-// note: defines object's lifetime
+struct common_sampler;
+
+// note: defines the model, context, samplers, ets. lifetimes
 struct common_init_result {
 struct common_init_result {
-    llama_model_ptr   model;
-    llama_context_ptr context;
+    common_init_result(common_params & params);
+    ~common_init_result();
 
 
-    std::vector<llama_adapter_lora_ptr> lora;
+    llama_model * model();
+    llama_context * context();
+    common_sampler * sampler(llama_seq_id seq_id);
+
+    std::vector<llama_adapter_lora_ptr> & lora();
+
+    void free_context();
+
+private:
+    struct impl;
+    std::unique_ptr<impl> pimpl;
 };
 };
 
 
-struct common_init_result     common_init_from_params(common_params & params);
+using common_init_result_ptr = std::unique_ptr<common_init_result>;
+
+common_init_result_ptr common_init_from_params(common_params & params);
 
 
 struct llama_model_params     common_model_params_to_llama  (      common_params & params);
 struct llama_model_params     common_model_params_to_llama  (      common_params & params);
 struct llama_context_params   common_context_params_to_llama(const common_params & params);
 struct llama_context_params   common_context_params_to_llama(const common_params & params);

+ 91 - 92
common/sampling.cpp

@@ -104,9 +104,10 @@ struct ring_buffer {
 struct common_sampler {
 struct common_sampler {
     common_params_sampling params;
     common_params_sampling params;
 
 
-    struct llama_sampler * grmr;
     struct llama_sampler * chain;
     struct llama_sampler * chain;
 
 
+    bool grammar;
+
     ring_buffer<llama_token> prev;
     ring_buffer<llama_token> prev;
 
 
     std::vector<llama_token_data> cur;
     std::vector<llama_token_data> cur;
@@ -116,7 +117,6 @@ struct common_sampler {
     void reset() {
     void reset() {
         prev.clear();
         prev.clear();
 
 
-        llama_sampler_reset(grmr);
         llama_sampler_reset(chain);
         llama_sampler_reset(chain);
     }
     }
 
 
@@ -167,10 +167,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
 
 
     lparams.no_perf = params.no_perf;
     lparams.no_perf = params.no_perf;
 
 
-    struct llama_sampler * grmr;
+    llama_sampler * chain = llama_sampler_chain_init(lparams);
+
+    bool grammar = false;
+    std::vector<llama_sampler *> samplers;
+
     if (params.grammar.compare(0, 11, "%llguidance") == 0) {
     if (params.grammar.compare(0, 11, "%llguidance") == 0) {
 #ifdef LLAMA_USE_LLGUIDANCE
 #ifdef LLAMA_USE_LLGUIDANCE
-        grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
+        samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
+        grammar = true;
 #else
 #else
         GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
         GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
 #endif // LLAMA_USE_LLGUIDANCE
 #endif // LLAMA_USE_LLGUIDANCE
@@ -217,30 +222,23 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
             trigger_patterns_c.push_back(regex.c_str());
             trigger_patterns_c.push_back(regex.c_str());
         }
         }
 
 
-        grmr = params.grammar_lazy
-             ? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
-                                                        trigger_patterns_c.data(), trigger_patterns_c.size(),
-                                                        trigger_tokens.data(), trigger_tokens.size())
-             :      llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
-        if (!grmr) {
-            return nullptr;
+        if (!params.grammar.empty()) {
+             if (params.grammar_lazy) {
+                 samplers.push_back(
+                         llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
+                             trigger_patterns_c.data(), trigger_patterns_c.size(),
+                             trigger_tokens.data(),     trigger_tokens.size()));
+             } else {
+                 samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
+             }
+
+             grammar = true;
         }
         }
     }
     }
 
 
-    auto * result = new common_sampler {
-        /* .params = */ params,
-        /* .grmr   = */ grmr,
-        /* .chain  = */ llama_sampler_chain_init(lparams),
-        /* .prev   = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
-        /* .cur    = */ {},
-        /* .cur_p  = */ {},
-    };
-
-    llama_sampler_chain_add(result->chain,
-            llama_sampler_init_logit_bias(
-                llama_vocab_n_tokens(vocab),
-                params.logit_bias.size(),
-                params.logit_bias.data()));
+    if (params.has_logit_bias()) {
+        samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
+    }
 
 
     if (params.mirostat == 0) {
     if (params.mirostat == 0) {
         for (const auto & cnstr : params.samplers) {
         for (const auto & cnstr : params.samplers) {
@@ -253,58 +251,70 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
                             c_breakers.push_back(str.c_str());
                             c_breakers.push_back(str.c_str());
                         }
                         }
 
 
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_dry      (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
+                        samplers.push_back(llama_sampler_init_dry    (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
                     }
                     }
                     break;
                     break;
                 case COMMON_SAMPLER_TYPE_TOP_K:
                 case COMMON_SAMPLER_TYPE_TOP_K:
-                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_k       (params.top_k));
+                    samplers.push_back(llama_sampler_init_top_k      (params.top_k));
                     break;
                     break;
                 case COMMON_SAMPLER_TYPE_TOP_P:
                 case COMMON_SAMPLER_TYPE_TOP_P:
-                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_p       (params.top_p, params.min_keep));
+                    samplers.push_back(llama_sampler_init_top_p      (params.top_p, params.min_keep));
                     break;
                     break;
                 case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
                 case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
-                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
+                    samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
                     break;
                     break;
                 case COMMON_SAMPLER_TYPE_MIN_P:
                 case COMMON_SAMPLER_TYPE_MIN_P:
-                    llama_sampler_chain_add(result->chain, llama_sampler_init_min_p       (params.min_p, params.min_keep));
+                    samplers.push_back(llama_sampler_init_min_p      (params.min_p, params.min_keep));
                     break;
                     break;
                 case COMMON_SAMPLER_TYPE_XTC:
                 case COMMON_SAMPLER_TYPE_XTC:
-                    llama_sampler_chain_add(result->chain, llama_sampler_init_xtc         (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
+                    samplers.push_back(llama_sampler_init_xtc        (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
                     break;
                     break;
                 case COMMON_SAMPLER_TYPE_TYPICAL_P:
                 case COMMON_SAMPLER_TYPE_TYPICAL_P:
-                    llama_sampler_chain_add(result->chain, llama_sampler_init_typical     (params.typ_p, params.min_keep));
+                    samplers.push_back(llama_sampler_init_typical    (params.typ_p, params.min_keep));
                     break;
                     break;
                 case COMMON_SAMPLER_TYPE_TEMPERATURE:
                 case COMMON_SAMPLER_TYPE_TEMPERATURE:
-                    llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext    (params.temp, params.dynatemp_range, params.dynatemp_exponent));
+                    samplers.push_back(llama_sampler_init_temp_ext   (params.temp, params.dynatemp_range, params.dynatemp_exponent));
                     break;
                     break;
                 case COMMON_SAMPLER_TYPE_INFILL:
                 case COMMON_SAMPLER_TYPE_INFILL:
-                    llama_sampler_chain_add(result->chain, llama_sampler_init_infill      (vocab));
+                    samplers.push_back(llama_sampler_init_infill     (vocab));
                     break;
                     break;
                 case COMMON_SAMPLER_TYPE_PENALTIES:
                 case COMMON_SAMPLER_TYPE_PENALTIES:
-                    llama_sampler_chain_add(result->chain, llama_sampler_init_penalties   (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
+                    samplers.push_back(llama_sampler_init_penalties  (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
                     break;
                     break;
                 default:
                 default:
                     GGML_ASSERT(false && "unknown sampler type");
                     GGML_ASSERT(false && "unknown sampler type");
             }
             }
         }
         }
-        llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
+
+        samplers.push_back(llama_sampler_init_dist(params.seed));
     } else if (params.mirostat == 1) {
     } else if (params.mirostat == 1) {
-        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
-        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
+        samplers.push_back(llama_sampler_init_temp(params.temp));
+        samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
     } else if (params.mirostat == 2) {
     } else if (params.mirostat == 2) {
-        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
-        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
+        samplers.push_back(llama_sampler_init_temp(params.temp));
+        samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
     } else {
     } else {
         GGML_ASSERT(false && "unknown mirostat version");
         GGML_ASSERT(false && "unknown mirostat version");
     }
     }
 
 
+    for (auto * smpl : samplers) {
+        llama_sampler_chain_add(chain, smpl);
+    }
+
+    auto * result = new common_sampler {
+        /* .params  = */ params,
+        /* .chain   = */ chain,
+        /* .grammar = */ grammar,
+        /* .prev    = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
+        /* .cur     = */ {},
+        /* .cur_p   = */ {},
+    };
+
     return result;
     return result;
 }
 }
 
 
 void common_sampler_free(struct common_sampler * gsmpl) {
 void common_sampler_free(struct common_sampler * gsmpl) {
     if (gsmpl) {
     if (gsmpl) {
-        llama_sampler_free(gsmpl->grmr);
-
         llama_sampler_free(gsmpl->chain);
         llama_sampler_free(gsmpl->chain);
 
 
         delete gsmpl;
         delete gsmpl;
@@ -314,11 +324,24 @@ void common_sampler_free(struct common_sampler * gsmpl) {
 void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
 void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
     const auto tm = gsmpl->tm();
     const auto tm = gsmpl->tm();
 
 
-    if (accept_grammar) {
-        llama_sampler_accept(gsmpl->grmr, token);
-    }
+    if (gsmpl->grammar) {
+        const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
 
 
-    llama_sampler_accept(gsmpl->chain, token);
+        for (int i = 0; i < n_smpl; i++) {
+            auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
+
+            // the grammar sampler is always the first one
+            if (i == 0) {
+                if (accept_grammar) {
+                    llama_sampler_accept(smpl, token);
+                }
+            } else {
+                llama_sampler_accept(smpl, token);
+            }
+        }
+    } else {
+        llama_sampler_accept(gsmpl->chain, token);
+    }
 
 
     gsmpl->prev.push_back(token);
     gsmpl->prev.push_back(token);
 }
 }
@@ -329,12 +352,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
 
 
 struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
 struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
     return new common_sampler {
     return new common_sampler {
-        /* .params = */ gsmpl->params,
-        /* .grmr   = */ llama_sampler_clone(gsmpl->grmr),
-        /* .chain  = */ llama_sampler_clone(gsmpl->chain),
-        /* .prev   = */ gsmpl->prev,
-        /* .cur    = */ gsmpl->cur,
-        /* .cur_p  = */ gsmpl->cur_p,
+        /* .params  = */ gsmpl->params,
+        /* .chain   = */ llama_sampler_clone(gsmpl->chain),
+        /* .grammar = */ gsmpl->grammar,
+        /* .prev    = */ gsmpl->prev,
+        /* .cur     = */ gsmpl->cur,
+        /* .cur_p   = */ gsmpl->cur_p,
     };
     };
 }
 }
 
 
@@ -383,58 +406,33 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
     }
     }
 }
 }
 
 
-llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
+struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
+    return gsmpl->chain;
+}
+
+llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
     llama_synchronize(ctx);
     llama_synchronize(ctx);
 
 
     // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
     // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
     const auto tm = gsmpl->tm();
     const auto tm = gsmpl->tm();
 
 
-    gsmpl->set_logits(ctx, idx);
+    llama_token id = LLAMA_TOKEN_NULL;
 
 
-    auto & grmr  = gsmpl->grmr;
     auto & chain = gsmpl->chain;
     auto & chain = gsmpl->chain;
     auto & cur_p = gsmpl->cur_p; // initialized by set_logits
     auto & cur_p = gsmpl->cur_p; // initialized by set_logits
 
 
-    if (grammar_first) {
-        llama_sampler_apply(grmr, &cur_p);
-    }
+    gsmpl->set_logits(ctx, idx);
 
 
     llama_sampler_apply(chain, &cur_p);
     llama_sampler_apply(chain, &cur_p);
 
 
     GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
     GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
 
 
-    const llama_token id = cur_p.data[cur_p.selected].id;
-
-    if (grammar_first) {
-        return id;
-    }
-
-    // check if it the sampled token fits the grammar
-    {
-        llama_token_data       single_token_data       = { id, 1.0f, 0.0f };
-        llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
-
-        llama_sampler_apply(grmr, &single_token_data_array);
-
-        const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
-        if (is_valid) {
-            return id;
-        }
-    }
-
-    // resampling:
-    // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
-    gsmpl->set_logits(ctx, idx);
-
-    llama_sampler_apply(grmr,  &cur_p);
-    llama_sampler_apply(chain, &cur_p);
-
-    GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
+    id = cur_p.data[cur_p.selected].id;
 
 
-    return cur_p.data[cur_p.selected].id;
+    return id;
 }
 }
 
 
-std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
     GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
     GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
 
 
     std::vector<llama_token> result;
     std::vector<llama_token> result;
@@ -442,7 +440,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
 
 
     size_t i = 0;
     size_t i = 0;
     for (; i < draft.size(); i++) {
     for (; i < draft.size(); i++) {
-        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
+        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
 
 
         common_sampler_accept(gsmpl, id, true);
         common_sampler_accept(gsmpl, id, true);
 
 
@@ -454,7 +452,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
     }
     }
 
 
     if (i == draft.size()) {
     if (i == draft.size()) {
-        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
+        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
 
 
         common_sampler_accept(gsmpl, id, true);
         common_sampler_accept(gsmpl, id, true);
 
 
@@ -464,13 +462,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
     return result;
     return result;
 }
 }
 
 
-std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
     std::vector<int> idxs(draft.size() + 1);
     std::vector<int> idxs(draft.size() + 1);
     for (size_t i = 0; i < idxs.size(); ++i) {
     for (size_t i = 0; i < idxs.size(); ++i) {
         idxs[i] = i;
         idxs[i] = i;
     }
     }
 
 
-    return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
+    return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
 }
 }
 
 
 uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
 uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
@@ -515,7 +513,8 @@ std::string common_sampler_print(const struct common_sampler * gsmpl) {
 
 
     for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
     for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
         const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
         const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
-        result += std::string("-> ") + llama_sampler_name(smpl) + " ";
+        result += std::string("-> ");
+        result += std::string(llama_sampler_name(smpl)) + " ";
     }
     }
 
 
     return result;
     return result;

+ 11 - 6
common/sampling.h

@@ -48,6 +48,8 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
 // arguments can be nullptr to skip printing
 // arguments can be nullptr to skip printing
 void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
 void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
 
 
+struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
+
 // extended sampling implementation:
 // extended sampling implementation:
 //
 //
 // - set logits
 // - set logits
@@ -55,10 +57,7 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
 // - check if the token fits the grammar (if any)
 // - check if the token fits the grammar (if any)
 // - if not: resample by first applying the grammar constraints and then sampling again (slower path)
 // - if not: resample by first applying the grammar constraints and then sampling again (slower path)
 //
 //
-// if grammar_first is true, the grammar is applied before the samplers (slower)
-// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
-//
-llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
+llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
 
 
 // generalized version of common_sampler_sample
 // generalized version of common_sampler_sample
 //
 //
@@ -76,10 +75,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
 //
 //
 // returns at least 1 token, up to idxs.size()
 // returns at least 1 token, up to idxs.size()
 //
 //
-std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft);
 
 
 // assume idxs == [ 0, 1, 2, ..., draft.size() ]
 // assume idxs == [ 0, 1, 2, ..., draft.size() ]
-std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft);
 
 
 uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
 uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
 
 
@@ -107,3 +106,9 @@ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std:
 
 
 llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
 llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
                 const char * grammar_kind, const char * grammar_data);
                 const char * grammar_kind, const char * grammar_data);
+
+struct common_sampler_deleter {
+    void operator()(common_sampler * s) { common_sampler_free(s); }
+};
+
+typedef std::unique_ptr<common_sampler, common_sampler_deleter> common_sampler_ptr;

+ 1 - 1
common/speculative.cpp

@@ -315,7 +315,7 @@ llama_tokens common_speculative_gen_draft(
     for (int i = 0; i < params.n_draft; ++i) {
     for (int i = 0; i < params.n_draft; ++i) {
         common_batch_clear(batch);
         common_batch_clear(batch);
 
 
-        common_sampler_sample(smpl, ctx_dft, 0, true);
+        common_sampler_sample(smpl, ctx_dft, 0);
 
 
         const auto * cur_p = common_sampler_get_candidates(smpl, true);
         const auto * cur_p = common_sampler_get_candidates(smpl, true);
 
 

+ 20 - 10
examples/batched/batched.cpp

@@ -2,6 +2,7 @@
 #include "common.h"
 #include "common.h"
 #include "log.h"
 #include "log.h"
 #include "llama.h"
 #include "llama.h"
+#include "sampling.h"
 
 
 #include <algorithm>
 #include <algorithm>
 #include <cstdio>
 #include <cstdio>
@@ -64,17 +65,23 @@ int main(int argc, char ** argv) {
     ctx_params.n_ctx   = n_kv_req;
     ctx_params.n_ctx   = n_kv_req;
     ctx_params.n_batch = std::max(n_predict, n_parallel);
     ctx_params.n_batch = std::max(n_predict, n_parallel);
 
 
-    llama_context * ctx = llama_init_from_model(model, ctx_params);
-
     auto sparams = llama_sampler_chain_default_params();
     auto sparams = llama_sampler_chain_default_params();
     sparams.no_perf = false;
     sparams.no_perf = false;
 
 
-    llama_sampler * smpl = llama_sampler_chain_init(sparams);
+    std::vector<llama_sampler *> samplers;
+
+    for (int32_t i = 0; i < n_parallel; ++i) {
+        llama_sampler * smpl = llama_sampler_chain_init(sparams);
+
+        llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
+        llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
+        llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
+        llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
 
 
-    llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
-    llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
-    llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
-    llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
+        samplers.push_back(smpl);
+    }
+
+    llama_context * ctx = llama_init_from_model(model, ctx_params);
 
 
     if (ctx == NULL) {
     if (ctx == NULL) {
         LOG_ERR("%s: error: failed to create the llama_context\n" , __func__);
         LOG_ERR("%s: error: failed to create the llama_context\n" , __func__);
@@ -173,7 +180,7 @@ int main(int argc, char ** argv) {
                 continue;
                 continue;
             }
             }
 
 
-            const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
+            const llama_token new_token_id = llama_sampler_sample(samplers[i], ctx, i_batch[i]);
 
 
             // is it an end of generation? -> mark the stream as finished
             // is it an end of generation? -> mark the stream as finished
             if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {
             if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {
@@ -229,14 +236,17 @@ int main(int argc, char ** argv) {
             __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
             __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
 
 
     LOG("\n");
     LOG("\n");
-    llama_perf_sampler_print(smpl);
+    llama_perf_sampler_print(samplers[0]);
     llama_perf_context_print(ctx);
     llama_perf_context_print(ctx);
 
 
     fprintf(stderr, "\n");
     fprintf(stderr, "\n");
 
 
     llama_batch_free(batch);
     llama_batch_free(batch);
 
 
-    llama_sampler_free(smpl);
+    for (auto & sampler_config : samplers) {
+        llama_sampler_free(sampler_config);
+    }
+
     llama_free(ctx);
     llama_free(ctx);
     llama_model_free(model);
     llama_model_free(model);
 
 

+ 3 - 3
examples/embedding/embedding.cpp

@@ -131,10 +131,10 @@ int main(int argc, char ** argv) {
     llama_numa_init(params.numa);
     llama_numa_init(params.numa);
 
 
     // load the model
     // load the model
-    common_init_result llama_init = common_init_from_params(params);
+    auto llama_init = common_init_from_params(params);
 
 
-    llama_model * model = llama_init.model.get();
-    llama_context * ctx = llama_init.context.get();
+    auto * model = llama_init->model();
+    auto * ctx = llama_init->context();
 
 
     if (model == NULL) {
     if (model == NULL) {
         LOG_ERR("%s: unable to load model\n", __func__);
         LOG_ERR("%s: unable to load model\n", __func__);

+ 3 - 3
examples/eval-callback/eval-callback.cpp

@@ -202,10 +202,10 @@ int main(int argc, char ** argv) {
     params.warmup = false;
     params.warmup = false;
 
 
     // init
     // init
-    common_init_result llama_init = common_init_from_params(params);
+    auto llama_init = common_init_from_params(params);
 
 
-    llama_model * model = llama_init.model.get();
-    llama_context * ctx = llama_init.context.get();
+    auto * model = llama_init->model();
+    auto * ctx   = llama_init->context();
 
 
     if (model == nullptr || ctx == nullptr) {
     if (model == nullptr || ctx == nullptr) {
         LOG_ERR("%s : failed to init\n", __func__);
         LOG_ERR("%s : failed to init\n", __func__);

+ 3 - 3
examples/lookahead/lookahead.cpp

@@ -55,10 +55,10 @@ int main(int argc, char ** argv) {
     llama_numa_init(params.numa);
     llama_numa_init(params.numa);
 
 
     // load the target model
     // load the target model
-    common_init_result llama_init = common_init_from_params(params);
+    auto llama_init = common_init_from_params(params);
 
 
-    llama_model * model = llama_init.model.get();
-    llama_context * ctx = llama_init.context.get();
+    auto * model = llama_init->model();
+    auto * ctx   = llama_init->context();
 
 
     auto * mem = llama_get_memory(ctx);
     auto * mem = llama_get_memory(ctx);
 
 

+ 4 - 4
examples/lookup/lookup-create.cpp

@@ -18,16 +18,16 @@ int main(int argc, char ** argv){
     llama_numa_init(params.numa);
     llama_numa_init(params.numa);
 
 
     // load the model
     // load the model
-    common_init_result llama_init = common_init_from_params(params);
+    auto llama_init = common_init_from_params(params);
 
 
-    llama_model_ptr & model = llama_init.model;
-    llama_context_ptr & ctx = llama_init.context;
+    auto * model = llama_init->model();
+    auto * ctx = llama_init->context();
 
 
     GGML_ASSERT(model != nullptr);
     GGML_ASSERT(model != nullptr);
 
 
     // tokenize the prompt
     // tokenize the prompt
     std::vector<llama_token> inp;
     std::vector<llama_token> inp;
-    inp = common_tokenize(ctx.get(), params.prompt, true, true);
+    inp = common_tokenize(ctx, params.prompt, true, true);
     fprintf(stderr, "%s: tokenization done\n", __func__);
     fprintf(stderr, "%s: tokenization done\n", __func__);
 
 
     common_ngram_cache ngram_cache;
     common_ngram_cache ngram_cache;

+ 4 - 4
examples/lookup/lookup-stats.cpp

@@ -28,13 +28,13 @@ int main(int argc, char ** argv){
     llama_numa_init(params.numa);
     llama_numa_init(params.numa);
 
 
     // load the model
     // load the model
-    common_init_result llama_init = common_init_from_params(params);
+    auto llama_init = common_init_from_params(params);
 
 
-    llama_context_ptr & ctx = llama_init.context;
+    llama_context * ctx = llama_init->context();
 
 
     // tokenize the prompt
     // tokenize the prompt
     std::vector<llama_token> inp;
     std::vector<llama_token> inp;
-    inp = common_tokenize(ctx.get(), params.prompt, true, true);
+    inp = common_tokenize(ctx, params.prompt, true, true);
 
 
     common_ngram_cache ngram_cache_context;
     common_ngram_cache ngram_cache_context;
     common_ngram_cache ngram_cache_dynamic;
     common_ngram_cache ngram_cache_dynamic;
@@ -65,7 +65,7 @@ int main(int argc, char ** argv){
     }
     }
 
 
     const int n_input = inp.size();
     const int n_input = inp.size();
-    const int n_ctx = llama_n_ctx(ctx.get());
+    const int n_ctx = llama_n_ctx(ctx);
 
 
     int n_drafted = 0;
     int n_drafted = 0;
     int n_accept  = 0;
     int n_accept  = 0;

+ 3 - 3
examples/lookup/lookup.cpp

@@ -29,10 +29,10 @@ int main(int argc, char ** argv){
     llama_numa_init(params.numa);
     llama_numa_init(params.numa);
 
 
     // load the model
     // load the model
-    common_init_result llama_init = common_init_from_params(params);
+    auto llama_init = common_init_from_params(params);
 
 
-    llama_model * model = llama_init.model.get();
-    llama_context * ctx = llama_init.context.get();
+    auto * model = llama_init->model();
+    auto * ctx   = llama_init->context();
 
 
     const llama_vocab * vocab = llama_model_get_vocab(model);
     const llama_vocab * vocab = llama_model_get_vocab(model);
 
 

+ 3 - 3
examples/parallel/parallel.cpp

@@ -192,10 +192,10 @@ int main(int argc, char ** argv) {
     llama_numa_init(params.numa);
     llama_numa_init(params.numa);
 
 
     // load the target model
     // load the target model
-    common_init_result llama_init = common_init_from_params(params);
+    auto llama_init = common_init_from_params(params);
 
 
-    llama_model * model = llama_init.model.get();
-    llama_context * ctx = llama_init.context.get();
+    auto * model = llama_init->model();
+    auto * ctx   = llama_init->context();
 
 
     auto * mem = llama_get_memory(ctx);
     auto * mem = llama_get_memory(ctx);
 
 

+ 3 - 3
examples/retrieval/retrieval.cpp

@@ -149,10 +149,10 @@ int main(int argc, char ** argv) {
     llama_numa_init(params.numa);
     llama_numa_init(params.numa);
 
 
     // load the model
     // load the model
-    common_init_result llama_init = common_init_from_params(params);
+    auto llama_init = common_init_from_params(params);
 
 
-    llama_model * model = llama_init.model.get();
-    llama_context * ctx = llama_init.context.get();
+    auto * model = llama_init->model();
+    auto * ctx   = llama_init->context();
 
 
     if (model == NULL) {
     if (model == NULL) {
         LOG_ERR("%s: unable to load model\n", __func__);
         LOG_ERR("%s: unable to load model\n", __func__);

+ 3 - 3
examples/save-load-state/save-load-state.cpp

@@ -34,10 +34,10 @@ int main(int argc, char ** argv) {
     std::string result2;
     std::string result2;
 
 
     // init
     // init
-    common_init_result llama_init = common_init_from_params(params);
+    auto llama_init = common_init_from_params(params);
 
 
-    llama_model * model = llama_init.model.get();
-    llama_context * ctx = llama_init.context.get();
+    auto * model = llama_init->model();
+    auto * ctx   = llama_init->context();
 
 
     if (model == nullptr || ctx == nullptr) {
     if (model == nullptr || ctx == nullptr) {
         fprintf(stderr, "%s : failed to init\n", __func__);
         fprintf(stderr, "%s : failed to init\n", __func__);

+ 6 - 6
examples/speculative-simple/speculative-simple.cpp

@@ -40,10 +40,10 @@ int main(int argc, char ** argv) {
     llama_context * ctx_dft = NULL;
     llama_context * ctx_dft = NULL;
 
 
     // load the target model
     // load the target model
-    common_init_result llama_init_tgt = common_init_from_params(params);
+    auto llama_init_tgt = common_init_from_params(params);
 
 
-    model_tgt = llama_init_tgt.model.get();
-    ctx_tgt   = llama_init_tgt.context.get();
+    model_tgt = llama_init_tgt->model();
+    ctx_tgt   = llama_init_tgt->context();
 
 
     const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
     const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
 
 
@@ -61,10 +61,10 @@ int main(int argc, char ** argv) {
     params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
     params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
     params.tensor_buft_overrides     = params.speculative.tensor_buft_overrides;
     params.tensor_buft_overrides     = params.speculative.tensor_buft_overrides;
 
 
-    common_init_result llama_init_dft = common_init_from_params(params);
+    auto llama_init_dft = common_init_from_params(params);
 
 
-    //model_dft = llama_init_dft.model.get();
-    ctx_dft   = llama_init_dft.context.get();
+    //model_dft = llama_init_dft->model();
+    ctx_dft   = llama_init_dft->context();
 
 
     if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
     if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
         LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());
         LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());

+ 8 - 8
examples/speculative/speculative.cpp

@@ -71,10 +71,10 @@ int main(int argc, char ** argv) {
     llama_context * ctx_dft = NULL;
     llama_context * ctx_dft = NULL;
 
 
     // load the target model
     // load the target model
-    common_init_result llama_init_tgt = common_init_from_params(params);
+    auto llama_init_tgt = common_init_from_params(params);
 
 
-    model_tgt = llama_init_tgt.model.get();
-    ctx_tgt   = llama_init_tgt.context.get();
+    model_tgt = llama_init_tgt->model();
+    ctx_tgt   = llama_init_tgt->context();
 
 
     // load the draft model
     // load the draft model
     params.devices = params.speculative.devices;
     params.devices = params.speculative.devices;
@@ -87,10 +87,10 @@ int main(int argc, char ** argv) {
     params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
     params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
     params.tensor_buft_overrides     = params.speculative.tensor_buft_overrides;
     params.tensor_buft_overrides     = params.speculative.tensor_buft_overrides;
 
 
-    common_init_result llama_init_dft = common_init_from_params(params);
+    auto llama_init_dft = common_init_from_params(params);
 
 
-    model_dft = llama_init_dft.model.get();
-    ctx_dft   = llama_init_dft.context.get();
+    model_dft = llama_init_dft->model();
+    ctx_dft   = llama_init_dft->context();
 
 
     const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
     const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
     const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
     const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
@@ -242,7 +242,7 @@ int main(int argc, char ** argv) {
                 bool accept = false;
                 bool accept = false;
                 if (params.sampling.temp > 0) {
                 if (params.sampling.temp > 0) {
                     // stochastic verification
                     // stochastic verification
-                    common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
+                    common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
 
 
                     auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
                     auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
 
 
@@ -491,7 +491,7 @@ int main(int argc, char ** argv) {
                     continue;
                     continue;
                 }
                 }
 
 
-                common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
+                common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft);
 
 
                 const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);
                 const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);
 
 

+ 9 - 8
examples/training/finetune.cpp

@@ -39,9 +39,10 @@ int main(int argc, char ** argv) {
     llama_backend_init();
     llama_backend_init();
     llama_numa_init(params.numa);
     llama_numa_init(params.numa);
     // load the model and apply lora adapter, if any
     // load the model and apply lora adapter, if any
-    common_init_result   llama_init = common_init_from_params(params);
-    llama_model_ptr    & model      = llama_init.model;
-    llama_context_ptr  & ctx        = llama_init.context;
+    auto llama_init = common_init_from_params(params);
+
+    auto * model = llama_init->model();
+    auto * ctx   = llama_init->context();
 
 
     if (model == NULL) {
     if (model == NULL) {
         LOG_ERR("%s: unable to load model\n", __func__);
         LOG_ERR("%s: unable to load model\n", __func__);
@@ -54,8 +55,8 @@ int main(int argc, char ** argv) {
         LOG_INF("%s\n", common_params_get_system_info(params).c_str());
         LOG_INF("%s\n", common_params_get_system_info(params).c_str());
     }
     }
 
 
-    std::vector<llama_token> tokens  = common_tokenize(ctx.get(), params.prompt, true);
-    ggml_opt_dataset_t       dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get()) / 2);
+    std::vector<llama_token> tokens  = common_tokenize(ctx, params.prompt, true);
+    ggml_opt_dataset_t       dataset = common_opt_dataset_init(ctx, tokens, llama_n_ctx(ctx) / 2);
 
 
     struct lr_opt & lr = params.lr;
     struct lr_opt & lr = params.lr;
     LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n",
     LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n",
@@ -70,7 +71,7 @@ int main(int argc, char ** argv) {
         /*get_opt_pars_ud =*/&params.lr,
         /*get_opt_pars_ud =*/&params.lr,
         /*optimizer_type  =*/params.optimizer,
         /*optimizer_type  =*/params.optimizer,
     };
     };
-    llama_opt_init(ctx.get(), model.get(), lopt_params);
+    llama_opt_init(ctx, model, lopt_params);
 
 
     const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);
     const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);
 
 
@@ -78,7 +79,7 @@ int main(int argc, char ** argv) {
     ggml_opt_result_t result_eval  = ggml_opt_result_init();
     ggml_opt_result_t result_eval  = ggml_opt_result_init();
 
 
     for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) {
     for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) {
-        llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
+        llama_opt_epoch(ctx, dataset, result_train, result_eval, idata_split,
                         ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
                         ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
         fprintf(stderr, "\n");
         fprintf(stderr, "\n");
 
 
@@ -88,7 +89,7 @@ int main(int argc, char ** argv) {
     ggml_opt_result_free(result_train);
     ggml_opt_result_free(result_train);
     ggml_opt_result_free(result_eval);
     ggml_opt_result_free(result_eval);
 
 
-    llama_model_save_to_file(model.get(), params.out_file.c_str());
+    llama_model_save_to_file(model, params.out_file.c_str());
 
 
     llama_backend_free();
     llama_backend_free();
 
 

+ 7 - 13
tools/completion/completion.cpp

@@ -141,13 +141,15 @@ int main(int argc, char ** argv) {
 
 
     // load the model and apply lora adapter, if any
     // load the model and apply lora adapter, if any
     LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
     LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
-    common_init_result llama_init = common_init_from_params(params);
 
 
-    model = llama_init.model.get();
-    ctx = llama_init.context.get();
+    auto llama_init = common_init_from_params(params);
 
 
-    if (model == NULL) {
-        LOG_ERR("%s: error: unable to load model\n", __func__);
+    ctx   = llama_init->context();
+    model = llama_init->model();
+    smpl  = llama_init->sampler(0);
+
+    if (ctx == NULL) {
+        LOG_ERR("%s: error: unable to create context\n", __func__);
         return 1;
         return 1;
     }
     }
 
 
@@ -474,12 +476,6 @@ int main(int argc, char ** argv) {
         }
         }
     }
     }
 
 
-    smpl = common_sampler_init(model, sparams);
-    if (!smpl) {
-        LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
-        return 1;
-    }
-
     LOG_INF("sampler seed: %u\n",     common_sampler_get_seed(smpl));
     LOG_INF("sampler seed: %u\n",     common_sampler_get_seed(smpl));
     LOG_INF("sampler params: \n%s\n", sparams.print().c_str());
     LOG_INF("sampler params: \n%s\n", sparams.print().c_str());
     LOG_INF("sampler chain: %s\n",    common_sampler_print(smpl).c_str());
     LOG_INF("sampler chain: %s\n",    common_sampler_print(smpl).c_str());
@@ -993,8 +989,6 @@ int main(int argc, char ** argv) {
     LOG("\n\n");
     LOG("\n\n");
     common_perf_print(ctx, smpl);
     common_perf_print(ctx, smpl);
 
 
-    common_sampler_free(smpl);
-
     llama_backend_free();
     llama_backend_free();
 
 
     ggml_threadpool_free_fn(threadpool);
     ggml_threadpool_free_fn(threadpool);

+ 3 - 3
tools/cvector-generator/cvector-generator.cpp

@@ -419,10 +419,10 @@ int main(int argc, char ** argv) {
     llama_numa_init(params.numa);
     llama_numa_init(params.numa);
 
 
     // load the model to get hparams
     // load the model to get hparams
-    common_init_result llama_init = common_init_from_params(params);
+    auto llama_init = common_init_from_params(params);
 
 
-    llama_model * model = llama_init.model.get();
-    llama_context * ctx = llama_init.context.get();
+    auto * model = llama_init->model();
+    auto * ctx   = llama_init->context();
 
 
     // int n_ctx = llama_n_ctx(ctx);
     // int n_ctx = llama_n_ctx(ctx);
     int n_layers = llama_model_n_layer(model);
     int n_layers = llama_model_n_layer(model);

+ 3 - 3
tools/imatrix/imatrix.cpp

@@ -1265,10 +1265,10 @@ int main(int argc, char ** argv) {
     params.warmup = false;
     params.warmup = false;
 
 
     // init
     // init
-    common_init_result llama_init = common_init_from_params(params);
+    auto llama_init = common_init_from_params(params);
 
 
-    llama_model * model = llama_init.model.get();
-    llama_context * ctx = llama_init.context.get();
+    auto * model = llama_init->model();
+    auto * ctx   = llama_init->context();
 
 
     if (model == nullptr || ctx == nullptr) {
     if (model == nullptr || ctx == nullptr) {
         LOG_ERR("%s : failed to init\n", __func__);
         LOG_ERR("%s : failed to init\n", __func__);

+ 3 - 3
tools/mtmd/mtmd-cli.cpp

@@ -65,7 +65,7 @@ static void sigint_handler(int signo) {
 
 
 struct mtmd_cli_context {
 struct mtmd_cli_context {
     mtmd::context_ptr ctx_vision;
     mtmd::context_ptr ctx_vision;
-    common_init_result llama_init;
+    common_init_result_ptr llama_init;
 
 
     llama_model       * model;
     llama_model       * model;
     llama_context     * lctx;
     llama_context     * lctx;
@@ -89,8 +89,8 @@ struct mtmd_cli_context {
     llama_pos n_past = 0;
     llama_pos n_past = 0;
 
 
     mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) {
     mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) {
-        model = llama_init.model.get();
-        lctx = llama_init.context.get();
+        model = llama_init->model();
+        lctx = llama_init->context();
         vocab = llama_model_get_vocab(model);
         vocab = llama_model_get_vocab(model);
         smpl = common_sampler_init(model, params.sampling);
         smpl = common_sampler_init(model, params.sampling);
         n_threads = params.cpuparams.n_threads;
         n_threads = params.cpuparams.n_threads;

+ 3 - 3
tools/perplexity/perplexity.cpp

@@ -2024,10 +2024,10 @@ int main(int argc, char ** argv) {
     llama_numa_init(params.numa);
     llama_numa_init(params.numa);
 
 
     // load the model and apply lora adapter, if any
     // load the model and apply lora adapter, if any
-    common_init_result llama_init = common_init_from_params(params);
+    auto llama_init = common_init_from_params(params);
 
 
-    llama_model * model = llama_init.model.get();
-    llama_context * ctx = llama_init.context.get();
+    auto * model = llama_init->model();
+    auto * ctx   = llama_init->context();
 
 
     if (model == NULL) {
     if (model == NULL) {
         LOG_ERR("%s: unable to load model\n", __func__);
         LOG_ERR("%s: unable to load model\n", __func__);

+ 20 - 27
tools/server/server-context.cpp

@@ -153,7 +153,7 @@ struct server_slot {
     // sampling
     // sampling
     json json_schema;
     json json_schema;
 
 
-    struct common_sampler * smpl = nullptr;
+    common_sampler_ptr smpl;
 
 
     llama_token sampled; // in speculative mode, this is the last accepted token
     llama_token sampled; // in speculative mode, this is the last accepted token
     llama_tokens drafted;
     llama_tokens drafted;
@@ -510,8 +510,8 @@ struct server_context_impl {
     common_params params_base;
     common_params params_base;
 
 
     // note: keep these alive - they determine the lifetime of the model, context, etc.
     // note: keep these alive - they determine the lifetime of the model, context, etc.
-    common_init_result llama_init;
-    common_init_result llama_init_dft;
+    common_init_result_ptr llama_init;
+    common_init_result_ptr llama_init_dft;
 
 
     llama_model * model = nullptr;
     llama_model * model = nullptr;
     llama_context * ctx = nullptr;
     llama_context * ctx = nullptr;
@@ -557,9 +557,6 @@ struct server_context_impl {
 
 
         // Clear any sampling context
         // Clear any sampling context
         for (server_slot & slot : slots) {
         for (server_slot & slot : slots) {
-            common_sampler_free(slot.smpl);
-            slot.smpl = nullptr;
-
             llama_free(slot.ctx_dft);
             llama_free(slot.ctx_dft);
             slot.ctx_dft = nullptr;
             slot.ctx_dft = nullptr;
 
 
@@ -580,8 +577,8 @@ struct server_context_impl {
 
 
         llama_init = common_init_from_params(params_base);
         llama_init = common_init_from_params(params_base);
 
 
-        model = llama_init.model.get();
-        ctx   = llama_init.context.get();
+        model = llama_init->model();
+        ctx   = llama_init->context();
 
 
         if (model == nullptr) {
         if (model == nullptr) {
             SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
             SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
@@ -613,25 +610,25 @@ struct server_context_impl {
 
 
             llama_init_dft = common_init_from_params(params_dft);
             llama_init_dft = common_init_from_params(params_dft);
 
 
-            model_dft = llama_init_dft.model.get();
+            model_dft = llama_init_dft->model();
 
 
             if (model_dft == nullptr) {
             if (model_dft == nullptr) {
                 SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str());
                 SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str());
                 return false;
                 return false;
             }
             }
 
 
-            vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get());
+            vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft->context());
             if (!vocab_dft_compatible) {
             if (!vocab_dft_compatible) {
                 SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
                 SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
             }
             }
 
 
-            const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get());
+            const int n_ctx_dft = llama_n_ctx(llama_init_dft->context());
 
 
             cparams_dft = common_context_params_to_llama(params_dft);
             cparams_dft = common_context_params_to_llama(params_dft);
             cparams_dft.n_batch = n_ctx_dft;
             cparams_dft.n_batch = n_ctx_dft;
 
 
             // the context is not needed - we will create one for each slot
             // the context is not needed - we will create one for each slot
-            llama_init_dft.context.reset();
+            llama_init_dft->free_context();
         }
         }
 
 
         chat_templates = common_chat_templates_init(model, params_base.chat_template);
         chat_templates = common_chat_templates_init(model, params_base.chat_template);
@@ -1051,18 +1048,15 @@ struct server_context_impl {
 
 
         // initialize samplers
         // initialize samplers
         {
         {
-            if (slot.smpl != nullptr) {
-                common_sampler_free(slot.smpl);
-            }
+            slot.smpl.reset(common_sampler_init(model, task.params.sampling));
 
 
-            slot.smpl = common_sampler_init(model, task.params.sampling);
             if (slot.smpl == nullptr) {
             if (slot.smpl == nullptr) {
                 // for now, the only error that may happen here is invalid grammar
                 // for now, the only error that may happen here is invalid grammar
                 send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
                 send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
                 return false;
                 return false;
             }
             }
 
 
-            SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str());
+            SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
         }
         }
 
 
         // initialize draft batch
         // initialize draft batch
@@ -1216,11 +1210,10 @@ struct server_context_impl {
     }
     }
 
 
     void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
     void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
-        size_t n_probs = slot.task->params.sampling.n_probs;
-        size_t n_vocab = llama_vocab_n_tokens(vocab);
+        const size_t n_probs = slot.task->params.sampling.n_probs;
 
 
         if (post_sampling) {
         if (post_sampling) {
-            const auto * cur_p = common_sampler_get_candidates(slot.smpl, true);
+            const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true);
             const size_t max_probs = cur_p->size;
             const size_t max_probs = cur_p->size;
 
 
             // set probability for sampled token
             // set probability for sampled token
@@ -1245,7 +1238,7 @@ struct server_context_impl {
             std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
             std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
 
 
             // set probability for sampled token
             // set probability for sampled token
-            for (size_t i = 0; i < n_vocab; i++) {
+            for (size_t i = 0; i < cur.size(); i++) {
                 // set probability for sampled token
                 // set probability for sampled token
                 if (cur[i].id == result.tok) {
                 if (cur[i].id == result.tok) {
                     result.prob = cur[i].p;
                     result.prob = cur[i].p;
@@ -1255,7 +1248,7 @@ struct server_context_impl {
 
 
             // set probability for top n_probs tokens
             // set probability for top n_probs tokens
             result.probs.reserve(n_probs);
             result.probs.reserve(n_probs);
-            for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
+            for (size_t i = 0; i < std::min(cur.size(), n_probs); i++) {
                 result.probs.push_back({
                 result.probs.push_back({
                     cur[i].id,
                     cur[i].id,
                     common_token_to_piece(ctx, cur[i].id, special),
                     common_token_to_piece(ctx, cur[i].id, special),
@@ -2301,13 +2294,13 @@ struct server_context_impl {
 
 
                         GGML_ASSERT(batch.n_tokens > 0);
                         GGML_ASSERT(batch.n_tokens > 0);
 
 
-                        common_sampler_reset(slot.smpl);
+                        common_sampler_reset(slot.smpl.get());
 
 
                         // Process all prompt tokens through sampler system
                         // Process all prompt tokens through sampler system
                         for (int i = 0; i < slot.task->n_tokens(); ++i) {
                         for (int i = 0; i < slot.task->n_tokens(); ++i) {
                             llama_token id = input_tokens[i];
                             llama_token id = input_tokens[i];
                             if (id != LLAMA_TOKEN_NULL) {
                             if (id != LLAMA_TOKEN_NULL) {
-                                common_sampler_accept(slot.smpl, id, false);
+                                common_sampler_accept(slot.smpl.get(), id, false);
                             }
                             }
                         }
                         }
 
 
@@ -2525,11 +2518,11 @@ struct server_context_impl {
 
 
                 const int tok_idx = slot.i_batch - i;
                 const int tok_idx = slot.i_batch - i;
 
 
-                llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
+                llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
 
 
                 slot.i_batch = -1;
                 slot.i_batch = -1;
 
 
-                common_sampler_accept(slot.smpl, id, true);
+                common_sampler_accept(slot.smpl.get(), id, true);
 
 
                 slot.n_decoded += 1;
                 slot.n_decoded += 1;
 
 
@@ -2570,7 +2563,7 @@ struct server_context_impl {
                 size_t n_draft = slot.drafted.size();
                 size_t n_draft = slot.drafted.size();
 
 
                 // the accepted tokens from the speculation
                 // the accepted tokens from the speculation
-                const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, slot.i_batch_dft, slot.drafted);
+                const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
                 slot.i_batch_dft.clear();
                 slot.i_batch_dft.clear();
                 slot.drafted.clear();
                 slot.drafted.clear();
 
 

+ 1 - 1
tools/server/tests/unit/test_compat_anthropic.py

@@ -684,7 +684,7 @@ def test_anthropic_streaming_content_block_indices():
     # Request that might produce both text and tool use
     # Request that might produce both text and tool use
     res = server.make_stream_request("POST", "/v1/messages", data={
     res = server.make_stream_request("POST", "/v1/messages", data={
         "model": "test",
         "model": "test",
-        "max_tokens": 200,
+        "max_tokens": 400,
         "stream": True,
         "stream": True,
         "tools": [{
         "tools": [{
             "name": "test_tool",
             "name": "test_tool",

+ 6 - 6
tools/tts/tts.cpp

@@ -568,10 +568,10 @@ int main(int argc, char ** argv) {
     llama_context * ctx_ttc = NULL;
     llama_context * ctx_ttc = NULL;
     llama_context * ctx_cts = NULL;
     llama_context * ctx_cts = NULL;
 
 
-    common_init_result llama_init_ttc = common_init_from_params(params);
+    auto llama_init_ttc = common_init_from_params(params);
 
 
-    model_ttc = llama_init_ttc.model.get();
-    ctx_ttc   = llama_init_ttc.context.get();
+    model_ttc = llama_init_ttc->model();
+    ctx_ttc   = llama_init_ttc->context();
 
 
     if (model_ttc == nullptr || ctx_ttc == nullptr) {
     if (model_ttc == nullptr || ctx_ttc == nullptr) {
         return ENOENT;
         return ENOENT;
@@ -583,10 +583,10 @@ int main(int argc, char ** argv) {
     params.embedding = true;
     params.embedding = true;
     params.n_ubatch = params.n_batch;
     params.n_ubatch = params.n_batch;
 
 
-    common_init_result llama_init_cts = common_init_from_params(params);
+    auto llama_init_cts = common_init_from_params(params);
 
 
-    model_cts = llama_init_cts.model.get();
-    ctx_cts   = llama_init_cts.context.get();
+    model_cts = llama_init_cts->model();
+    ctx_cts   = llama_init_cts->context();
 
 
     if (model_cts == nullptr || ctx_cts == nullptr) {
     if (model_cts == nullptr || ctx_cts == nullptr) {
         return ENOENT;
         return ENOENT;