|
|
@@ -35,7 +35,8 @@ constexpr int HTTP_POLLING_SECONDS = 1;
|
|
|
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
|
|
|
enum slot_state {
|
|
|
SLOT_STATE_IDLE,
|
|
|
- SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
|
|
|
+ SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
|
|
|
+ SLOT_STATE_STARTED, // after assigning a task and about to process prompt
|
|
|
SLOT_STATE_PROCESSING_PROMPT,
|
|
|
SLOT_STATE_DONE_PROMPT,
|
|
|
SLOT_STATE_GENERATING,
|
|
|
@@ -254,6 +255,15 @@ struct server_slot {
|
|
|
generated_token_probs.push_back(token);
|
|
|
}
|
|
|
|
|
|
+ // note: a slot can also be either a parent or a child
|
|
|
+ bool is_parent() const {
|
|
|
+ return is_processing() && task->n_children > 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ bool is_child() const {
|
|
|
+ return is_processing() && task->id_parent >= 0;
|
|
|
+ }
|
|
|
+
|
|
|
void release() {
|
|
|
if (is_processing()) {
|
|
|
GGML_ASSERT(task);
|
|
|
@@ -383,6 +393,17 @@ struct server_slot {
|
|
|
|
|
|
return res;
|
|
|
}
|
|
|
+
|
|
|
+ void copy_state_to(server_slot & other) const {
|
|
|
+ llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1);
|
|
|
+ llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1);
|
|
|
+ other.n_decoded = n_decoded;
|
|
|
+ other.n_remaining = n_remaining;
|
|
|
+ other.i_batch = i_batch;
|
|
|
+ other.n_prompt_tokens_cache = n_prompt_tokens_cache;
|
|
|
+ other.n_prompt_tokens_processed = n_prompt_tokens_processed;
|
|
|
+ other.prompt = prompt.clone();
|
|
|
+ }
|
|
|
};
|
|
|
|
|
|
|
|
|
@@ -1022,7 +1043,9 @@ struct server_context_impl {
|
|
|
|
|
|
slot.task = std::make_unique<const server_task>(std::move(task));
|
|
|
|
|
|
- slot.state = SLOT_STATE_STARTED;
|
|
|
+ slot.state = slot.is_child()
|
|
|
+ ? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
|
|
|
+ : SLOT_STATE_STARTED;
|
|
|
|
|
|
SLT_INF(slot, "%s", "processing task\n");
|
|
|
|
|
|
@@ -1684,6 +1707,12 @@ struct server_context_impl {
|
|
|
GGML_ABORT("not supported by multimodal");
|
|
|
}
|
|
|
|
|
|
+ if (slot.is_parent() || slot.is_child()) {
|
|
|
+ send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER);
|
|
|
+ slot.release();
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
// Shift context
|
|
|
int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
|
|
|
|
|
|
@@ -2308,6 +2337,26 @@ struct server_context_impl {
|
|
|
n_batch = llama_n_batch(ctx);
|
|
|
|
|
|
for (auto & slot : slots) {
|
|
|
+ // may need to copy state to other slots
|
|
|
+ if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
|
|
|
+ std::vector<server_slot *> child_slots;
|
|
|
+ for (auto & other : slots) {
|
|
|
+ if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
|
|
|
+ child_slots.push_back(&other);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // we can only proceed if all child slots are having the correct tasks
|
|
|
+ if (child_slots.size() == slot.task->n_children) {
|
|
|
+ // copy state to the child slots
|
|
|
+ for (auto & child : child_slots) {
|
|
|
+ SLT_INF(slot, "copying state to child %d\n", child->id);
|
|
|
+ slot.copy_state_to(*child);
|
|
|
+ child->state = SLOT_STATE_DONE_PROMPT;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
// optionally send prompt processing progress
|
|
|
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
|
|
|
if (slot.task->params.stream && slot.task->params.return_progress) {
|
|
|
@@ -2593,11 +2642,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|
|
}
|
|
|
tasks.reserve(inputs.size());
|
|
|
states.reserve(inputs.size());
|
|
|
+ int idx = 0;
|
|
|
for (size_t i = 0; i < inputs.size(); i++) {
|
|
|
server_task task = server_task(type);
|
|
|
|
|
|
task.id = ctx_server.queue_tasks.get_new_id();
|
|
|
- task.index = i;
|
|
|
+ task.index = idx++;
|
|
|
|
|
|
task.tokens = std::move(inputs[i]);
|
|
|
task.params = server_task::params_from_json_cmpl(
|
|
|
@@ -2612,6 +2662,18 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|
|
task.params.oaicompat_model = ctx_server.model_name;
|
|
|
states.push_back(task.params.oaicompat_chat_syntax);
|
|
|
|
|
|
+ if (task.params.n_cmpl > 1) {
|
|
|
+ task.n_children = task.params.n_cmpl - 1;
|
|
|
+ for (size_t j = 0; j < task.n_children; j++) {
|
|
|
+ server_task child = task.create_child(
|
|
|
+ task.id,
|
|
|
+ ctx_server.queue_tasks.get_new_id(),
|
|
|
+ idx++);
|
|
|
+ states.push_back(child.params.oaicompat_chat_syntax);
|
|
|
+ tasks.push_back(std::move(child));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
tasks.push_back(std::move(task));
|
|
|
}
|
|
|
|
|
|
@@ -2638,8 +2700,21 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|
|
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
|
|
arr.push_back(res->to_json());
|
|
|
}
|
|
|
- // if single request, return single object instead of array
|
|
|
- res->ok(arr.size() == 1 ? arr[0] : arr);
|
|
|
+ GGML_ASSERT(!arr.empty() && "empty results");
|
|
|
+ if (arr.size() == 1) {
|
|
|
+ // if single request, return single object instead of array
|
|
|
+ res->ok(arr[0]);
|
|
|
+ } else if (res_type == TASK_RESPONSE_TYPE_OAI_CHAT || res_type == TASK_RESPONSE_TYPE_OAI_CMPL) {
|
|
|
+ // if multiple results in OAI format, we need to re-format them
|
|
|
+ json & choices = arr[0]["choices"];
|
|
|
+ for (size_t i = 1; i < arr.size(); i++) {
|
|
|
+ choices.push_back(std::move(arr[i]["choices"][0]));
|
|
|
+ }
|
|
|
+ res->ok(arr[0]);
|
|
|
+ } else {
|
|
|
+ // multi-results, non-OAI compat
|
|
|
+ res->ok(arr);
|
|
|
+ }
|
|
|
}
|
|
|
} else {
|
|
|
// in streaming mode, the first error must be treated as non-stream response
|