|
|
@@ -33,7 +33,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
|
|
|
+static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd_out, int embd_norm) {
|
|
|
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
|
|
|
|
|
// clear previous kv_cache values (irrelevant for embeddings)
|
|
|
@@ -65,8 +65,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|
|
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
|
|
|
}
|
|
|
|
|
|
- float * out = output + embd_pos * n_embd;
|
|
|
- common_embd_normalize(embd, out, n_embd, embd_norm);
|
|
|
+ float * out = output + embd_pos * n_embd_out;
|
|
|
+ common_embd_normalize(embd, out, n_embd_out, embd_norm);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -252,8 +252,8 @@ int main(int argc, char ** argv) {
|
|
|
}
|
|
|
|
|
|
// allocate output
|
|
|
- const int n_embd = llama_model_n_embd(model);
|
|
|
- std::vector<float> embeddings(n_embd_count * n_embd, 0);
|
|
|
+ const int n_embd_out = llama_model_n_embd_out(model);
|
|
|
+ std::vector<float> embeddings(n_embd_count * n_embd_out, 0);
|
|
|
float * emb = embeddings.data();
|
|
|
|
|
|
// break into batches
|
|
|
@@ -267,8 +267,8 @@ int main(int argc, char ** argv) {
|
|
|
|
|
|
// encode if at capacity
|
|
|
if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) {
|
|
|
- float * out = emb + e * n_embd;
|
|
|
- batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
|
|
|
+ float * out = emb + e * n_embd_out;
|
|
|
+ batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize);
|
|
|
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
|
|
|
s = 0;
|
|
|
common_batch_clear(batch);
|
|
|
@@ -280,8 +280,8 @@ int main(int argc, char ** argv) {
|
|
|
}
|
|
|
|
|
|
// final batch
|
|
|
- float * out = emb + e * n_embd;
|
|
|
- batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
|
|
|
+ float * out = emb + e * n_embd_out;
|
|
|
+ batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize);
|
|
|
|
|
|
if (params.embd_out.empty()) {
|
|
|
LOG("\n");
|
|
|
@@ -289,19 +289,19 @@ int main(int argc, char ** argv) {
|
|
|
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
|
|
for (int j = 0; j < n_embd_count; j++) {
|
|
|
LOG("embedding %d: ", j);
|
|
|
- for (int i = 0; i < std::min(3, n_embd); i++) {
|
|
|
+ for (int i = 0; i < std::min(3, n_embd_out); i++) {
|
|
|
if (params.embd_normalize == 0) {
|
|
|
- LOG("%6.0f ", emb[j * n_embd + i]);
|
|
|
+ LOG("%6.0f ", emb[j * n_embd_out + i]);
|
|
|
} else {
|
|
|
- LOG("%9.6f ", emb[j * n_embd + i]);
|
|
|
+ LOG("%9.6f ", emb[j * n_embd_out + i]);
|
|
|
}
|
|
|
}
|
|
|
LOG(" ... ");
|
|
|
- for (int i = n_embd - 3; i < n_embd; i++) {
|
|
|
+ for (int i = n_embd_out - 3; i < n_embd_out; i++) {
|
|
|
if (params.embd_normalize == 0) {
|
|
|
- LOG("%6.0f ", emb[j * n_embd + i]);
|
|
|
+ LOG("%6.0f ", emb[j * n_embd_out + i]);
|
|
|
} else {
|
|
|
- LOG("%9.6f ", emb[j * n_embd + i]);
|
|
|
+ LOG("%9.6f ", emb[j * n_embd_out + i]);
|
|
|
}
|
|
|
}
|
|
|
LOG("\n");
|
|
|
@@ -320,9 +320,9 @@ int main(int argc, char ** argv) {
|
|
|
for (uint32_t i = 0; i < n_cls_out; i++) {
|
|
|
// NOTE: if you change this log - update the tests in ci/run.sh
|
|
|
if (n_cls_out == 1) {
|
|
|
- LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
|
|
|
+ LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd_out]);
|
|
|
} else {
|
|
|
- LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str());
|
|
|
+ LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd_out + i], cls_out_labels[i].c_str());
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -330,11 +330,11 @@ int main(int argc, char ** argv) {
|
|
|
// print the first part of the embeddings or for a single prompt, the full embedding
|
|
|
for (int j = 0; j < n_prompts; j++) {
|
|
|
LOG("embedding %d: ", j);
|
|
|
- for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
|
|
|
+ for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd_out) : n_embd_out); i++) {
|
|
|
if (params.embd_normalize == 0) {
|
|
|
- LOG("%6.0f ", emb[j * n_embd + i]);
|
|
|
+ LOG("%6.0f ", emb[j * n_embd_out + i]);
|
|
|
} else {
|
|
|
- LOG("%9.6f ", emb[j * n_embd + i]);
|
|
|
+ LOG("%9.6f ", emb[j * n_embd_out + i]);
|
|
|
}
|
|
|
}
|
|
|
LOG("\n");
|
|
|
@@ -350,7 +350,7 @@ int main(int argc, char ** argv) {
|
|
|
LOG("\n");
|
|
|
for (int i = 0; i < n_prompts; i++) {
|
|
|
for (int j = 0; j < n_prompts; j++) {
|
|
|
- float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
|
|
|
+ float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out);
|
|
|
LOG("%6.2f ", sim);
|
|
|
}
|
|
|
LOG("%1.10s", prompts[i].c_str());
|
|
|
@@ -368,9 +368,9 @@ int main(int argc, char ** argv) {
|
|
|
if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j);
|
|
|
LOG("[");
|
|
|
for (int i = 0;;) { // at least one iteration (n_embd > 0)
|
|
|
- LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]);
|
|
|
+ LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd_out + i]);
|
|
|
i++;
|
|
|
- if (i < n_embd) LOG(","); else break;
|
|
|
+ if (i < n_embd_out) LOG(","); else break;
|
|
|
}
|
|
|
LOG(notArray ? "]\n }" : "]");
|
|
|
j++;
|
|
|
@@ -383,7 +383,7 @@ int main(int argc, char ** argv) {
|
|
|
for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
|
|
|
LOG(" [");
|
|
|
for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
|
|
|
- float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
|
|
|
+ float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out);
|
|
|
LOG("%6.2f", sim);
|
|
|
j++;
|
|
|
if (j < n_embd_count) LOG(", "); else break;
|
|
|
@@ -397,7 +397,7 @@ int main(int argc, char ** argv) {
|
|
|
|
|
|
if (notArray) LOG("\n}\n");
|
|
|
} else if (params.embd_out == "raw") {
|
|
|
- print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize);
|
|
|
+ print_raw_embeddings(emb, n_embd_count, n_embd_out, model, pooling_type, params.embd_normalize);
|
|
|
}
|
|
|
|
|
|
LOG("\n");
|