1
0

cvector-generator.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. #include "ggml.h"
  2. #include "gguf.h"
  3. #include "arg.h"
  4. #include "common.h"
  5. #include "llama.h"
  6. #include "pca.hpp"
  7. #include "mean.hpp"
  8. #ifdef GGML_USE_CUDA
  9. #include "ggml-cuda.h"
  10. #endif
  11. #ifdef GGML_USE_METAL
  12. #include "ggml-metal.h"
  13. #endif
  14. #include <algorithm>
  15. #include <climits>
  16. #include <cstdio>
  17. #include <cstring>
  18. #include <fstream>
  19. #include <iostream>
  20. #include <string>
  21. #include <tuple>
  22. #include <vector>
  23. //////////////////////////////////////////////////
  24. // utils
  25. template <class Iter>
  26. static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
  27. std::string ret;
  28. for (; begin != end; ++begin) {
  29. ret += common_token_to_piece(ctx, *begin);
  30. }
  31. return ret;
  32. }
  33. static void print_usage(int, char ** argv) {
  34. printf("\nexample usage:\n");
  35. printf("\n CPU only: %s -m ./llama-3.Q4_K_M.gguf\n", argv[0]);
  36. printf("\n with GPU: %s -m ./llama-3.Q4_K_M.gguf -ngl 99\n", argv[0]);
  37. printf("\n advanced: %s -m ./llama-3.Q4_K_M.gguf -ngl 99 --pca-iter 2000 --pca-batch 100\n", argv[0]);
  38. printf("\n using mean: %s -m ./llama-3.Q4_K_M.gguf --method mean\n", argv[0]);
  39. printf("\n");
  40. }
  41. //////////////////////////////////////////////////
  42. // cb_eval is reused for each pair of positive - negative prompt
  43. struct callback_data {
  44. ggml_context * ctx_ggml = nullptr; // holds v_pos, v_neg, v_diff_filtered
  45. int n_layers = 0;
  46. int n_tokens = 0;
  47. bool is_eval_pos = true;
  48. // each element of the vector correspond to one layer
  49. std::vector<struct ggml_tensor *> v_pos; // vector of matrices of size [n_embd, n_tokens]
  50. std::vector<struct ggml_tensor *> v_neg; // vector of matrices of size [n_embd, n_tokens]
  51. std::vector<struct ggml_tensor *> v_diff_filtered; // vector of matrices of size [n_embd, n_nonzero_rows]. NOTE: n_nonzero_rows maybe different for each layer
  52. // save a tensor into either v_pos or v_neg (decided by is_eval_pos)
  53. void save_tensor_for_layer(struct ggml_tensor * t) {
  54. GGML_ASSERT(t->type == GGML_TYPE_F32);
  55. if (ctx_ggml == nullptr) {
  56. // alloc a new ctx_ggml if needed
  57. struct ggml_init_params params_ggml = {
  58. /*.mem_size =*/ ggml_tensor_overhead() * n_layers * 3u,
  59. /*.mem_buffer =*/ NULL,
  60. /*.no_alloc =*/ true,
  61. };
  62. ctx_ggml = ggml_init(params_ggml);
  63. }
  64. // copy tensor data
  65. auto n_bytes = ggml_nbytes(t);
  66. struct ggml_tensor * t_layer = ggml_new_tensor_2d(ctx_ggml, t->type, t->ne[0], t->ne[1]);
  67. t_layer->data = malloc(n_bytes); // TODO @ngxson : get rid of this malloc somehow
  68. ggml_backend_tensor_get(t, t_layer->data, 0, n_bytes);
  69. ggml_set_name(t_layer, ggml_get_name(t));
  70. //print_debug_tensor(t_layer);
  71. if (is_eval_pos) {
  72. v_pos.push_back(t_layer);
  73. } else {
  74. v_neg.push_back(t_layer);
  75. }
  76. }
  77. // calculate diff (v_pos - v_neg) and place the result back to v_pos
  78. // all zero rows in the diff tensor will also be removed
  79. // NOTE: final layer is ignored. we only have (n_layers - 1) to process
  80. std::vector<struct ggml_tensor *> calc_diff() {
  81. for (float il = 0; il < v_pos.size(); il++) {
  82. float * a = (float *) v_pos[il]->data;
  83. float * b = (float *) v_neg[il]->data;
  84. size_t n_elem = ggml_nelements(v_pos[il]);
  85. for (size_t j = 0; j < n_elem; j++) {
  86. a[j] -= b[j];
  87. }
  88. //print_debug_tensor(v_pos[i]);
  89. auto diff_filtered = filter_nonzero_rows(v_pos[il]);
  90. v_diff_filtered.push_back(diff_filtered);
  91. }
  92. return v_diff_filtered; // for convinient, we return the result std::vector
  93. }
  94. // delete zero rows from a given 2D tensor
  95. struct ggml_tensor * filter_nonzero_rows(struct ggml_tensor * a) {
  96. //printf("filter_nonzero_rows\n");
  97. auto is_row_all_zeros = [](struct ggml_tensor * t, int row, float eps) -> bool {
  98. // check if given row containing all zero elements
  99. int n_cols = t->ne[0]; // hint: should be equal to n_embd
  100. for (int col = 0; col < n_cols; ++col) {
  101. if (ggml_get_f32_nd(t, col, row, 0, 0) > eps) {
  102. return false;
  103. }
  104. }
  105. return true;
  106. };
  107. std::vector<int> rows_to_copy; // the idx of non-zero cols (to be copied to row of diff_filtered)
  108. for (int i_row = 0; i_row < a->ne[1]; i_row++) {
  109. if (!is_row_all_zeros(a, i_row, 1e-6)) {
  110. rows_to_copy.push_back(i_row);
  111. }
  112. }
  113. // get "n_nonzero_rows" for the output "diff_filtered"
  114. int n_nonzero_rows = rows_to_copy.size();
  115. //printf("n_nonzero_rows: %d\n", n_nonzero_rows);
  116. int n_embd = a->ne[0];
  117. GGML_ASSERT(n_nonzero_rows > 0);
  118. // diff_filtered: [n_embd, n_nonzero_rows]
  119. struct ggml_tensor * diff_filtered = ggml_new_tensor_2d(
  120. ctx_ggml, GGML_TYPE_F32, n_embd, n_nonzero_rows);
  121. ggml_format_name(diff_filtered, "diff_filtered_%s", a->name);
  122. diff_filtered->data = malloc(ggml_nbytes(diff_filtered));
  123. // copy non-zero rows
  124. for (int dest_row = 0; dest_row < n_nonzero_rows; dest_row++) {
  125. int src_row = rows_to_copy[dest_row];
  126. for (int i = 0; i < n_embd; i++) {
  127. float src_elem = ggml_get_f32_nd(a, i, src_row, 0, 0);
  128. ggml_set_f32_nd(diff_filtered, i, dest_row, 0, 0, src_elem);
  129. }
  130. }
  131. //print_debug_tensor(diff_filtered);
  132. return diff_filtered;
  133. }
  134. // we don't implement destructor, because we want to reuse callback_data. we just want to free the tensors
  135. void reset() {
  136. for (auto ptr : v_pos) free(ptr->data);
  137. for (auto ptr : v_neg) free(ptr->data);
  138. for (auto ptr : v_diff_filtered) free(ptr->data);
  139. v_pos.clear();
  140. v_neg.clear();
  141. v_diff_filtered.clear();
  142. if (ctx_ggml) {
  143. ggml_free(ctx_ggml);
  144. }
  145. ctx_ggml = nullptr;
  146. }
  147. };
  148. /**
  149. * process_ctx is used to store the ggml context for pre-post processing the diff vectors
  150. * in short, input => v_diff and output => v_final
  151. */
  152. struct train_context {
  153. ggml_context * ctx_ggml;
  154. int n_embd;
  155. int n_layers;
  156. /* pair of prompts to be used for generating final vector */
  157. std::vector<std::string> positive_entries;
  158. std::vector<std::string> negative_entries;
  159. // each element of the vector correspond to one layer
  160. // NOTE: the last layer is discard. therefore, we will have (n_layers - 1) elements here
  161. // NOTE (2): v_diff is transposed from v_diff_tmp
  162. std::vector<struct ggml_tensor *> v_diff; // vector of matrices of size [m, n_embd] where m ~ n_tokens * n_completions (v_diff contains no zero-rows)
  163. std::vector<struct ggml_tensor *> v_final; // vector of vectors of size [n_embd] to be written to file
  164. // to easily re-alloc when concat v_diff, we temporary store v_diff in a vector instead of a tensor
  165. // v_diff_tmp will get converted unto v_diff later on
  166. std::vector<std::vector<uint8_t>> v_diff_tmp;
  167. train_context(int n_embd_, int n_layers_) {
  168. n_embd = n_embd_;
  169. n_layers = n_layers_;
  170. struct ggml_init_params params_ggml = {
  171. /*.mem_size =*/ ggml_tensor_overhead() * (n_layers - 1) * 2u,
  172. /*.mem_buffer =*/ NULL,
  173. /*.no_alloc =*/ true,
  174. };
  175. ctx_ggml = ggml_init(params_ggml);
  176. for (int il = 0; il < n_layers - 1; il++) {
  177. std::vector<uint8_t> empty;
  178. v_diff_tmp.push_back(empty);
  179. auto t = ggml_new_tensor_1d(ctx_ggml, GGML_TYPE_F32, n_embd);
  180. t->data = malloc(ggml_nbytes(t)); // TODO: get rid of malloc if possible
  181. v_final.push_back(t);
  182. }
  183. }
  184. // add new rows into existing tensor in v_diff_tmp
  185. void concat_diff_tmp(const std::vector<struct ggml_tensor *> & diff_filtered) {
  186. GGML_ASSERT((int) diff_filtered.size() == n_layers - 1);
  187. for (int il = 0; il < n_layers - 1; il++) {
  188. auto t = diff_filtered[il];
  189. auto & diff_tmp = v_diff_tmp[il];
  190. size_t curr_size = diff_tmp.size();
  191. diff_tmp.resize(curr_size + ggml_nbytes(t));
  192. memcpy(diff_tmp.data() + curr_size, t->data, ggml_nbytes(t));
  193. }
  194. }
  195. // build the v_diff tensors from v_diff_tmp (v_diff need to be transposed)
  196. // TODO @ngxson : maybe add option NOT to transpose v_diff; will be useful for "mean" method
  197. void build_v_diff(bool transpose) {
  198. printf("build_v_diff\n");
  199. for (int il = 0; il < n_layers - 1; il++) {
  200. auto & diff_tmp = v_diff_tmp[il];
  201. int n_elem = diff_tmp.size() / sizeof(float);
  202. GGML_ASSERT(n_elem % n_embd == 0);
  203. int n_rows = n_elem / n_embd;
  204. struct ggml_tensor * diff = transpose
  205. ? ggml_new_tensor_2d(ctx_ggml, GGML_TYPE_F32, n_rows, n_embd)
  206. : ggml_new_tensor_2d(ctx_ggml, GGML_TYPE_F32, n_embd, n_rows);
  207. ggml_set_name(diff, (std::string("diff_") + std::to_string(il)).c_str());
  208. diff->data = malloc(ggml_nbytes(diff)); // TODO: get rid of this malloc if possible
  209. if (transpose) {
  210. // copy data & transpose
  211. float * arr = (float *) diff_tmp.data();
  212. for (int ir = 0; ir < n_rows; ++ir) {
  213. for (int ic = 0; ic < n_embd; ++ic) {
  214. float f = arr[ir*n_embd + ic];
  215. ggml_set_f32_nd(diff, ir, ic, 0, 0, f);
  216. }
  217. }
  218. } else {
  219. // only copy
  220. memcpy(diff->data, diff_tmp.data(), ggml_nbytes(diff));
  221. }
  222. v_diff.push_back(diff);
  223. print_debug_tensor(diff);
  224. // free memory of diff_tmp
  225. diff_tmp.resize(0);
  226. }
  227. }
  228. ~train_context() {
  229. for (auto ptr : v_final) free(ptr->data);
  230. for (auto ptr : v_diff) free(ptr->data);
  231. // no need to free v_diff_tmp, since we didn't use malloc
  232. ggml_free(ctx_ggml);
  233. }
  234. };
  235. struct tokenized_prompt {
  236. std::vector<llama_token> tokens_pos;
  237. std::vector<llama_token> tokens_neg;
  238. size_t max_seq_len;
  239. tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) {
  240. const llama_model * model = llama_get_model(ctx);
  241. const llama_vocab * vocab = llama_model_get_vocab(model);
  242. const bool add_bos = llama_vocab_get_add_bos(vocab);
  243. tokens_pos = common_tokenize(ctx, pos, add_bos, true);
  244. tokens_neg = common_tokenize(ctx, neg, add_bos, true);
  245. max_seq_len = std::max(tokens_pos.size(), tokens_neg.size());
  246. padding_seq(ctx, tokens_pos, max_seq_len);
  247. padding_seq(ctx, tokens_neg, max_seq_len);
  248. }
  249. void padding_seq(llama_context * ctx, std::vector<llama_token> & tokens, size_t len) {
  250. // TODO: customize padding token
  251. std::vector<llama_token> pad_tokens = common_tokenize(ctx, " ", false);
  252. llama_token pad_tok = pad_tokens.back();
  253. while (tokens.size() < len) {
  254. tokens.push_back(pad_tok);
  255. }
  256. }
  257. };
  258. //////////////////////////////////////////////////
  259. template <typename T>
  260. static std::string to_string(const T & val) {
  261. std::stringstream ss;
  262. ss << val;
  263. return ss.str();
  264. }
  265. static std::vector<std::string> ctrlvec_load_prompt_file(std::string path, bool skip_empty_lines) {
  266. std::vector<std::string> output;
  267. std::ifstream file(path);
  268. if (!file.is_open()) {
  269. fprintf(stderr, "error: unable to open file: %s\n", path.c_str());
  270. exit(1);
  271. }
  272. std::string line;
  273. while (std::getline(file, line)) {
  274. bool is_skip = skip_empty_lines && line.empty();
  275. if (!is_skip) {
  276. string_process_escapes(line);
  277. output.push_back(line);
  278. }
  279. }
  280. file.close();
  281. return output;
  282. }
  283. //////////////////////////////////////////////////
  284. static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
  285. auto * cb_data = (callback_data *) user_data;
  286. static const char * l_out_name = "l_out";
  287. const bool is_l_out = strncmp(t->name, l_out_name, strlen(l_out_name)) == 0;
  288. if (ask) {
  289. return is_l_out;
  290. }
  291. if (!is_l_out || t->ne[1] != cb_data->n_tokens) {
  292. return true;
  293. }
  294. // save the tensor to current context
  295. cb_data->save_tensor_for_layer(t);
  296. return true;
  297. }
  298. static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
  299. llama_kv_cache_clear(ctx);
  300. if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
  301. fprintf(stderr, "%s : failed to eval\n", __func__);
  302. return false;
  303. }
  304. return true;
  305. }
  306. static void export_gguf(const std::vector<struct ggml_tensor *> & v_ctrl, const std::string fname, const std::string model_hint) {
  307. struct gguf_context * ctx = gguf_init_empty();
  308. const std::string arch = "controlvector";
  309. gguf_set_val_str(ctx, "general.architecture", arch.c_str());
  310. gguf_set_val_str(ctx, (arch + ".model_hint").c_str(), model_hint.c_str());
  311. gguf_set_val_i32(ctx, (arch + ".layer_count").c_str(), v_ctrl.size());
  312. for (size_t i = 0; i < v_ctrl.size(); ++i) {
  313. gguf_add_tensor(ctx, v_ctrl[i]);
  314. print_debug_tensor(v_ctrl[i]);
  315. printf("Added tensor: %s\n", v_ctrl[i]->name);
  316. }
  317. printf("%s: writing file...\n", __func__);
  318. gguf_write_to_file(ctx, fname.c_str(), false);
  319. printf("%s: wrote file '%s'\n", __func__, fname.c_str());
  320. gguf_free(ctx);
  321. }
  322. /**
  323. * Load prompt files and completion file.
  324. * Then format each pair of prompt + completion to make an entry.
  325. */
  326. static int prepare_entries(common_params & params, train_context & ctx_train) {
  327. // load prompts
  328. std::vector<std::string> positive_prompts = ctrlvec_load_prompt_file(params.cvector_positive_file, true);
  329. std::vector<std::string> negative_prompts = ctrlvec_load_prompt_file(params.cvector_negative_file, true);
  330. if (positive_prompts.size() != negative_prompts.size()) {
  331. fprintf(stderr, "number of positive and negative prompts must be equal\n");
  332. return 1;
  333. }
  334. if (positive_prompts.empty()) {
  335. fprintf(stderr, "must provide at least one prompt pair\n");
  336. return 1;
  337. }
  338. ctx_train.positive_entries = positive_prompts;
  339. ctx_train.negative_entries = negative_prompts;
  340. return 0;
  341. }
  342. int main(int argc, char ** argv) {
  343. common_params params;
  344. if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_CVECTOR_GENERATOR, print_usage)) {
  345. return 1;
  346. }
  347. if (params.n_pca_iterations % params.n_pca_batch != 0) {
  348. fprintf(stderr, "PCA iterations must by multiply of PCA batch size\n");
  349. return 1;
  350. }
  351. callback_data cb_data;
  352. // pass the callback to the backend scheduler
  353. // it will be executed for each node during the graph computation
  354. params.cb_eval = cb_eval;
  355. params.cb_eval_user_data = &cb_data;
  356. params.warmup = false;
  357. print_build_info();
  358. llama_backend_init();
  359. llama_numa_init(params.numa);
  360. // load the model to get hparams
  361. common_init_result llama_init = common_init_from_params(params);
  362. llama_model * model = llama_init.model.get();
  363. llama_context * ctx = llama_init.context.get();
  364. // int n_ctx = llama_n_ctx(ctx);
  365. int n_layers = llama_model_n_layer(model);
  366. int n_embd = llama_model_n_embd(model);
  367. // get model hint param (a.k.a model arch name)
  368. char model_hint[128];
  369. llama_model_meta_val_str(model, "general.architecture", model_hint, 128);
  370. // init train_context
  371. train_context ctx_train(n_embd, n_layers);
  372. // load and prepare entries for training
  373. prepare_entries(params, ctx_train);
  374. // we have to pretokenize everything because otherwise we don't know how much overhead to allocate ctx_diffs_wrapped
  375. std::vector<tokenized_prompt> tokenized_prompts;
  376. size_t n_total_tokens = 0;
  377. for (size_t i = 0; i < ctx_train.positive_entries.size(); ++i) {
  378. tokenized_prompt t(ctx, ctx_train.positive_entries[i], ctx_train.negative_entries[i]);
  379. n_total_tokens += 2 * t.max_seq_len;
  380. tokenized_prompts.push_back(std::move(t));
  381. }
  382. std::cout << "n_total_tokens: " << n_total_tokens << std::endl;
  383. for(size_t i = 0; i < ctx_train.positive_entries.size(); ++i) {
  384. bool success = false;
  385. tokenized_prompt t = tokenized_prompts[i];
  386. cb_data.n_layers = n_layers;
  387. cb_data.n_tokens = t.max_seq_len;
  388. printf("Evaluating prompt[%d/%d]: \"%s\" - \"%s\" (%d tokens)\n",
  389. (int) i+1, (int) ctx_train.positive_entries.size(),
  390. tokens_to_str(ctx, t.tokens_pos.cbegin(), t.tokens_pos.cend()).c_str(),
  391. tokens_to_str(ctx, t.tokens_neg.cbegin(), t.tokens_neg.cend()).c_str(),
  392. (int) t.max_seq_len);
  393. cb_data.is_eval_pos = true;
  394. success = get_hidden_layers(ctx, t.tokens_pos);
  395. if (!success) break;
  396. cb_data.is_eval_pos = false;
  397. success = get_hidden_layers(ctx, t.tokens_neg);
  398. if (!success) break;
  399. // calculate diff and remove all zero rows
  400. auto v_diff_filtered = cb_data.calc_diff();
  401. // save & concat the filtered v_diff to ctx_train
  402. ctx_train.concat_diff_tmp(v_diff_filtered);
  403. // reset for next iteration
  404. cb_data.reset();
  405. }
  406. // done with the model, we can now free it to make gain some memory
  407. printf("Done evaluate prompts, unload model...\n");
  408. bool use_pca = params.cvector_dimre_method == DIMRE_METHOD_PCA;
  409. // prepare ctx_train for PCA
  410. ctx_train.build_v_diff(use_pca);
  411. if (use_pca) {
  412. // run PCA
  413. PCA::pca_params pca_params;
  414. pca_params.n_threads = params.cpuparams.n_threads;
  415. pca_params.n_batch = params.n_pca_batch;
  416. pca_params.n_iterations = params.n_pca_iterations;
  417. PCA::run_pca(pca_params, ctx_train.v_diff, ctx_train.v_final);
  418. } else {
  419. // run mean
  420. mean::run(ctx_train.v_diff, ctx_train.v_final);
  421. }
  422. // write output vectors to gguf
  423. export_gguf(ctx_train.v_final, params.cvector_outfile, model_hint);
  424. llama_backend_free();
  425. return 0;
  426. }