|
@@ -2581,12 +2581,14 @@ struct server_context {
|
|
|
continue;
|
|
continue;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
|
|
|
|
- if (embd == NULL) {
|
|
|
|
|
|
|
+ const float * embd = nullptr;
|
|
|
|
|
+ if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
|
|
embd = llama_get_embeddings_ith(ctx, i);
|
|
embd = llama_get_embeddings_ith(ctx, i);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if (embd == NULL) {
|
|
|
|
|
|
|
+ if (embd == nullptr) {
|
|
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
|
|
|
|
|
|
|
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
|
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
|
@@ -2594,12 +2596,12 @@ struct server_context {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// normalize only when there is pooling
|
|
// normalize only when there is pooling
|
|
|
- // TODO: configurable
|
|
|
|
|
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
|
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
|
|
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
|
|
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
|
|
|
res->embedding.push_back(embd_res);
|
|
res->embedding.push_back(embd_res);
|
|
|
|
|
+ break;
|
|
|
} else {
|
|
} else {
|
|
|
- res->embedding.push_back({ embd, embd + n_embd });
|
|
|
|
|
|
|
+ res->embedding.emplace_back(embd, embd + n_embd);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|