1
0

embedding.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. #include "arg.h"
  2. #include "common.h"
  3. #include "log.h"
  4. #include "llama.h"
  5. #include <ctime>
  6. #include <algorithm>
  7. #if defined(_MSC_VER)
  8. #pragma warning(disable: 4244 4267) // possible loss of data
  9. #endif
  10. static std::vector<std::string> split_lines(const std::string & s, const std::string & separator = "\n") {
  11. std::vector<std::string> lines;
  12. size_t start = 0;
  13. size_t end = s.find(separator);
  14. while (end != std::string::npos) {
  15. lines.push_back(s.substr(start, end - start));
  16. start = end + separator.length();
  17. end = s.find(separator, start);
  18. }
  19. lines.push_back(s.substr(start)); // Add the last part
  20. return lines;
  21. }
  22. static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
  23. size_t n_tokens = tokens.size();
  24. for (size_t i = 0; i < n_tokens; i++) {
  25. common_batch_add(batch, tokens[i], i, { seq_id }, true);
  26. }
  27. }
  28. static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
  29. const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
  30. // clear previous kv_cache values (irrelevant for embeddings)
  31. llama_memory_clear(llama_get_memory(ctx), true);
  32. // run model
  33. LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
  34. if (llama_decode(ctx, batch) < 0) {
  35. LOG_ERR("%s : failed to process\n", __func__);
  36. }
  37. for (int i = 0; i < batch.n_tokens; i++) {
  38. if (!batch.logits[i]) {
  39. continue;
  40. }
  41. const float * embd = nullptr;
  42. int embd_pos = 0;
  43. if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
  44. // try to get token embeddings
  45. embd = llama_get_embeddings_ith(ctx, i);
  46. embd_pos = i;
  47. GGML_ASSERT(embd != NULL && "failed to get token embeddings");
  48. } else {
  49. // try to get sequence embeddings - supported only when pooling_type is not NONE
  50. embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
  51. embd_pos = batch.seq_id[i][0];
  52. GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
  53. }
  54. float * out = output + embd_pos * n_embd;
  55. common_embd_normalize(embd, out, n_embd, embd_norm);
  56. }
  57. }
  58. int main(int argc, char ** argv) {
  59. common_params params;
  60. if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EMBEDDING)) {
  61. return 1;
  62. }
  63. common_init();
  64. params.embedding = true;
  65. // if the number of prompts that would be encoded is known in advance, it's more efficient to specify the
  66. // --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache
  67. // in order to support any number of prompts
  68. if (params.n_parallel == 1) {
  69. LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__);
  70. params.kv_unified = true;
  71. }
  72. // utilize the full context
  73. if (params.n_batch < params.n_ctx) {
  74. LOG_WRN("%s: setting batch size to %d\n", __func__, params.n_ctx);
  75. params.n_batch = params.n_ctx;
  76. }
  77. // for non-causal models, batch size must be equal to ubatch size
  78. if (params.attention_type != LLAMA_ATTENTION_TYPE_CAUSAL) {
  79. params.n_ubatch = params.n_batch;
  80. }
  81. // get max number of sequences per batch
  82. const int n_seq_max = llama_max_parallel_sequences();
  83. llama_backend_init();
  84. llama_numa_init(params.numa);
  85. // load the model
  86. common_init_result llama_init = common_init_from_params(params);
  87. llama_model * model = llama_init.model.get();
  88. llama_context * ctx = llama_init.context.get();
  89. if (model == NULL) {
  90. LOG_ERR("%s: unable to load model\n", __func__);
  91. return 1;
  92. }
  93. const llama_vocab * vocab = llama_model_get_vocab(model);
  94. const int n_ctx_train = llama_model_n_ctx_train(model);
  95. const int n_ctx = llama_n_ctx(ctx);
  96. const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
  97. if (llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
  98. LOG_ERR("%s: computing embeddings in encoder-decoder models is not supported\n", __func__);
  99. return 1;
  100. }
  101. if (n_ctx > n_ctx_train) {
  102. LOG_WRN("%s: warning: model was trained on only %d context tokens (%d specified)\n",
  103. __func__, n_ctx_train, n_ctx);
  104. }
  105. // print system information
  106. {
  107. LOG_INF("\n");
  108. LOG_INF("%s\n", common_params_get_system_info(params).c_str());
  109. }
  110. // split the prompt into lines
  111. std::vector<std::string> prompts = split_lines(params.prompt, params.embd_sep);
  112. // max batch size
  113. const uint64_t n_batch = params.n_batch;
  114. // get added sep and eos token, if any
  115. const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : "";
  116. const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : "";
  117. const char * rerank_prompt = llama_model_chat_template(model, "rerank");
  118. // tokenize the prompts and trim
  119. std::vector<std::vector<int32_t>> inputs;
  120. for (const auto & prompt : prompts) {
  121. std::vector<llama_token> inp;
  122. // split classification pairs and insert expected separator tokens
  123. if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) {
  124. std::vector<std::string> pairs = split_lines(prompt, params.cls_sep);
  125. if (rerank_prompt != nullptr) {
  126. const std::string query = pairs[0];
  127. const std::string doc = pairs[1];
  128. std::string final_prompt = rerank_prompt;
  129. string_replace_all(final_prompt, "{query}" , query);
  130. string_replace_all(final_prompt, "{document}", doc );
  131. inp = common_tokenize(vocab, final_prompt, true, true);
  132. } else {
  133. std::string final_prompt;
  134. for (size_t i = 0; i < pairs.size(); i++) {
  135. final_prompt += pairs[i];
  136. if (i != pairs.size() - 1) {
  137. if (!added_eos_token.empty()) {
  138. final_prompt += added_eos_token;
  139. }
  140. if (!added_sep_token.empty()) {
  141. final_prompt += added_sep_token;
  142. }
  143. }
  144. }
  145. inp = common_tokenize(ctx, final_prompt, true, true);
  146. }
  147. } else {
  148. inp = common_tokenize(ctx, prompt, true, true);
  149. }
  150. if (inp.size() > n_batch) {
  151. LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
  152. __func__, (long long int) inp.size(), (long long int) n_batch);
  153. return 1;
  154. }
  155. inputs.push_back(inp);
  156. }
  157. // check if the last token is SEP/EOS
  158. // it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
  159. for (auto & inp : inputs) {
  160. if (inp.empty() || (inp.back() != llama_vocab_sep(vocab) && inp.back() != llama_vocab_eos(vocab))) {
  161. LOG_WRN("%s: last token in the prompt is not SEP or EOS\n", __func__);
  162. LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__);
  163. }
  164. }
  165. // tokenization stats
  166. if (params.verbose_prompt) {
  167. for (int i = 0; i < (int) inputs.size(); i++) {
  168. LOG_INF("%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str());
  169. LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size());
  170. for (int j = 0; j < (int) inputs[i].size(); j++) {
  171. LOG("%6d -> '%s'\n", inputs[i][j], common_token_to_piece(ctx, inputs[i][j]).c_str());
  172. }
  173. LOG("\n\n");
  174. }
  175. }
  176. // initialize batch
  177. const int n_prompts = prompts.size();
  178. struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
  179. // count number of embeddings
  180. int n_embd_count = 0;
  181. if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
  182. for (int k = 0; k < n_prompts; k++) {
  183. n_embd_count += inputs[k].size();
  184. }
  185. } else {
  186. n_embd_count = n_prompts;
  187. }
  188. // allocate output
  189. const int n_embd = llama_model_n_embd(model);
  190. std::vector<float> embeddings(n_embd_count * n_embd, 0);
  191. float * emb = embeddings.data();
  192. // break into batches
  193. int e = 0; // number of embeddings already stored
  194. int s = 0; // number of prompts in current batch
  195. for (int k = 0; k < n_prompts; k++) {
  196. // clamp to n_batch tokens
  197. auto & inp = inputs[k];
  198. const uint64_t n_toks = inp.size();
  199. // encode if at capacity
  200. if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) {
  201. float * out = emb + e * n_embd;
  202. batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
  203. e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
  204. s = 0;
  205. common_batch_clear(batch);
  206. }
  207. // add to batch
  208. batch_add_seq(batch, inp, s);
  209. s += 1;
  210. }
  211. // final batch
  212. float * out = emb + e * n_embd;
  213. batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
  214. if (params.embd_out.empty()) {
  215. LOG("\n");
  216. if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
  217. for (int j = 0; j < n_embd_count; j++) {
  218. LOG("embedding %d: ", j);
  219. for (int i = 0; i < std::min(3, n_embd); i++) {
  220. if (params.embd_normalize == 0) {
  221. LOG("%6.0f ", emb[j * n_embd + i]);
  222. } else {
  223. LOG("%9.6f ", emb[j * n_embd + i]);
  224. }
  225. }
  226. LOG(" ... ");
  227. for (int i = n_embd - 3; i < n_embd; i++) {
  228. if (params.embd_normalize == 0) {
  229. LOG("%6.0f ", emb[j * n_embd + i]);
  230. } else {
  231. LOG("%9.6f ", emb[j * n_embd + i]);
  232. }
  233. }
  234. LOG("\n");
  235. }
  236. } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
  237. const uint32_t n_cls_out = llama_model_n_cls_out(model);
  238. std::vector<std::string> cls_out_labels;
  239. for (uint32_t i = 0; i < n_cls_out; i++) {
  240. const char * label = llama_model_cls_label(model, i);
  241. const std::string label_i(label == nullptr ? "" : label);
  242. cls_out_labels.emplace_back(label_i.empty() ? std::to_string(i) : label_i);
  243. }
  244. for (int j = 0; j < n_embd_count; j++) {
  245. for (uint32_t i = 0; i < n_cls_out; i++) {
  246. // NOTE: if you change this log - update the tests in ci/run.sh
  247. if (n_cls_out == 1) {
  248. LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
  249. } else {
  250. LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str());
  251. }
  252. }
  253. }
  254. } else {
  255. // print the first part of the embeddings or for a single prompt, the full embedding
  256. for (int j = 0; j < n_prompts; j++) {
  257. LOG("embedding %d: ", j);
  258. for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
  259. if (params.embd_normalize == 0) {
  260. LOG("%6.0f ", emb[j * n_embd + i]);
  261. } else {
  262. LOG("%9.6f ", emb[j * n_embd + i]);
  263. }
  264. }
  265. LOG("\n");
  266. }
  267. // print cosine similarity matrix
  268. if (n_prompts > 1) {
  269. LOG("\n");
  270. LOG("cosine similarity matrix:\n\n");
  271. for (int i = 0; i < n_prompts; i++) {
  272. LOG("%6.6s ", prompts[i].c_str());
  273. }
  274. LOG("\n");
  275. for (int i = 0; i < n_prompts; i++) {
  276. for (int j = 0; j < n_prompts; j++) {
  277. float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
  278. LOG("%6.2f ", sim);
  279. }
  280. LOG("%1.10s", prompts[i].c_str());
  281. LOG("\n");
  282. }
  283. }
  284. }
  285. }
  286. if (params.embd_out == "json" || params.embd_out == "json+" || params.embd_out == "array") {
  287. const bool notArray = params.embd_out != "array";
  288. LOG(notArray ? "{\n \"object\": \"list\",\n \"data\": [\n" : "[");
  289. for (int j = 0;;) { // at least one iteration (one prompt)
  290. if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j);
  291. LOG("[");
  292. for (int i = 0;;) { // at least one iteration (n_embd > 0)
  293. LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]);
  294. i++;
  295. if (i < n_embd) LOG(","); else break;
  296. }
  297. LOG(notArray ? "]\n }" : "]");
  298. j++;
  299. if (j < n_embd_count) LOG(notArray ? ",\n" : ","); else break;
  300. }
  301. LOG(notArray ? "\n ]" : "]\n");
  302. if (params.embd_out == "json+" && n_prompts > 1) {
  303. LOG(",\n \"cosineSimilarity\": [\n");
  304. for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
  305. LOG(" [");
  306. for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
  307. float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
  308. LOG("%6.2f", sim);
  309. j++;
  310. if (j < n_embd_count) LOG(", "); else break;
  311. }
  312. LOG(" ]");
  313. i++;
  314. if (i < n_embd_count) LOG(",\n"); else break;
  315. }
  316. LOG("\n ]");
  317. }
  318. if (notArray) LOG("\n}\n");
  319. }
  320. LOG("\n");
  321. llama_perf_context_print(ctx);
  322. // clean up
  323. llama_batch_free(batch);
  324. llama_backend_free();
  325. return 0;
  326. }