Browse Source

context : reserve new scheduler when graph topology changes (#18547)

* context : reserve new scheduler when graph topology changes

* cont : fix

* cont : fix reserve

* cont : reserve only when changes occur + timing

* context : add comments

* llama : reserve on sampler changes

* common : allow null common_sampler

* server : task declares needs (embd, logits, sampling)

* server : do not init sampler if not needed

* llama : fix need_reserve when unsetting a sampler

* server : consolidate slot reset/clear logic
Georgi Gerganov 1 week ago
parent
commit
39173bcacb

+ 0 - 1
common/common.cpp

@@ -1172,7 +1172,6 @@ common_init_result::common_init_result(common_params & params) :
         pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
     }
 
-    // TODO: temporarily gated behind a flag
     if (params.sampling.backend_sampling) {
         cparams.samplers   = pimpl->samplers_seq_config.data();
         cparams.n_samplers = pimpl->samplers_seq_config.size();

+ 19 - 5
common/sampling.cpp

@@ -334,15 +334,21 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
 }
 
 void common_sampler_free(struct common_sampler * gsmpl) {
-    if (gsmpl) {
-        llama_sampler_free(gsmpl->grmr);
-        llama_sampler_free(gsmpl->chain);
-
-        delete gsmpl;
+    if (!gsmpl) {
+        return;
     }
+
+    llama_sampler_free(gsmpl->grmr);
+    llama_sampler_free(gsmpl->chain);
+
+    delete gsmpl;
 }
 
 void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
+    if (!gsmpl) {
+        return;
+    }
+
     const auto tm = gsmpl->tm();
 
     if (gsmpl->grmr && accept_grammar) {
@@ -355,6 +361,10 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
 }
 
 void common_sampler_reset(struct common_sampler * gsmpl) {
+    if (!gsmpl) {
+        return;
+    }
+
     gsmpl->reset();
 }
 
@@ -415,6 +425,10 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
 }
 
 struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
+    if (!gsmpl) {
+        return nullptr;
+    }
+
     return gsmpl->chain;
 }
 

+ 0 - 1
examples/batched/batched.cpp

@@ -81,7 +81,6 @@ int main(int argc, char ** argv) {
         sampler_configs.push_back({ i, smpl });
     }
 
