|
|
@@ -98,6 +98,8 @@ struct slot_params {
|
|
|
int64_t t_max_prompt_ms = -1; // TODO: implement
|
|
|
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
|
|
|
|
|
+ std::vector<common_lora_adapter_container> lora;
|
|
|
+
|
|
|
std::vector<std::string> antiprompt;
|
|
|
std::vector<std::string> response_fields;
|
|
|
bool timings_per_token = false;
|
|
|
@@ -120,6 +122,11 @@ struct slot_params {
|
|
|
samplers.emplace_back(common_sampler_type_to_str(sampler));
|
|
|
}
|
|
|
|
|
|
+ json lora = json::array();
|
|
|
+ for (size_t i = 0; i < this->lora.size(); ++i) {
|
|
|
+ lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
|
|
|
+ }
|
|
|
+
|
|
|
return json {
|
|
|
{"n_predict", n_predict}, // Server configured n_predict
|
|
|
{"seed", sampling.seed},
|
|
|
@@ -160,6 +167,7 @@ struct slot_params {
|
|
|
{"speculative.p_min", speculative.p_min},
|
|
|
{"timings_per_token", timings_per_token},
|
|
|
{"post_sampling_probs", post_sampling_probs},
|
|
|
+ {"lora", lora},
|
|
|
};
|
|
|
}
|
|
|
};
|
|
|
@@ -189,12 +197,16 @@ struct server_task {
|
|
|
// used by SERVER_TASK_TYPE_METRICS
|
|
|
bool metrics_reset_bucket = false;
|
|
|
|
|
|
+ // used by SERVER_TASK_TYPE_SET_LORA
|
|
|
+ std::vector<common_lora_adapter_container> set_lora;
|
|
|
+
|
|
|
server_task(server_task_type type) : type(type) {}
|
|
|
|
|
|
static slot_params params_from_json_cmpl(
|
|
|
const llama_model * model,
|
|
|
const llama_context * ctx,
|
|
|
const common_params & params_base,
|
|
|
+ const std::vector<common_lora_adapter_container> & lora_base,
|
|
|
const json & data) {
|
|
|
slot_params params;
|
|
|
|
|
|
@@ -251,6 +263,16 @@ struct server_task {
|
|
|
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
|
|
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
|
|
|
|
|
+ if (data.contains("lora")) {
|
|
|
+ if (data.at("lora").is_array()) {
|
|
|
+ params.lora = parse_lora_request(lora_base, data.at("lora"));
|
|
|
+ } else {
|
|
|
+ throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ params.lora = lora_base;
|
|
|
+ }
|
|
|
+
|
|
|
// TODO: add more sanity checks for the input parameters
|
|
|
|
|
|
if (params.sampling.penalty_last_n < -1) {
|
|
|
@@ -1110,6 +1132,8 @@ struct server_slot {
|
|
|
|
|
|
common_speculative * spec = nullptr;
|
|
|
|
|
|
+ std::vector<common_lora_adapter_container> lora;
|
|
|
+
|
|
|
// the index relative to completion multi-task request
|
|
|
size_t index = 0;
|
|
|
|
|
|
@@ -1191,6 +1215,11 @@ struct server_slot {
|
|
|
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
|
|
|
}
|
|
|
|
|
|
+ bool can_batch_with(server_slot & other_slot) {
|
|
|
+ return is_non_causal() == other_slot.is_non_causal()
|
|
|
+ && are_lora_equal(lora, other_slot.lora);
|
|
|
+ }
|
|
|
+
|
|
|
bool has_budget(const common_params & global_params) {
|
|
|
if (params.n_predict == -1 && global_params.n_predict == -1) {
|
|
|
return true; // limitless
|
|
|
@@ -1600,7 +1629,7 @@ struct server_context {
|
|
|
|
|
|
llama_model * model = nullptr;
|
|
|
llama_context * ctx = nullptr;
|
|
|
- std::vector<common_lora_adapter_container> loras;
|
|
|
+ std::vector<common_lora_adapter_container> lora;
|
|
|
|
|
|
llama_model * model_dft = nullptr;
|
|
|
llama_context_params cparams_dft;
|
|
|
@@ -1667,7 +1696,7 @@ struct server_context {
|
|
|
|
|
|
model = llama_init.model;
|
|
|
ctx = llama_init.context;
|
|
|
- loras = llama_init.lora_adapters;
|
|
|
+ lora = llama_init.lora_adapters;
|
|
|
|
|
|
if (model == nullptr) {
|
|
|
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
|
|
|
@@ -1866,6 +1895,12 @@ struct server_context {
|
|
|
slot.params = std::move(task.params);
|
|
|
slot.prompt_tokens = std::move(task.prompt_tokens);
|
|
|
|
|
|
+ if (!are_lora_equal(task.params.lora, slot.lora)) {
|
|
|
+ // if lora is changed, we cannot reuse cached tokens
|
|
|
+ slot.cache_tokens.clear();
|
|
|
+ slot.lora = std::move(task.params.lora);
|
|
|
+ }
|
|
|
+
|
|
|
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
|
|
|
|
|
|
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
|
|
@@ -2557,7 +2592,7 @@ struct server_context {
|
|
|
} break;
|
|
|
case SERVER_TASK_TYPE_SET_LORA:
|
|
|
{
|
|
|
- common_lora_adapters_apply(ctx, loras);
|
|
|
+ lora = std::move(task.set_lora);
|
|
|
auto res = std::make_unique<server_task_result_apply_lora>();
|
|
|
res->id = task.id;
|
|
|
queue_results.send(std::move(res));
|
|
|
@@ -2634,12 +2669,22 @@ struct server_context {
|
|
|
// start populating the batch for this iteration
|
|
|
common_batch_clear(batch);
|
|
|
|
|
|
+ // track if given slot can be batched with slots already in the batch
|
|
|
+ server_slot * slot_batched = nullptr;
|
|
|
+
|
|
|
// frist, add sampled tokens from any ongoing sequences
|
|
|
for (auto & slot : slots) {
|
|
|
if (slot.state != SLOT_STATE_GENERATING) {
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
+ // check if we can batch this slot with the previous one
|
|
|
+ if (!slot_batched) {
|
|
|
+ slot_batched = &slot;
|
|
|
+ } else if (!slot_batched->can_batch_with(slot)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
slot.i_batch = batch.n_tokens;
|
|
|
|
|
|
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
|
|
|
@@ -2658,15 +2703,18 @@ struct server_context {
|
|
|
int32_t n_batch = llama_n_batch(ctx);
|
|
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
|
|
|
|
|
- // track if this is an embedding or non-embedding batch
|
|
|
- // if we've added sampled tokens above, we are in non-embedding mode
|
|
|
- // -1: none, 0: non-embedding, 1: embedding
|
|
|
- // TODO: make enum
|
|
|
- int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
|
|
-
|
|
|
// next, batch any pending prompts without exceeding n_batch
|
|
|
if (params_base.cont_batching || batch.n_tokens == 0) {
|
|
|
for (auto & slot : slots) {
|
|
|
+ // check if we can batch this slot with the previous one
|
|
|
+ if (slot.is_processing()) {
|
|
|
+ if (!slot_batched) {
|
|
|
+ slot_batched = &slot;
|
|
|
+ } else if (!slot_batched->can_batch_with(slot)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
// this slot still has a prompt to be processed
|
|
|
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
|
|
auto & prompt_tokens = slot.prompt_tokens;
|
|
|
@@ -2827,14 +2875,6 @@ struct server_context {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // check that we are in the right batch_type, if not defer the slot
|
|
|
- int slot_type = slot.is_non_causal();
|
|
|
- if (batch_type == -1) {
|
|
|
- batch_type = slot_type;
|
|
|
- } else if (batch_type != slot_type) {
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
// keep only the common part
|
|
|
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
|
|
|
// could not partially delete (likely using a non-Transformer model)
|
|
|
@@ -2902,8 +2942,12 @@ struct server_context {
|
|
|
|
|
|
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
|
|
|
|
|
- // make sure we're in the right embedding mode
|
|
|
- llama_set_embeddings(ctx, batch_type == 1);
|
|
|
+ if (slot_batched) {
|
|
|
+ // make sure we're in the right embedding mode
|
|
|
+ llama_set_embeddings(ctx, slot_batched->is_non_causal());
|
|
|
+ // apply lora, only need to do it once per batch
|
|
|
+ common_lora_adapters_apply(ctx, slot_batched->lora);
|
|
|
+ }
|
|
|
|
|
|
// process the created batch of tokens
|
|
|
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
|
|
@@ -3623,7 +3667,12 @@ int main(int argc, char ** argv) {
|
|
|
task.index = i;
|
|
|
|
|
|
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
|
|
- task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
|
|
|
+ task.params = server_task::params_from_json_cmpl(
|
|
|
+ ctx_server.model,
|
|
|
+ ctx_server.ctx,
|
|
|
+ ctx_server.params_base,
|
|
|
+ ctx_server.lora,
|
|
|
+ data);
|
|
|
task.id_selected_slot = json_value(data, "id_slot", -1);
|
|
|
|
|
|
// OAI-compat
|
|
|
@@ -4049,8 +4098,8 @@ int main(int argc, char ** argv) {
|
|
|
|
|
|
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
|
|
|
json result = json::array();
|
|
|
- for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
|
|
|
- auto & lora = ctx_server.loras[i];
|
|
|
+ for (size_t i = 0; i < ctx_server.lora.size(); ++i) {
|
|
|
+ auto & lora = ctx_server.lora[i];
|
|
|
result.push_back({
|
|
|
{"id", i},
|
|
|
{"path", lora.path},
|
|
|
@@ -4062,27 +4111,14 @@ int main(int argc, char ** argv) {
|
|
|
};
|
|
|
|
|
|
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
|
|
|
- const std::vector<json> body = json::parse(req.body);
|
|
|
- int max_idx = ctx_server.loras.size();
|
|
|
-
|
|
|
- // clear existing value
|
|
|
- for (auto & lora : ctx_server.loras) {
|
|
|
- lora.scale = 0.0f;
|
|
|
- }
|
|
|
-
|
|
|
- // set value
|
|
|
- for (auto entry : body) {
|
|
|
- int id = entry.at("id");
|
|
|
- float scale = entry.at("scale");
|
|
|
- if (0 <= id && id < max_idx) {
|
|
|
- ctx_server.loras[id].scale = scale;
|
|
|
- } else {
|
|
|
- throw std::runtime_error("invalid adapter id");
|
|
|
- }
|
|
|
+ const json body = json::parse(req.body);
|
|
|
+ if (!body.is_array()) {
|
|
|
+ res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
|
|
|
+ return;
|
|
|
}
|
|
|
-
|
|
|
server_task task(SERVER_TASK_TYPE_SET_LORA);
|
|
|
task.id = ctx_server.queue_tasks.get_new_id();
|
|
|
+ task.set_lora = parse_lora_request(ctx_server.lora, body);
|
|
|
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
|
ctx_server.queue_tasks.post(task);
|
|
|
|