|
|
@@ -57,11 +57,21 @@ struct callback_data {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
+static bool has_pooling(llama_context * ctx) {
|
|
|
+ switch (llama_pooling_type(ctx)) {
|
|
|
+ case LLAMA_POOLING_TYPE_NONE:
|
|
|
+ case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
|
+ return false;
|
|
|
+ default:
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
struct output_data {
|
|
|
float * data_ptr = nullptr;
|
|
|
int data_size = 0;
|
|
|
std::string type_suffix;
|
|
|
- std::vector<float> storage;
|
|
|
+ std::vector<float> embd_norm;
|
|
|
std::string prompt;
|
|
|
std::vector<llama_token> tokens;
|
|
|
|
|
|
@@ -73,24 +83,32 @@ struct output_data {
|
|
|
prompt = params.prompt;
|
|
|
|
|
|
if (params.embedding) {
|
|
|
- const int n_embd = llama_model_n_embd_out(model);
|
|
|
- const bool pooling_enabled = llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE;
|
|
|
- const int n_embd_count = pooling_enabled ? 1 : tokens.size();
|
|
|
- const int n_embeddings = n_embd * n_embd_count;
|
|
|
-
|
|
|
- float * embeddings;
|
|
|
- if (pooling_enabled) {
|
|
|
- embeddings = llama_get_embeddings_seq(ctx, 0);
|
|
|
- storage.resize(n_embeddings);
|
|
|
- common_embd_normalize(embeddings, storage.data(), n_embeddings, params.embd_normalize);
|
|
|
- embeddings = storage.data();
|
|
|
- } else {
|
|
|
- embeddings = llama_get_embeddings(ctx);
|
|
|
+ const int n_embd = llama_model_n_embd_out(model);
|
|
|
+ const bool pooling = has_pooling(ctx);
|
|
|
+ const int n_embd_count = pooling ? 1 : tokens.size();
|
|
|
+ const int n_floats = n_embd * n_embd_count;
|
|
|
+
|
|
|
+ float * embd_raw = pooling ? llama_get_embeddings_seq(ctx, 0) : llama_get_embeddings(ctx);
|
|
|
+ if (embd_raw == nullptr) {
|
|
|
+ throw std::runtime_error("failed to get embeddings from the model");
|
|
|
}
|
|
|
|
|
|
- data_ptr = embeddings;
|
|
|
- data_size = n_embeddings;
|
|
|
+ LOG_DBG("pooling_enabled: %s\n", pooling ? "true" : "false");
|
|
|
+ LOG_DBG("n_embd: %d\n", n_embd);
|
|
|
+ LOG_DBG("n_floats: %d\n", n_floats);
|
|
|
+ LOG_DBG("n_embd_count: %d\n", n_embd_count);
|
|
|
+
|
|
|
+ data_ptr = embd_raw;
|
|
|
+ data_size = n_floats;
|
|
|
type_suffix = "-embeddings";
|
|
|
+
|
|
|
+ if (params.embd_normalize >= 0) {
|
|
|
+ embd_norm.resize(n_floats);
|
|
|
+ for (int i = 0; i < n_embd_count; i++) {
|
|
|
+ common_embd_normalize(embd_raw+i*n_embd, embd_norm.data()+i*n_embd, n_embd, params.embd_normalize);
|
|
|
+ }
|
|
|
+ data_ptr = embd_norm.data();
|
|
|
+ }
|
|
|
} else {
|
|
|
const float * logits = llama_get_logits_ith(ctx, tokens.size() - 1);
|
|
|
const int n_logits = llama_vocab_n_tokens(vocab);
|