-    // TODO: temporarily gated behind a flag
     if (params.sampling.backend_sampling) {
         ctx_params.samplers   = sampler_configs.data();
         ctx_params.n_samplers = sampler_configs.size();

+ 0 - 1
include/llama.h

@@ -1256,7 +1256,6 @@ extern "C" {
     // [EXPERIMENTAL]
     // attach a sampler to the context
     // note: prefer initializing the context with llama_context_params.samplers when possible
-    // note: changing the samplers of a context can cause graph reallocations and degraded performance
     LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
 
     // mirror of llama_sampler_i:

+ 226 - 142
src/llama-context.cpp

@@ -146,6 +146,7 @@ llama_context::llama_context(
     }
 
     cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
+    cparams.auto_fa    = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO;
 
     // with causal attention, the batch size is limited by the context size
     cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
@@ -155,6 +156,9 @@ llama_context::llama_context(
     cparams.op_offload = params.op_offload;
     cparams.kv_unified = params.kv_unified;
 
+    // intialized later
+    cparams.pipeline_parallel = false;
+
     {
         const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
         graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
@@ -302,16 +306,6 @@ llama_context::llama_context(
 
         LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
 
-        const uint32_t n_seqs = cparams.n_seq_max;
-        const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
-
-        const size_t max_nodes = this->graph_max_nodes(n_tokens);
-
-        LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
-
-        gf_res_prev.reset(new llm_graph_result(max_nodes));
-        gf_res_reserve.reset(new llm_graph_result(max_nodes));
-
         // TODO: move these checks to ggml_backend_sched
         // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
         bool pipeline_parallel =
@@ -340,177 +334,217 @@ llama_context::llama_context(
             }
         }
 
-        sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload));
+        cparams.pipeline_parallel = pipeline_parallel;
 
-        if (pipeline_parallel) {
+        if (cparams.pipeline_parallel) {
             LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
         }
 
-        llama_memory_context_ptr mctx;
-        if (memory) {
-            LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
-            mctx = memory->init_full();
-            if (!mctx) {
-                throw std::runtime_error("failed to initialize memory module");
+        sched_reserve();
+
+        if (!cparams.flash_attn) {
+            if (ggml_is_quantized(params.type_v)) {
+                throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
             }
         }
+    }
 
-        cross.v_embd.clear();
-
-        // avoid reserving graphs with zero outputs - assume one output per sequence
-        n_outputs = n_seqs;
-
-        LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
+    // Initialize the full vocabulary token ids for backend samplers.
+    {
+        const int n_vocab = model.vocab.n_tokens();
 
-        // resolve automatic Flash Attention use
-        if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
-            auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
-            if (!gf) {
-                throw std::runtime_error("failed to split graph for Flash Attention check");
-            }
+        sampling.token_ids_full_vocab.resize(n_vocab);
+        for (int i = 0; i < n_vocab; ++i) {
+            sampling.token_ids_full_vocab[i] = i;
+        }
+    }
+}
 
-            const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
-            bool fa_device_mismatch = false;
-            for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
-                ggml_tensor * n = ggml_graph_node(gf, i);
-                if (n->op != GGML_OP_FLASH_ATTN_EXT) {
-                    continue;
-                }
-                ggml_backend_dev_t device_fa = ggml_backend_get_device(
-                    ggml_backend_sched_get_tensor_backend(sched.get(), n));
+llama_context::~llama_context() {
+    if (!model.hparams.no_alloc) {
+        for (size_t i = 0; i < backend_ptrs.size(); ++i) {
+            ggml_backend_t             backend = backend_ptrs[i];
+            ggml_backend_buffer_type_t buft    = backend_buft[i];
 
-                // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
-                GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
-                const int il = std::stoi(n->name + prefix_len);
-                ggml_backend_dev_t device_kv = model.dev_layer(il);
-                if (device_fa != device_kv) {
-                    LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
-                        "is assigned to device %s (usually due to missing support)\n",
-                        __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
-                    // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
-                    fa_device_mismatch = true;
-                    break;
-                }
-            }
-            if (fa_device_mismatch) {
-                cparams.flash_attn = false;
-                LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
-                if (ggml_is_quantized(params.type_v)) {
-                    throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
-                }
+            const size_t size_exp = backend_buf_exp_size[i];
+            const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
+            if (size_exp == size_act) {
+                LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
+                    __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
             } else {
-                cparams.flash_attn = true;
-                LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
+                LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
+                    __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
             }
         }
+    }
+    ggml_opt_free(opt_ctx);
+}
 
-        // reserve worst-case graph
-        int n_splits_pp = -1;
-        int n_nodes_pp  = -1;
+void llama_context::sched_reserve() {
+    if (!sched_need_reserve) {
+        return;
+    }
 
-        int n_splits_tg = -1;
-        int n_nodes_tg  = -1;
+    sched_need_reserve = false;
 
-        // reserve pp (prompt processing) graph first so that buffers are only allocated once
-        {
-            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
-                model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
-            if (!gf) {
-                if (pipeline_parallel) {
-                    LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
-                    sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
-                    gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
-                }
-                if (!gf) {
-                    throw std::runtime_error("failed to allocate compute pp buffers");
-                }
-            }
+    LLAMA_LOG_INFO("%s: reserving ...\n", __func__);
+
+    synchronize();
+
+    const int64_t t_start_us = ggml_time_us();
+
+    const uint32_t n_seqs = cparams.n_seq_max;
+    const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
-            n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
-            n_nodes_pp  = ggml_graph_n_nodes(gf);
+    const size_t max_nodes = this->graph_max_nodes(n_tokens);
+
+    LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
+
+    gf_res_prev.reset(new llm_graph_result(max_nodes));
+    gf_res_reserve.reset(new llm_graph_result(max_nodes));
+
+    sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, cparams.pipeline_parallel, cparams.op_offload));
+
+    llama_memory_context_ptr mctx;
+    if (memory) {
+        LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
+        mctx = memory->init_full();
+        if (!mctx) {
+            throw std::runtime_error("failed to initialize memory module");
         }
+    }
 
-        // reserve with tg (token generation) graph to get the number of splits and nodes
-        {
-            auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
-            if (!gf) {
-                throw std::runtime_error("failed to allocate compute tg buffers");
-            }
+    // avoid reserving graphs with zero outputs - assume one output per sequence
+    const int n_outputs = n_seqs;
 
-            n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
-            n_nodes_tg  = ggml_graph_n_nodes(gf);
+    LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
+
+    // resolve automatic Flash Attention use
+    if (cparams.auto_fa) {
+        auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
+        if (!gf) {
+            throw std::runtime_error("failed to split graph for Flash Attention check");
         }
 
-        // reserve again with pp graph to avoid ggml-alloc reallocations during inference
-        {
-            // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
-            //
-            // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
-            //
-            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
-            if (!gf) {
-                throw std::runtime_error("failed to allocate compute pp buffers");
+        const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
+        bool fa_device_mismatch = false;
+        for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
+            ggml_tensor * n = ggml_graph_node(gf, i);
+            if (n->op != GGML_OP_FLASH_ATTN_EXT) {
+                continue;
+            }
+            ggml_backend_dev_t device_fa = ggml_backend_get_device(
+                    ggml_backend_sched_get_tensor_backend(sched.get(), n));
+
+            // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
+            GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
+            const int il = std::stoi(n->name + prefix_len);
+            ggml_backend_dev_t device_kv = model.dev_layer(il);
+            if (device_fa != device_kv) {
+                LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
+                        "is assigned to device %s (usually due to missing support)\n",
+                        __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
+                // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
+                fa_device_mismatch = true;
+                break;
             }
         }
+        if (fa_device_mismatch) {
+            cparams.flash_attn = false;
+            LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
+        } else {
+            cparams.flash_attn = true;
+            LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
+        }
 
-        for (size_t i = 0; i < backend_ptrs.size(); ++i) {
-            ggml_backend_t             backend = backend_ptrs[i];
-            ggml_backend_buffer_type_t buft    = backend_buft[i];
-            if (!model.hparams.no_alloc) {
-                backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
+        cparams.auto_fa = false;
+    }
+
+    // reserve worst-case graph
+    int n_splits_pp = -1;
+    int n_nodes_pp  = -1;
+
+    int n_splits_tg = -1;
+    int n_nodes_tg  = -1;
+
+    // reserve pp (prompt processing) graph first so that buffers are only allocated once
+    {
+        auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
+                model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
+        if (!gf) {
+            if (cparams.pipeline_parallel) {
+                LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
+                cparams.pipeline_parallel = false;
+                sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
+                gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
             }
-            if (backend_buf_exp_size[i] > 1) {
-                LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
-                        ggml_backend_buft_name(buft),
-                        backend_buf_exp_size[i] / 1024.0 / 1024.0);
+            if (!gf) {
+                throw std::runtime_error("failed to allocate compute pp buffers");
             }
         }
 
-        if (n_nodes_pp == n_nodes_tg) {
-            LLAMA_LOG_INFO("%s: graph nodes  = %d\n", __func__, n_nodes_pp);
-        } else {
-            LLAMA_LOG_INFO("%s: graph nodes  = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
-        }
+        n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
+        n_nodes_pp  = ggml_graph_n_nodes(gf);
+    }
 
-        if (n_splits_pp == n_splits_tg) {
-            LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
-        } else {
-            LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
+    // reserve with tg (token generation) graph to get the number of splits and nodes
+    {
+        auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
+        if (!gf) {
+            throw std::runtime_error("failed to allocate compute tg buffers");
         }
+
+        n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
+        n_nodes_tg  = ggml_graph_n_nodes(gf);
     }
 
-    // Initialize the full vocabulary token ids for backend samplers.
+    // reserve again with pp graph to avoid ggml-alloc reallocations during inference
     {
-        const int n_vocab = model.vocab.n_tokens();
+        // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
+        //
+        // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
+        //
+        auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
+        if (!gf) {
+            throw std::runtime_error("failed to allocate compute pp buffers");
+        }
+    }
 
-        sampling.token_ids_full_vocab.resize(n_vocab);
-        for (int i = 0; i < n_vocab; ++i) {
-            sampling.token_ids_full_vocab[i] = i;
+    for (size_t i = 0; i < backend_ptrs.size(); ++i) {
+        ggml_backend_t             backend = backend_ptrs[i];
+        ggml_backend_buffer_type_t buft    = backend_buft[i];
+        if (!model.hparams.no_alloc) {
+            backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
+        }
+        if (backend_buf_exp_size[i] > 1) {
+            LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
+                    ggml_backend_buft_name(buft),
+                    backend_buf_exp_size[i] / 1024.0 / 1024.0);
         }
     }
-}
 
-llama_context::~llama_context() {
-    if (!model.hparams.no_alloc) {
-        for (size_t i = 0; i < backend_ptrs.size(); ++i) {
-            ggml_backend_t             backend = backend_ptrs[i];
-            ggml_backend_buffer_type_t buft    = backend_buft[i];
+    if (n_nodes_pp == n_nodes_tg) {
+        LLAMA_LOG_INFO("%s: graph nodes  = %d\n", __func__, n_nodes_pp);
+    } else {
+        LLAMA_LOG_INFO("%s: graph nodes  = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
+    }
 
-            const size_t size_exp = backend_buf_exp_size[i];
-            const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
-            if (size_exp == size_act) {
-                LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
-                    __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
-            } else {
-                LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
-                    __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
-            }
-        }
+    if (n_splits_pp == n_splits_tg) {
+        LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
+    } else {
+        LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
     }
-    ggml_opt_free(opt_ctx);
+
+    const int64_t t_end_us = ggml_time_us();
+
+    LLAMA_LOG_INFO("%s: reserve took %.2f ms\n", __func__, (t_end_us - t_start_us)/1000.0);
 }
 
 void llama_context::synchronize() {
+    if (!sched) {
+        return;
+    }
+
     ggml_backend_sched_synchronize(sched.get());
 
     // FIXME: if multiple single tokens are evaluated without a synchronization,
@@ -951,21 +985,40 @@ void llama_context::set_embeddings(bool value) {
     LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
 
     cparams.embeddings = value;
+
+    // TODO: not sure yet if we want to reserve here
+    //sched_need_reserve = true;
 }
 
 void llama_context::set_causal_attn(bool value) {
     LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
 
+    if (cparams.causal_attn == value) {
+        return;
+    }
+
     cparams.causal_attn = value;
+
+    sched_need_reserve = true;
 }
 
 void llama_context::set_warmup(bool value) {
     LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
 
+    if (cparams.warmup == value) {
+        return;
+    }
+
     cparams.warmup = value;
+
+    sched_need_reserve = true;
 }
 
 bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
+    if (!sampler && sampling.samplers.count(seq_id) == 0) {
+        return true;
+    }
+
     LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
 
     const bool can_offload =
@@ -985,12 +1038,18 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
 
         sampling.samplers[seq_id] = sampler;
 
+        sched_need_reserve = true;
+
         return true;
     }
 
     if (sampler && !can_offload) {
         LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
 
+        if (sampling.samplers.count(seq_id) > 0) {
+            sched_need_reserve = true;
+        }
+
         sampling.samplers.erase(seq_id);
 
         return false;
@@ -998,6 +1057,8 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
 
     sampling.samplers.erase(seq_id);
 
+    sched_need_reserve = true;
+
     return true;
 }
 
@@ -1006,16 +1067,27 @@ void llama_context::set_adapter_lora(
             float scale) {
     LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale);
 
+    if (auto it = loras.find(adapter); it != loras.end()) {
+        if (it->second == scale) {
+            return;
+        }
+    }
+
     loras[adapter] = scale;
+
+    sched_need_reserve = true;
 }
 
 bool llama_context::rm_adapter_lora(
             llama_adapter_lora * adapter) {
     LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter);
 
-    auto pos = loras.find(adapter);
-    if (pos != loras.end()) {
-        loras.erase(pos);
+    auto it = loras.find(adapter);
+    if (it != loras.end()) {
+        loras.erase(it);
+
+        sched_need_reserve = true;
+
         return true;
     }
 
@@ -1025,7 +1097,13 @@ bool llama_context::rm_adapter_lora(
 void llama_context::clear_adapter_lora() {
     LLAMA_LOG_DEBUG("%s: call\n", __func__);
 
+    if (loras.empty()) {
+        return;
+    }
+
     loras.clear();
+
+    sched_need_reserve = true;
 }
 
 bool llama_context::apply_adapter_cvec(
@@ -1036,6 +1114,8 @@ bool llama_context::apply_adapter_cvec(
                 int32_t   il_end) {
     LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end);
 
+    // TODO: should we reserve?
+
     return cvec.apply(model, data, len, n_embd, il_start, il_end);
 }
 
@@ -1138,6 +1218,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
     // TODO: this clear of the buffer can easily be forgotten - need something better
     embd_seq.clear();
 
+    sched_reserve();
+
     n_queued_tokens += n_tokens;
 
     // reserve output buffer
@@ -1177,7 +1259,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
     auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
 
     // extract logits
-   if (logits && t_logits) {
+    if (logits && t_logits) {
         ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
         GGML_ASSERT(backend_res != nullptr);
         GGML_ASSERT(logits != nullptr);
@@ -1451,6 +1533,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
     embd_seq.clear();
     output_swaps.clear();
 
+    sched_reserve();
+
     bool did_optimize = false;
 
     // handle any pending shifts/copies

+ 10 - 0
src/llama-context.h

@@ -40,6 +40,14 @@ struct llama_context {
 
     ~llama_context();
 
+    // reserve a new backend scheduler (if needed)
+    // for example, when:
+    //   - changing loras
+    //   - changing samplers
+    //   - changing attention type
+    //   - etc.
+    void sched_reserve();
+
     void synchronize();
 
     const llama_model   & get_model()   const;
@@ -314,6 +322,8 @@ private:
 
     ggml_backend_sched_ptr sched;
 
+    bool sched_need_reserve = true;
+
     ggml_backend_t backend_cpu = nullptr;
     std::vector<ggml_backend_ptr> backends;
 

+ 2 - 0
src/llama-cparams.h

@@ -30,10 +30,12 @@ struct llama_cparams {
     bool causal_attn;
     bool offload_kqv;
     bool flash_attn;
+    bool auto_fa;
     bool no_perf;
     bool warmup;
     bool op_offload;
     bool kv_unified;
+    bool pipeline_parallel;
 
     enum llama_pooling_type pooling_type;
 

+ 38 - 63
tools/server/server-context.cpp

@@ -45,26 +45,6 @@ enum server_state {
     SERVER_STATE_READY,          // Server is ready and model is loaded
 };
 
-static bool server_task_type_need_embd(server_task_type task_type) {
-    switch (task_type) {
-        case SERVER_TASK_TYPE_EMBEDDING:
-        case SERVER_TASK_TYPE_RERANK:
-            return true;
-        default:
-            return false;
-    }
-}
-
-static bool server_task_type_need_logits(server_task_type task_type) {
-    switch (task_type) {
-        case SERVER_TASK_TYPE_COMPLETION:
-        case SERVER_TASK_TYPE_INFILL:
-            return true;
-        default:
-            return false;
-    }
-}
-
 struct server_slot {
     int id;
 
@@ -147,6 +127,17 @@ struct server_slot {
         return res;
     }
 
+    void prompt_clear(bool allow_processing) {
+        if (!allow_processing) {
+            GGML_ASSERT(!is_processing());
+        }
+
+        SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size());
+
+        llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
+        prompt.tokens.clear();
+    }
+
     std::vector<common_adapter_lora_info> lora;
     int32_t alora_invocation_start = -1;
 
@@ -196,30 +187,24 @@ struct server_slot {
         n_draft_total = 0;
         n_draft_accepted = 0;
 
+        task_prev = std::move(task);
         task.reset();
-        task_prev.reset();
+
+        llama_set_sampler(ctx, id, nullptr);
 
         // clear alora start
         alora_invocation_start = -1;
     }
 
-    // remove cached prompt + tokens
-    void clear(bool allow_processing) {
-        if (!allow_processing) {
-            GGML_ASSERT(!is_processing());
-        }
-
-        SLT_INF(*this, "clearing slot with %zu tokens\n", prompt.tokens.size());
+    void init_sampler() const {
+        common_sampler_reset(smpl.get());
 
-        llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
-        prompt.tokens.clear();
-    }
+        if (!task->need_sampling()) {
+            return;
+        }
 
-    void init_sampler() const {
         const int64_t t_start = ggml_time_us();
 
-        common_sampler_reset(smpl.get());
-
         int n_text = 0;
 
         for (int i = 0; i < (int) prompt.tokens.size(); i++) {
@@ -235,25 +220,13 @@ struct server_slot {
                 (ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size());
     }
 
-    // TODO: move to server_task
-    bool need_embd() const {
-        GGML_ASSERT(task);
-
-        return server_task_type_need_embd(task->type);
-    }
-
-    // TODO: move to server_task
-    bool need_logits() const {
-        GGML_ASSERT(task);
-
-        return server_task_type_need_logits(task->type);
-    }
-
     // if the context does not have a memory module then all embeddings have to be computed within a single ubatch
     // also we cannot split if the pooling would require any past tokens
     bool can_split() const {
+        GGML_ASSERT(task);
+
         return
-            !need_embd() ||
+            !task->need_embd() ||
             (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
     }
 
@@ -349,11 +322,10 @@ struct server_slot {
 
             // do not keep context of the child slots - the parent's context is enough
             if (is_child()) {
-                clear(false);
+                prompt_clear(false);
             }
 
-            task_prev = std::move(task);
-            task.reset();
+            reset();
 
             callback_on_release(id);
         }
@@ -801,6 +773,7 @@ private:
 
         slots.clear();
 
+        // initialize slots
         for (int i = 0; i < params_base.n_parallel; i++) {
             server_slot slot;
 
@@ -1049,7 +1022,7 @@ private:
                 ret->prompt_save(*prompt_cache);
 
                 if (!ret->prompt_load(*prompt_cache, task.tokens)) {
-                    ret->clear(false);
+                    ret->prompt_clear(false);
                 }
 
                 prompt_cache->update();
@@ -1081,7 +1054,7 @@ private:
             if (slot.prompt.n_tokens() > 0) {
                 SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
 
-                slot.clear(false);
+                slot.prompt_clear(false);
 
                 res = true;
 
@@ -1107,8 +1080,6 @@ private:
     }
 
     bool launch_slot_with_task(server_slot & slot, server_task && task) {
-        slot.reset();
-
         // process per-request lora adapters
         if (!task.params.lora.empty()) {
             auto task_loras = construct_lora_list(task.params.lora);
@@ -1182,7 +1153,7 @@ private:
         SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
 
         // initialize samplers
-        {
+        if (task.need_sampling()) {
             slot.smpl.reset(common_sampler_init(model, task.params.sampling));
 
             if (slot.smpl == nullptr) {
@@ -1211,6 +1182,8 @@ private:
             }
 
             SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
+        } else {
+            slot.smpl.reset();
         }
 
         // initialize draft batch
@@ -1864,7 +1837,7 @@ private:
                     // Erase token cache
                     const size_t n_erased = slot->prompt.tokens.size();
 
-                    slot->clear(false);
+                    slot->prompt_clear(false);
 
                     auto res = std::make_unique<server_task_result_slot_erase>();
                     res->id       = task.id;
@@ -2161,7 +2134,7 @@ private:
                         }
 
                         // TODO: support memory-less logits computation
-                        if (slot.need_logits() && !llama_get_memory(ctx)) {
+                        if (slot.task->need_logits() && !llama_get_memory(ctx)) {
                             send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
                             slot.release();
                             continue;
@@ -2421,7 +2394,7 @@ private:
                     if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
                         SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
 
-                        slot.clear(true);
+                        slot.prompt_clear(true);
 
                         // there is no common part left
                         slot.n_prompt_tokens_cache = 0;
@@ -2500,7 +2473,7 @@ private:
                             cur_tok,
                             slot.prompt.tokens.pos_next(),
                             { slot.id },
-                            slot.need_embd());
+                            slot.task->need_embd());
                         slot.prompt.tokens.push_back(cur_tok);
 
                         slot.n_prompt_tokens_processed++;
@@ -2590,7 +2563,7 @@ private:
                 slot_batched->lora[alora_disabled_id].scale = alora_scale;
             }
 
-            llama_set_embeddings(ctx, slot_batched->need_embd());
+            llama_set_embeddings(ctx, slot_batched->task->need_embd());
         }
 
         if (batch.n_tokens == 0) {
@@ -2648,7 +2621,7 @@ private:
 
                                 // note: it's complicated to keep track of how much of the current batch has been
                                 //       processed before the error occurred, so we simply clear the entire context
-                                slot.clear(false);
+                                slot.prompt_clear(false);
                             }
                         }
 
@@ -2727,6 +2700,8 @@ private:
                         continue; // continue loop of slots
                     }
 
+                    GGML_ASSERT(slot.task->need_sampling());
+
                     // prompt evaluated for next-token prediction
                     slot.state = SLOT_STATE_GENERATING;
                 } else if (slot.state != SLOT_STATE_GENERATING) {

+ 30 - 0
tools/server/server-task.h

@@ -156,6 +156,36 @@ struct server_task {
         return tokens.size();
     }
 
+    bool need_embd() const {
+        switch (type) {
+            case SERVER_TASK_TYPE_EMBEDDING:
+            case SERVER_TASK_TYPE_RERANK:
+                return true;
+            default:
+                return false;
+        }
+    }
+
+    bool need_logits() const {
+        switch (type) {
+            case SERVER_TASK_TYPE_COMPLETION:
+            case SERVER_TASK_TYPE_INFILL:
+                return true;
+            default:
+                return false;
+        }
+    }
+
+    bool need_sampling() const {
+        switch (type) {
+            case SERVER_TASK_TYPE_COMPLETION:
+            case SERVER_TASK_TYPE_INFILL:
+                return true;
+            default:
+                return false;
+        }
+    }
+
     static task_params params_from_json_cmpl(
         const llama_vocab * vocab,
         const common_params & params_base,