Преглед изворни кода

feat: remove a sampler from a chain (#9445)

* feat: remove a sampler from a chain

* fix: return removed sampler

* fix: safer casting
Gilad S. пре 1 година
родитељ
комит
bd35cb0ae3
2 измењених фајлова са 17 додато и 1 уклоњено
  1. 3 0
      include/llama.h
  2. 14 1
      src/llama-sampling.cpp

+ 3 - 0
include/llama.h

@@ -1056,6 +1056,9 @@ extern "C" {
     LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
     LLAMA_API int                    llama_sampler_chain_n  (const struct llama_sampler * chain);
 
+    // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
+    LLAMA_API struct llama_sampler * llama_sampler_chain_remove(   struct llama_sampler * chain, int32_t i);
+
     // available samplers:
 
     LLAMA_API struct llama_sampler * llama_sampler_init_greedy     (void);

+ 14 - 1
src/llama-sampling.cpp

@@ -349,13 +349,26 @@ void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler
 struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
     const auto * p = (const llama_sampler_chain *) chain->ctx;
 
-    if (i < 0 || i >= (int32_t) p->samplers.size()) {
+    if (i < 0 || (size_t) i >= p->samplers.size()) {
         return nullptr;
     }
 
     return p->samplers[i];
 }
 
+struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
+    auto * p = (llama_sampler_chain *) chain->ctx;
+
+    if (i < 0 || (size_t) i >= p->samplers.size()) {
+        return nullptr;
+    }
+
+    auto * result = p->samplers[i];
+    p->samplers.erase(p->samplers.begin() + i);
+
+    return result;
+}
+
 int llama_sampler_chain_n(const struct llama_sampler * chain) {
     const auto * p = (const llama_sampler_chain *) chain->ctx;