1
0

diffusion-cli.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. #include "arg.h"
  2. #include "chat.h"
  3. #include "common.h"
  4. #include "llama.h"
  5. #include "log.h"
  6. #include <limits.h>
  7. #include <string>
  8. #include <vector>
  9. #include <algorithm>
  10. #include <cmath>
  11. #include <limits>
  12. #include <random>
  13. typedef bool (*diffusion_step_callback_t)(int32_t step,
  14. int32_t total_steps,
  15. const llama_token * tokens,
  16. int32_t n_tokens,
  17. void * user_data);
  18. enum diffusion_alg {
  19. DIFFUSION_ALG_ORIGIN = 0,
  20. DIFFUSION_ALG_MASKGIT_PLUS = 1,
  21. DIFFUSION_ALG_TOPK_MARGIN = 2,
  22. DIFFUSION_ALG_ENTROPY = 3,
  23. };
  24. struct diffusion_params {
  25. int32_t steps;
  26. float eps;
  27. float temperature;
  28. float top_p;
  29. int32_t top_k;
  30. llama_token mask_token_id;
  31. enum diffusion_alg algorithm;
  32. float alg_temp;
  33. diffusion_step_callback_t step_callback;
  34. void * step_callback_user_data;
  35. int32_t seed;
  36. };
  37. static diffusion_params diffusion_default_params() {
  38. diffusion_params params = {};
  39. params.steps = 64;
  40. params.eps = 1e-3f;
  41. params.temperature = 0.2f;
  42. params.top_p = 0.95f;
  43. params.top_k = 0;
  44. params.mask_token_id = LLAMA_TOKEN_NULL;
  45. params.algorithm = DIFFUSION_ALG_ORIGIN;
  46. params.alg_temp = 0.0f;
  47. params.step_callback = nullptr;
  48. params.step_callback_user_data = nullptr;
  49. params.seed = 0;
  50. return params;
  51. }
  52. static void diffusion_generate(llama_context * ctx,
  53. const llama_token * input_tokens,
  54. llama_token * output_tokens,
  55. int32_t n_input,
  56. int32_t max_length,
  57. struct diffusion_params params,
  58. int32_t & n_generated) {
  59. n_generated = 0;
  60. if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || max_length <= n_input) {
  61. return;
  62. }
  63. const llama_model * model = llama_get_model(ctx);
  64. // Initialize with input and pad with mask tokens
  65. std::copy(input_tokens, input_tokens + n_input, output_tokens);
  66. std::fill(output_tokens + n_input, output_tokens + max_length, params.mask_token_id);
  67. std::mt19937 rng(params.seed);
  68. std::vector<float> timesteps(params.steps + 1);
  69. for (int32_t i = 0; i <= params.steps; i++) {
  70. timesteps[i] = 1.0f - (float) i / params.steps * (1.0f - params.eps);
  71. }
  72. llama_set_causal_attn(ctx, false);
  73. int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
  74. std::vector<llama_token_data> candidates(n_vocab);
  75. std::vector<llama_token_data> conf_candidates;
  76. conf_candidates.reserve(max_length);
  77. std::vector<int32_t> mask_positions;
  78. mask_positions.reserve(max_length);
  79. struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
  80. if (params.top_k > 0) {
  81. llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
  82. }
  83. if (params.top_p < 1.0f) {
  84. llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
  85. }
  86. if (params.temperature > 0.0f) {
  87. llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
  88. }
  89. llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
  90. struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
  91. llama_batch batch = llama_batch_init(max_length, 0, 1);
  92. batch.n_tokens = max_length;
  93. int64_t total_sampling_time = 0;
  94. int64_t total_time = 0;
  95. int64_t time_start = ggml_time_us();
  96. for (int32_t step = 0; step < params.steps; step++) {
  97. if (params.step_callback) {
  98. if (!params.step_callback(step, params.steps, output_tokens, max_length, params.step_callback_user_data)) {
  99. break;
  100. }
  101. }
  102. for (int32_t i = 0; i < max_length; i++) {
  103. batch.token[i] = output_tokens[i];
  104. batch.pos[i] = i;
  105. batch.n_seq_id[i] = 1;
  106. batch.seq_id[i][0] = 0;
  107. batch.logits[i] = 1;
  108. }
  109. int ret = llama_decode(ctx, batch);
  110. if (ret != 0) {
  111. LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, step, ret);
  112. break;
  113. }
  114. float * raw_logits = llama_get_logits(ctx);
  115. if (!raw_logits) {
  116. LOG_ERR("%s: failed to get logits at step %d\n", __func__, step);
  117. break;
  118. }
  119. auto get_logits_for_pos = [&](int32_t pos) -> const float * {
  120. return pos == 0 ? raw_logits : raw_logits + (pos - 1) * n_vocab;
  121. };
  122. int64_t time_start_sampling = ggml_time_us();
  123. mask_positions.clear();
  124. for (int32_t i = 0; i < max_length; i++) {
  125. if (output_tokens[i] == params.mask_token_id) {
  126. mask_positions.push_back(i);
  127. }
  128. }
  129. if (mask_positions.empty()) {
  130. break;
  131. }
  132. float t = timesteps[step];
  133. float s = timesteps[step + 1];
  134. if (params.algorithm == DIFFUSION_ALG_ORIGIN) {
  135. float p_transfer = (step < params.steps - 1) ? (1.0f - s / t) : 1.0f;
  136. for (int32_t pos : mask_positions) {
  137. if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
  138. const float * pos_logits = get_logits_for_pos(pos);
  139. for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
  140. candidates[token_id].id = token_id;
  141. candidates[token_id].logit = pos_logits[token_id];
  142. candidates[token_id].p = 0.0f;
  143. }
  144. llama_token_data_array cur_p = {
  145. /* .data = */ candidates.data(),
  146. /* .size = */ (size_t) n_vocab, // Reset size to full vocab
  147. /* .selected = */ -1,
  148. /* .sorted = */ false,
  149. };
  150. llama_sampler_apply(sampler, &cur_p);
  151. output_tokens[pos] = cur_p.data[cur_p.selected].id;
  152. }
  153. }
  154. } else {
  155. std::vector<std::pair<float, int32_t>> confidences;
  156. std::vector<llama_token> sampled_tokens(mask_positions.size());
  157. for (size_t i = 0; i < mask_positions.size(); i++) {
  158. int32_t pos = mask_positions[i];
  159. const float * pos_logits = get_logits_for_pos(pos);
  160. for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
  161. candidates[token_id].logit = pos_logits[token_id];
  162. candidates[token_id].p = 0.0f;
  163. candidates[token_id].id = token_id;
  164. }
  165. llama_token_data_array cur_p = {
  166. /* .data = */ candidates.data(),
  167. /* .size = */ candidates.size(),
  168. /* .selected = */ -1,
  169. /* .sorted = */ false,
  170. };
  171. llama_sampler_apply(sampler, &cur_p);
  172. llama_token sampled_token = cur_p.data[cur_p.selected].id;
  173. float confidence = 0.0f;
  174. if (params.algorithm == DIFFUSION_ALG_ENTROPY) {
  175. const float epsilon = 1e-10f;
  176. for (size_t j = 0; j < cur_p.size; j++) {
  177. float prob = cur_p.data[j].p;
  178. confidence += prob * logf(prob + epsilon);
  179. }
  180. } else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) {
  181. confidence = cur_p.data[0].p - cur_p.data[1].p;
  182. } else {
  183. confidence = cur_p.data[cur_p.selected].p;
  184. }
  185. sampled_tokens[i] = sampled_token;
  186. confidences.emplace_back(confidence, i);
  187. }
  188. int32_t num_transfer =
  189. (step < params.steps - 1) ? (int32_t) (mask_positions.size() * (1.0f - s / t)) : mask_positions.size();
  190. if (num_transfer > 0) {
  191. if (params.alg_temp == 0.0f) {
  192. std::partial_sort(confidences.begin(), confidences.begin() + num_transfer, confidences.end(),
  193. [](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
  194. if (a.first != b.first) {
  195. return a.first > b.first;
  196. }
  197. return a.second < b.second;
  198. });
  199. } else {
  200. conf_candidates.clear();
  201. for (int32_t pos = 0; pos < max_length; pos++) {
  202. float conf_logit = -std::numeric_limits<float>::infinity();
  203. auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
  204. if (it != mask_positions.end()) {
  205. size_t mask_idx = std::distance(mask_positions.begin(), it);
  206. conf_logit = confidences[mask_idx].first / params.alg_temp; // Apply temperature scaling
  207. }
  208. conf_candidates.emplace_back(llama_token_data{ pos, conf_logit, 0.0f });
  209. }
  210. llama_token_data_array conf_array = {
  211. /* .data = */ conf_candidates.data(),
  212. /* .size = */ conf_candidates.size(),
  213. /* .selected = */ -1,
  214. /* .sorted = */ false,
  215. };
  216. for (int32_t i = 0; i < num_transfer; i++) {
  217. // Apply distribution sampler to get selected index
  218. llama_sampler_apply(dist_sampler, &conf_array);
  219. int selected_idx = conf_array.selected;
  220. confidences[i].second = conf_candidates[selected_idx].id;
  221. conf_candidates[selected_idx].p = 0.0f;
  222. conf_array.selected = -1;
  223. }
  224. }
  225. if (params.alg_temp == 0.0f) {
  226. // Deterministic - use confidence order
  227. for (int32_t i = 0; i < num_transfer; i++) {
  228. int32_t mask_idx = confidences[i].second;
  229. int32_t pos = mask_positions[mask_idx];
  230. llama_token token = sampled_tokens[mask_idx];
  231. output_tokens[pos] = token;
  232. }
  233. } else {
  234. for (int32_t i = 0; i < num_transfer; i++) {
  235. int32_t pos = confidences[i].second;
  236. auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
  237. if (it != mask_positions.end()) {
  238. int32_t mask_idx = std::distance(mask_positions.begin(), it);
  239. output_tokens[pos] = sampled_tokens[mask_idx];
  240. }
  241. }
  242. }
  243. }
  244. }
  245. int64_t time_end_sampling = ggml_time_us();
  246. total_sampling_time += time_end_sampling - time_start_sampling;
  247. }
  248. int64_t time_end = ggml_time_us();
  249. total_time += time_end - time_start;
  250. LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
  251. total_time / 1000.0, total_time / 1000.0 / params.steps, total_sampling_time / 1000.0 / params.steps);
  252. llama_batch_free(batch);
  253. llama_sampler_free(sampler);
  254. llama_sampler_free(dist_sampler);
  255. n_generated = max_length;
  256. }
  257. static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
  258. if (!use_chat_template) {
  259. return prompt;
  260. }
  261. auto chat_templates = common_chat_templates_init(model, "");
  262. common_chat_templates_inputs inputs;
  263. common_chat_msg user_msg;
  264. user_msg.role = "user";
  265. user_msg.content = prompt;
  266. inputs.add_generation_prompt = true;
  267. inputs.messages.push_back(user_msg);
  268. auto result = common_chat_templates_apply(chat_templates.get(), inputs);
  269. return result.prompt;
  270. }
  271. struct callback_data {
  272. const common_params_diffusion * diff_params;
  273. const llama_vocab * vocab;
  274. int32_t n_input;
  275. };
  276. static bool diffusion_step_callback(int32_t step,
  277. int32_t total_steps,
  278. const llama_token * tokens,
  279. int32_t n_tokens,
  280. void * user_data) {
  281. (void)user_data;
  282. callback_data * data = static_cast<callback_data *>(user_data);
  283. auto print_progress_bar = [](int32_t step, int32_t total_steps) {
  284. int progress_percent = (step * 100) / total_steps;
  285. int progress_bars = (step * 50) / total_steps;
  286. LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
  287. step,
  288. total_steps,
  289. std::string(progress_bars, '=').c_str(),
  290. std::string(50 - progress_bars, ' ').c_str(),
  291. progress_percent);
  292. };
  293. if (data->diff_params->visual_mode) {
  294. // Visual mode: clear
  295. LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left
  296. print_progress_bar(step, total_steps);
  297. LOG_INF("\n");
  298. std::string current_text = " ";
  299. for (int32_t i = data->n_input; i < n_tokens; i++) {
  300. std::string token_str;
  301. if (tokens[i] != llama_vocab_mask(data->vocab)) {
  302. char piece[256];
  303. int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
  304. if (n_chars > 0) {
  305. piece[n_chars] = '\0';
  306. token_str = piece;
  307. }
  308. } else {
  309. token_str = " ";
  310. }
  311. current_text += token_str;
  312. }
  313. LOG_INF("%s\n", current_text.c_str());
  314. } else {
  315. print_progress_bar(step, total_steps);
  316. }
  317. return true;
  318. }
  319. int main(int argc, char ** argv) {
  320. ggml_time_init();
  321. common_params params;
  322. if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) {
  323. return 1;
  324. }
  325. const char * alg_names[] = { "ORIGIN", "MASKGIT_PLUS", "TOPK_MARGIN", "ENTROPY" };
  326. const char * alg_name = (params.diffusion.algorithm >= 0 && params.diffusion.algorithm <= 3) ?
  327. alg_names[params.diffusion.algorithm] :
  328. "UNKNOWN";
  329. common_init();
  330. llama_backend_init();
  331. llama_model_params model_params = llama_model_default_params();
  332. model_params.n_gpu_layers = params.n_gpu_layers;
  333. model_params.devices = params.devices.data();
  334. model_params.use_mmap = params.use_mmap;
  335. model_params.use_mlock = params.use_mlock;
  336. model_params.check_tensors = params.check_tensors;
  337. llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
  338. if (!model) {
  339. LOG_ERR("error: failed to load model '%s'\n", params.model.path.c_str());
  340. return 1;
  341. }
  342. llama_context_params ctx_params = llama_context_default_params();
  343. ctx_params.n_ctx = params.n_ctx;
  344. ctx_params.n_batch = params.n_batch;
  345. ctx_params.n_ubatch = params.n_ubatch;
  346. ctx_params.flash_attn = params.flash_attn;
  347. ctx_params.no_perf = params.no_perf;
  348. ctx_params.type_k = params.cache_type_k;
  349. ctx_params.type_v = params.cache_type_v;
  350. llama_context * ctx = llama_init_from_model(model, ctx_params);
  351. if (!ctx) {
  352. LOG_ERR("error: failed to create context\n");
  353. llama_model_free(model);
  354. return 1;
  355. }
  356. llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads);
  357. const llama_vocab * vocab = llama_model_get_vocab(model);
  358. std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);
  359. std::vector<llama_token> input_tokens = common_tokenize(vocab, formatted_prompt,
  360. /*add special tokens*/ true,
  361. /*parse special*/ true);
  362. int n_input = input_tokens.size();
  363. if (n_input >= params.n_ctx) {
  364. LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
  365. llama_free(ctx);
  366. llama_model_free(model);
  367. return 1;
  368. }
  369. struct diffusion_params ldiff_params = diffusion_default_params();
  370. ldiff_params.steps = params.diffusion.steps;
  371. ldiff_params.eps = params.diffusion.eps;
  372. ldiff_params.temperature = params.sampling.temp;
  373. ldiff_params.top_p = params.sampling.top_p;
  374. ldiff_params.top_k = params.sampling.top_k;
  375. ldiff_params.algorithm = static_cast<enum diffusion_alg>(params.diffusion.algorithm);
  376. ldiff_params.alg_temp = params.diffusion.alg_temp;
  377. ldiff_params.seed = params.sampling.seed;
  378. llama_token mask_token_id = llama_vocab_mask(vocab);
  379. GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
  380. LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
  381. LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", params.diffusion.steps);
  382. LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", params.diffusion.eps);
  383. LOG_INF("diffusion_params: - %-25s u32 = %d (%s)\n", "algorithm", params.diffusion.algorithm,
  384. alg_name);
  385. LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", params.diffusion.alg_temp);
  386. ldiff_params.mask_token_id = mask_token_id;
  387. callback_data cb_data = { &params.diffusion, vocab, n_input };
  388. ldiff_params.step_callback = diffusion_step_callback;
  389. ldiff_params.step_callback_user_data = &cb_data;
  390. int32_t n_generated = 0;
  391. std::vector<llama_token> output_tokens(params.n_ubatch);
  392. diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, params.n_ubatch,
  393. ldiff_params, n_generated);
  394. if (n_generated > 0) {
  395. if (params.diffusion.visual_mode) {
  396. //clear screen and move cursor to top-left
  397. LOG_INF("\033[2J\033[H");
  398. }
  399. output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
  400. std::string output_data = common_detokenize(vocab, output_tokens, false);
  401. LOG_INF("\n%s\n", output_data.c_str());
  402. } else {
  403. LOG_INF("Error: diffusion generation failed\n");
  404. }
  405. llama_free(ctx);
  406. llama_model_free(model);
  407. llama_backend_free();
  408. return 0;
  409. }