|
|
@@ -92,6 +92,7 @@ enum server_task_type {
|
|
|
enum server_task_cmpl_type {
|
|
|
SERVER_TASK_CMPL_TYPE_NORMAL,
|
|
|
SERVER_TASK_CMPL_TYPE_EMBEDDING,
|
|
|
+ SERVER_TASK_CMPL_TYPE_RERANK,
|
|
|
SERVER_TASK_CMPL_TYPE_INFILL,
|
|
|
};
|
|
|
|
|
|
@@ -172,6 +173,7 @@ struct server_slot {
|
|
|
std::vector<completion_token_output> generated_token_probs;
|
|
|
|
|
|
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
|
|
+
|
|
|
bool has_next_token = true;
|
|
|
bool truncated = false;
|
|
|
bool stopped_eos = false;
|
|
|
@@ -954,8 +956,17 @@ struct server_context {
|
|
|
slot.prompt = *prompt;
|
|
|
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
|
|
slot.prompt = prompt->at(0);
|
|
|
+ } else if (prompt->is_array() && prompt->size() > 1) {
|
|
|
+ // array of strings
|
|
|
+ for (const auto & el : *prompt) {
|
|
|
+ if (!el.is_string()) {
|
|
|
+ send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ slot.prompt = *prompt;
|
|
|
} else {
|
|
|
- send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
|
|
+ send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
|
|
return false;
|
|
|
}
|
|
|
}
|
|
|
@@ -1389,6 +1400,7 @@ struct server_context {
|
|
|
|
|
|
res.data = json {
|
|
|
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
|
|
+ {"index", slot.index},
|
|
|
};
|
|
|
|
|
|
continue;
|
|
|
@@ -1407,6 +1419,44 @@ struct server_context {
|
|
|
queue_results.send(res);
|
|
|
}
|
|
|
|
|
|
+ void send_rerank(const server_slot & slot, const llama_batch & batch) {
|
|
|
+ server_task_result res;
|
|
|
+ res.id = slot.id_task;
|
|
|
+ res.error = false;
|
|
|
+ res.stop = true;
|
|
|
+
|
|
|
+ for (int i = 0; i < batch.n_tokens; ++i) {
|
|
|
+ if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
|
|
+ if (embd == NULL) {
|
|
|
+ embd = llama_get_embeddings_ith(ctx, i);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (embd == NULL) {
|
|
|
+ SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
|
|
+
|
|
|
+ res.data = json {
|
|
|
+ {"index", slot.index},
|
|
|
+ {"score", -1e6},
|
|
|
+ };
|
|
|
+
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ res.data = json {
|
|
|
+ {"index", slot.index},
|
|
|
+ {"score", embd[0]},
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
+ SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str());
|
|
|
+
|
|
|
+ queue_results.send(res);
|
|
|
+ }
|
|
|
+
|
|
|
//
|
|
|
// Functions to create new task(s) and receive result(s)
|
|
|
//
|
|
|
@@ -1442,13 +1492,27 @@ struct server_context {
|
|
|
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
|
|
|
else if (prompt.is_array()) {
|
|
|
std::vector<json> prompts = prompt;
|
|
|
- for (size_t i = 0; i < prompts.size(); i++) {
|
|
|
- const auto & e = prompts[i];
|
|
|
- if (e.is_string() || json_is_array_of_numbers(e)) {
|
|
|
- data["index"] = i;
|
|
|
- create_task(data, true, e);
|
|
|
- } else {
|
|
|
- throw std::runtime_error(error_msg);
|
|
|
+ if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
|
|
+ // prompts[0] is the question
|
|
|
+ // the rest are the answers/documents
|
|
|
+ SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
|
|
|
+ for (size_t i = 1; i < prompts.size(); i++) {
|
|
|
+ json qd;
|
|
|
+ qd.push_back(prompts[0]);
|
|
|
+ qd.push_back(prompts[i]);
|
|
|
+ data["index"] = i - 1;
|
|
|
+ create_task(data, true, qd);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
|
|
|
+ for (size_t i = 0; i < prompts.size(); i++) {
|
|
|
+ const auto & e = prompts[i];
|
|
|
+ if (e.is_string() || json_is_array_of_numbers(e)) {
|
|
|
+ data["index"] = i;
|
|
|
+ create_task(data, true, e);
|
|
|
+ } else {
|
|
|
+ throw std::runtime_error(error_msg);
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -1492,7 +1556,9 @@ struct server_context {
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- size_t idx = result.data["index"];
|
|
|
+ const size_t idx = result.data["index"];
|
|
|
+ GGML_ASSERT(idx < results.size() && "index out of range");
|
|
|
+
|
|
|
results[idx] = result;
|
|
|
}
|
|
|
result_handler(results);
|
|
|
@@ -1903,6 +1969,7 @@ struct server_context {
|
|
|
// track if this is an embedding or non-embedding batch
|
|
|
// if we've added sampled tokens above, we are in non-embedding mode
|
|
|
// -1: none, 0: non-embedding, 1: embedding
|
|
|
+ // TODO: make enum
|
|
|
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
|
|
|
|
|
// next, batch any pending prompts without exceeding n_batch
|
|
|
@@ -1951,6 +2018,29 @@ struct server_context {
|
|
|
}
|
|
|
|
|
|
prompt_tokens = embd_inp;
|
|
|
+ } else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
|
|
+ // require slot.prompt to be array of 2 strings
|
|
|
+ if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
|
|
|
+ SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
|
|
|
+ slot.release();
|
|
|
+ send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ // prompt: <s>query</s><s>doc</s>
|
|
|
+ prompt_tokens.clear();
|
|
|
+ prompt_tokens.push_back(llama_token_bos(model));
|
|
|
+ {
|
|
|
+ const auto part = tokenize(slot.prompt[0], false);
|
|
|
+ prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
|
|
+ }
|
|
|
+ prompt_tokens.push_back(llama_token_eos(model));
|
|
|
+ prompt_tokens.push_back(llama_token_bos(model));
|
|
|
+ {
|
|
|
+ const auto part = tokenize(slot.prompt[1], false);
|
|
|
+ prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
|
|
+ }
|
|
|
+ prompt_tokens.push_back(llama_token_eos(model));
|
|
|
} else {
|
|
|
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
|
|
}
|
|
|
@@ -1970,7 +2060,7 @@ struct server_context {
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
|
|
|
+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
|
|
// this prompt is too large to process - discard it
|
|
|
if (slot.n_prompt_tokens > n_ubatch) {
|
|
|
slot.release();
|
|
|
@@ -2048,7 +2138,8 @@ struct server_context {
|
|
|
slot.n_prompt_tokens_processed = 0;
|
|
|
}
|
|
|
|
|
|
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
|
|
|
+ // non-causal tasks require to fit the entire prompt in the physical batch
|
|
|
+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
|
|
// cannot fit the prompt in the current batch - will try next iter
|
|
|
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
|
|
continue;
|
|
|
@@ -2056,7 +2147,10 @@ struct server_context {
|
|
|
}
|
|
|
|
|
|
// check that we are in the right batch_type, if not defer the slot
|
|
|
- bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0;
|
|
|
+ const bool slot_type =
|
|
|
+ slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
|
|
|
+ slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
|
|
|
+
|
|
|
if (batch_type == -1) {
|
|
|
batch_type = slot_type;
|
|
|
} else if (batch_type != slot_type) {
|
|
|
@@ -2229,6 +2323,13 @@ struct server_context {
|
|
|
continue; // continue loop of slots
|
|
|
}
|
|
|
|
|
|
+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
|
|
+ send_rerank(slot, batch_view);
|
|
|
+ slot.release();
|
|
|
+ slot.i_batch = -1;
|
|
|
+ continue; // continue loop of slots
|
|
|
+ }
|
|
|
+
|
|
|
// prompt evaluated for next-token prediction
|
|
|
slot.state = SLOT_STATE_GENERATING;
|
|
|
} else if (slot.state != SLOT_STATE_GENERATING) {
|
|
|
@@ -2787,8 +2888,8 @@ int main(int argc, char ** argv) {
|
|
|
};
|
|
|
|
|
|
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
|
|
|
- if (ctx_server.params.embedding) {
|
|
|
- res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
|
|
+ if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
|
|
+ res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
@@ -2848,8 +2949,8 @@ int main(int argc, char ** argv) {
|
|
|
|
|
|
// TODO: maybe merge this function with "handle_completions_generic"
|
|
|
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
|
|
|
- if (ctx_server.params.embedding) {
|
|
|
- res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
|
|
+ if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
|
|
+ res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
@@ -2973,6 +3074,11 @@ int main(int argc, char ** argv) {
|
|
|
};
|
|
|
|
|
|
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
|
|
+ // TODO: somehow clean up this checks in the future
|
|
|
+ if (!ctx_server.params.embedding || ctx_server.params.reranking) {
|
|
|
+ res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
|
|
+ return;
|
|
|
+ }
|
|
|
const json body = json::parse(req.body);
|
|
|
bool is_openai = false;
|
|
|
|
|
|
@@ -3023,6 +3129,79 @@ int main(int argc, char ** argv) {
|
|
|
res_ok(res, root);
|
|
|
};
|
|
|
|
|
|
+ const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
|
|
+ if (!ctx_server.params.reranking) {
|
|
|
+ res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ const json body = json::parse(req.body);
|
|
|
+
|
|
|
+ // TODO: implement
|
|
|
+ //int top_n = 1;
|
|
|
+ //if (body.count("top_n") != 1) {
|
|
|
+ // top_n = body.at("top_n");
|
|
|
+ //} else {
|
|
|
+ // res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
|
|
+ // return;
|
|
|
+ //}
|
|
|
+
|
|
|
+ json query;
|
|
|
+ if (body.count("query") == 1) {
|
|
|
+ query = body.at("query");
|
|
|
+ if (!query.is_string()) {
|
|
|
+ res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
|
|
|
+ if (documents.empty()) {
|
|
|
+ res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ // construct prompt object: array of ["query", "doc0", "doc1", ...]
|
|
|
+ json prompt;
|
|
|
+ prompt.push_back(query);
|
|
|
+ for (const auto & doc : documents) {
|
|
|
+ prompt.push_back(doc);
|
|
|
+ }
|
|
|
+
|
|
|
+ LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
|
|
|
+
|
|
|
+ // create and queue the task
|
|
|
+ json responses = json::array();
|
|
|
+ bool error = false;
|
|
|
+ {
|
|
|
+ std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
|
|
|
+ ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
|
+ ctx_server.queue_tasks.post(tasks);
|
|
|
+
|
|
|
+ // get the result
|
|
|
+ std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
|
|
+
|
|
|
+ ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
|
|
+ for (const auto & res : results) {
|
|
|
+ responses.push_back(res.data);
|
|
|
+ }
|
|
|
+ }, [&](const json & error_data) {
|
|
|
+ res_error(res, error_data);
|
|
|
+ error = true;
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ if (error) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ // write JSON response
|
|
|
+ json root = format_response_rerank(body, responses);
|
|
|
+ res_ok(res, root);
|
|
|
+ };
|
|
|
+
|
|
|
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
|
|
|
json result = json::array();
|
|
|
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
|
|
|
@@ -3119,6 +3298,10 @@ int main(int argc, char ** argv) {
|
|
|
svr->Post("/embedding", handle_embeddings); // legacy
|
|
|
svr->Post("/embeddings", handle_embeddings);
|
|
|
svr->Post("/v1/embeddings", handle_embeddings);
|
|
|
+ svr->Post("/rerank", handle_rerank);
|
|
|
+ svr->Post("/reranking", handle_rerank);
|
|
|
+ svr->Post("/v1/rerank", handle_rerank);
|
|
|
+ svr->Post("/v1/reranking", handle_rerank);
|
|
|
svr->Post("/tokenize", handle_tokenize);
|
|
|
svr->Post("/detokenize", handle_detokenize);
|
|
|
// LoRA adapters hotswap
|