|
|
@@ -7,13 +7,19 @@
|
|
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
|
#endif
|
|
|
|
|
|
-static std::vector<std::string> split_lines(const std::string & s) {
|
|
|
- std::string line;
|
|
|
+static std::vector<std::string> split_lines(const std::string & s, const std::string & separator = "\n") {
|
|
|
std::vector<std::string> lines;
|
|
|
- std::stringstream ss(s);
|
|
|
- while (std::getline(ss, line)) {
|
|
|
- lines.push_back(line);
|
|
|
+ size_t start = 0;
|
|
|
+ size_t end = s.find(separator);
|
|
|
+
|
|
|
+ while (end != std::string::npos) {
|
|
|
+ lines.push_back(s.substr(start, end - start));
|
|
|
+ start = end + separator.length();
|
|
|
+ end = s.find(separator, start);
|
|
|
}
|
|
|
+
|
|
|
+ lines.push_back(s.substr(start)); // Add the last part
|
|
|
+
|
|
|
return lines;
|
|
|
}
|
|
|
|
|
|
@@ -24,7 +30,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
|
|
|
+static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
|
|
|
// clear previous kv_cache values (irrelevant for embeddings)
|
|
|
llama_kv_cache_clear(ctx);
|
|
|
|
|
|
@@ -44,13 +50,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|
|
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
|
|
|
|
|
|
float * out = output + batch.seq_id[i][0] * n_embd;
|
|
|
- //TODO: I would also add a parameter here to enable normalization or not.
|
|
|
- /*fprintf(stdout, "unnormalized_embedding:");
|
|
|
- for (int hh = 0; hh < n_embd; hh++) {
|
|
|
- fprintf(stdout, "%9.6f ", embd[hh]);
|
|
|
- }
|
|
|
- fprintf(stdout, "\n");*/
|
|
|
- llama_embd_normalize(embd, out, n_embd);
|
|
|
+ llama_embd_normalize(embd, out, n_embd, embd_norm);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -110,7 +110,7 @@ int main(int argc, char ** argv) {
|
|
|
}
|
|
|
|
|
|
// split the prompt into lines
|
|
|
- std::vector<std::string> prompts = split_lines(params.prompt);
|
|
|
+ std::vector<std::string> prompts = split_lines(params.prompt, params.embd_sep);
|
|
|
|
|
|
// max batch size
|
|
|
const uint64_t n_batch = params.n_batch;
|
|
|
@@ -170,7 +170,7 @@ int main(int argc, char ** argv) {
|
|
|
// encode if at capacity
|
|
|
if (batch.n_tokens + n_toks > n_batch) {
|
|
|
float * out = emb + p * n_embd;
|
|
|
- batch_decode(ctx, batch, out, s, n_embd);
|
|
|
+ batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
|
|
|
llama_batch_clear(batch);
|
|
|
p += s;
|
|
|
s = 0;
|
|
|
@@ -183,29 +183,78 @@ int main(int argc, char ** argv) {
|
|
|
|
|
|
// final batch
|
|
|
float * out = emb + p * n_embd;
|
|
|
- batch_decode(ctx, batch, out, s, n_embd);
|
|
|
-
|
|
|
- // print the first part of the embeddings or for a single prompt, the full embedding
|
|
|
- fprintf(stdout, "\n");
|
|
|
- for (int j = 0; j < n_prompts; j++) {
|
|
|
- fprintf(stdout, "embedding %d: ", j);
|
|
|
- for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
|
|
|
- fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
|
|
|
- }
|
|
|
- fprintf(stdout, "\n");
|
|
|
- }
|
|
|
+ batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
|
|
|
|
|
|
- // print cosine similarity matrix
|
|
|
- if (n_prompts > 1) {
|
|
|
+ if (params.embd_out.empty()) {
|
|
|
+ // print the first part of the embeddings or for a single prompt, the full embedding
|
|
|
fprintf(stdout, "\n");
|
|
|
- printf("cosine similarity matrix:\n\n");
|
|
|
- for (int i = 0; i < n_prompts; i++) {
|
|
|
- for (int j = 0; j < n_prompts; j++) {
|
|
|
- float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
|
|
|
- fprintf(stdout, "%6.2f ", sim);
|
|
|
+ for (int j = 0; j < n_prompts; j++) {
|
|
|
+ fprintf(stdout, "embedding %d: ", j);
|
|
|
+ for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
|
|
|
+ if (params.embd_normalize == 0) {
|
|
|
+ fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
|
|
|
+ } else {
|
|
|
+ fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ fprintf(stdout, "\n");
|
|
|
+ }
|
|
|
+
|
|
|
+ // print cosine similarity matrix
|
|
|
+ if (n_prompts > 1) {
|
|
|
+ fprintf(stdout, "\n");
|
|
|
+ printf("cosine similarity matrix:\n\n");
|
|
|
+ for (int i = 0; i < n_prompts; i++) {
|
|
|
+ fprintf(stdout, "%6.6s ", prompts[i].c_str());
|
|
|
}
|
|
|
fprintf(stdout, "\n");
|
|
|
+ for (int i = 0; i < n_prompts; i++) {
|
|
|
+ for (int j = 0; j < n_prompts; j++) {
|
|
|
+ float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
|
|
|
+ fprintf(stdout, "%6.2f ", sim);
|
|
|
+ }
|
|
|
+ fprintf(stdout, "%1.10s", prompts[i].c_str());
|
|
|
+ fprintf(stdout, "\n");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if (params.embd_out == "json" || params.embd_out == "json+" || params.embd_out == "array") {
|
|
|
+ const bool notArray = params.embd_out != "array";
|
|
|
+
|
|
|
+ fprintf(stdout, notArray ? "{\n \"object\": \"list\",\n \"data\": [\n" : "[");
|
|
|
+ for (int j = 0;;) { // at least one iteration (one prompt)
|
|
|
+ if (notArray) fprintf(stdout, " {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j);
|
|
|
+ fprintf(stdout, "[");
|
|
|
+ for (int i = 0;;) { // at least one iteration (n_embd > 0)
|
|
|
+ fprintf(stdout, params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]);
|
|
|
+ i++;
|
|
|
+ if (i < n_embd) fprintf(stdout, ","); else break;
|
|
|
+ }
|
|
|
+ fprintf(stdout, notArray ? "]\n }" : "]");
|
|
|
+ j++;
|
|
|
+ if (j < n_prompts) fprintf(stdout, notArray ? ",\n" : ","); else break;
|
|
|
}
|
|
|
+ fprintf(stdout, notArray ? "\n ]" : "]\n");
|
|
|
+
|
|
|
+ if (params.embd_out == "json+" && n_prompts > 1) {
|
|
|
+ fprintf(stdout, ",\n \"cosineSimilarity\": [\n");
|
|
|
+ for (int i = 0;;) { // at least two iteration (n_prompts > 1)
|
|
|
+ fprintf(stdout, " [");
|
|
|
+ for (int j = 0;;) { // at least two iteration (n_prompts > 1)
|
|
|
+ float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
|
|
|
+ fprintf(stdout, "%6.2f", sim);
|
|
|
+ j++;
|
|
|
+ if (j < n_prompts) fprintf(stdout, ", "); else break;
|
|
|
+ }
|
|
|
+ fprintf(stdout, " ]");
|
|
|
+ i++;
|
|
|
+ if (i < n_prompts) fprintf(stdout, ",\n"); else break;
|
|
|
+ }
|
|
|
+ fprintf(stdout, "\n ]");
|
|
|
+ }
|
|
|
+
|
|
|
+ if (notArray) fprintf(stdout, "\n}\n");
|
|
|
}
|
|
|
|
|
|
// clean up
|