|
|
@@ -849,47 +849,44 @@ static json format_response_rerank(
|
|
|
const json & request,
|
|
|
const json & ranks,
|
|
|
bool is_tei_format,
|
|
|
- std::vector<std::string> & texts) {
|
|
|
- json res;
|
|
|
- if (is_tei_format) {
|
|
|
- // TEI response format
|
|
|
- res = json::array();
|
|
|
- bool return_text = json_value(request, "return_text", false);
|
|
|
- for (const auto & rank : ranks) {
|
|
|
- int index = json_value(rank, "index", 0);
|
|
|
- json elem = json{
|
|
|
- {"index", index},
|
|
|
- {"score", json_value(rank, "score", 0.0)},
|
|
|
- };
|
|
|
- if (return_text) {
|
|
|
- elem["text"] = std::move(texts[index]);
|
|
|
- }
|
|
|
- res.push_back(elem);
|
|
|
- }
|
|
|
- } else {
|
|
|
- // Jina response format
|
|
|
- json results = json::array();
|
|
|
- int32_t n_tokens = 0;
|
|
|
- for (const auto & rank : ranks) {
|
|
|
- results.push_back(json{
|
|
|
- {"index", json_value(rank, "index", 0)},
|
|
|
- {"relevance_score", json_value(rank, "score", 0.0)},
|
|
|
- });
|
|
|
-
|
|
|
- n_tokens += json_value(rank, "tokens_evaluated", 0);
|
|
|
- }
|
|
|
-
|
|
|
- res = json{
|
|
|
- {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
|
|
- {"object", "list"},
|
|
|
- {"usage", json{
|
|
|
- {"prompt_tokens", n_tokens},
|
|
|
- {"total_tokens", n_tokens}
|
|
|
- }},
|
|
|
- {"results", results}
|
|
|
+ std::vector<std::string> & texts,
|
|
|
+ int top_n) {
|
|
|
+ int32_t n_tokens = 0;
|
|
|
+ bool return_text = is_tei_format && json_value(request, "return_text", false);
|
|
|
+ std::vector<json> elements; // Temporary vector to hold unsorted elements
|
|
|
+ std::string score_label = is_tei_format ? "score" : "relevance_score";
|
|
|
+ for (const auto & rank : ranks) {
|
|
|
+ int index = json_value(rank, "index", 0);
|
|
|
+ json elem = json{
|
|
|
+ {"index", index},
|
|
|
+ {score_label, json_value(rank, "score", 0.0)},
|
|
|
};
|
|
|
+ n_tokens += json_value(rank, "tokens_evaluated", 0);
|
|
|
+ if (return_text) {
|
|
|
+ elem["text"] = std::move(texts[index]);
|
|
|
+ }
|
|
|
+ elements.push_back(elem);
|
|
|
}
|
|
|
|
|
|
+ std::sort(elements.begin(), elements.end(), [score_label](const json& a, const json& b) {
|
|
|
+ return json_value(a, score_label, 0.0) > json_value(b, score_label, 0.0);
|
|
|
+ });
|
|
|
+
|
|
|
+ elements.resize(std::min(top_n, (int)elements.size()));
|
|
|
+ json results = elements;
|
|
|
+
|
|
|
+ if (is_tei_format) return results;
|
|
|
+
|
|
|
+ json res = json{
|
|
|
+ {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
|
|
+ {"object", "list"},
|
|
|
+ {"usage", json{
|
|
|
+ {"prompt_tokens", n_tokens},
|
|
|
+ {"total_tokens", n_tokens}
|
|
|
+ }},
|
|
|
+ {"results", results}
|
|
|
+ };
|
|
|
+
|
|
|
return res;
|
|
|
}
|
|
|
|