diffusion-cli.cpp 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693
  1. #include "arg.h"
  2. #include "chat.h"
  3. #include "common.h"
  4. #include "llama.h"
  5. #include "log.h"
  6. #include <limits.h>
  7. #include <algorithm>
  8. #include <cmath>
  9. #include <cstring>
  10. #include <limits>
  11. #include <random>
  12. #include <string>
  13. #include <vector>
  14. enum diffusion_algorithm { ORIGIN = 0, ENTROPY_BASED = 1, MARGIN_BASED = 2, RANDOM = 3, CONFIDENCE_BASED = 4 };
  15. // Unified transfer scheduling methods
  16. enum transfer_schedule {
  17. TIMESTEP_BASED = 0, // Dream-style: (1.0 - s/t) * remaining
  18. BLOCK_BASED = 1, // LLaDA-style: process in blocks with get_num_transfer_tokens
  19. };
  20. typedef bool (*diffusion_step_callback_t)(int32_t step,
  21. int32_t total_steps,
  22. const llama_token * tokens,
  23. int32_t n_tokens,
  24. void * user_data);
  25. struct diffusion_params {
  26. int32_t steps = 0;
  27. float temperature = 0;
  28. llama_token mask_token_id = LLAMA_TOKEN_NULL;
  29. diffusion_step_callback_t step_callback = nullptr;
  30. void * step_callback_user_data = nullptr;
  31. int32_t seed = 0;
  32. bool visual_mode = false;
  33. bool shift_logits = false; // Shift logits by -1 after decode
  34. float top_p = 0.;
  35. int32_t top_k = 0.;
  36. diffusion_algorithm algorithm = CONFIDENCE_BASED;
  37. transfer_schedule schedule = TIMESTEP_BASED;
  38. float cfg_scale = 0.; // Config scale for classifier-free guidance
  39. float eps = 0.; // Timestep scheduling
  40. int32_t block_length = 0; // Block size (for block scheduling)
  41. float alg_temp = 0; // algorithm temperature (0.0 = deterministic)
  42. bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0
  43. int32_t max_length = 0; // Maximum sequence length
  44. };
  45. struct callback_data {
  46. diffusion_params * diff_params;
  47. const llama_vocab * vocab;
  48. int32_t n_input;
  49. };
  50. static float calculate_confidence(const llama_token_data_array & cur_p,
  51. diffusion_algorithm algorithm,
  52. std::mt19937 & rng) {
  53. switch (algorithm) {
  54. case CONFIDENCE_BASED:
  55. return cur_p.data[cur_p.selected].p; // Selected token probability
  56. case ENTROPY_BASED:
  57. {
  58. float entropy = 0.0f;
  59. const float epsilon = 1e-10f;
  60. for (size_t i = 0; i < cur_p.size; i++) {
  61. float prob = cur_p.data[i].p;
  62. entropy += prob * logf(prob + epsilon);
  63. }
  64. return -entropy; // Higher entropy = lower confidence
  65. }
  66. case MARGIN_BASED:
  67. return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p;
  68. case RANDOM:
  69. {
  70. std::uniform_real_distribution<float> uniform(0.0f, 1.0f);
  71. return uniform(rng); // Random confidence
  72. }
  73. case ORIGIN:
  74. return cur_p.data[cur_p.selected].p;
  75. default:
  76. return 0.0f;
  77. }
  78. }
  79. // Unified transfer count calculation function
  80. static int32_t calculate_transfer_count(int32_t step,
  81. int32_t total_steps,
  82. int32_t remaining_masked,
  83. transfer_schedule schedule,
  84. float eps,
  85. const std::vector<int32_t> & num_transfer_tokens = {}) {
  86. switch (schedule) {
  87. case TIMESTEP_BASED:
  88. {
  89. float t = 1.0f - (float) step / total_steps * (1.0f - eps);
  90. float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps);
  91. float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f;
  92. return (int32_t) (remaining_masked * p_transfer);
  93. }
  94. case BLOCK_BASED:
  95. if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) {
  96. return num_transfer_tokens[step];
  97. }
  98. return remaining_masked / (total_steps - step); // Fallback
  99. default:
  100. return remaining_masked / (total_steps - step);
  101. }
  102. }
  103. static bool diffusion_step_callback(int32_t step,
  104. int32_t total_steps,
  105. const llama_token * tokens,
  106. int32_t n_tokens,
  107. void * user_data) {
  108. (void) user_data;
  109. callback_data * data = static_cast<callback_data *>(user_data);
  110. auto print_progress_bar = [](int32_t step, int32_t total_steps) {
  111. int progress_percent = (step * 100) / total_steps;
  112. int progress_bars = (step * 50) / total_steps;
  113. LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
  114. step,
  115. total_steps,
  116. std::string(progress_bars, '=').c_str(),
  117. std::string(50 - progress_bars, ' ').c_str(),
  118. progress_percent);
  119. };
  120. if (data->diff_params->visual_mode) {
  121. // Visual mode: clear
  122. LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left
  123. print_progress_bar(step, total_steps);
  124. LOG_INF("\n");
  125. std::string current_text = " ";
  126. for (int32_t i = data->n_input; i < n_tokens; i++) {
  127. std::string token_str;
  128. if (tokens[i] != llama_vocab_mask(data->vocab)) {
  129. char piece[256];
  130. int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
  131. if (n_chars > 0) {
  132. piece[n_chars] = '\0';
  133. token_str = piece;
  134. }
  135. } else {
  136. token_str = " ";
  137. }
  138. current_text += token_str;
  139. }
  140. LOG_INF("%s\n", current_text.c_str());
  141. } else {
  142. print_progress_bar(step, total_steps);
  143. }
  144. return true;
  145. }
  146. static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) {
  147. if (temperature == 0.0f) {
  148. return;
  149. }
  150. std::uniform_real_distribution<double> uniform(0.0, 1.0);
  151. for (int32_t i = 0; i < n_vocab; i++) {
  152. double noise = uniform(rng);
  153. // Prevent log(0)
  154. noise = std::max(noise, 1e-20);
  155. double gumbel_noise = std::pow(-std::log(noise), temperature);
  156. logits[i] = std::exp(logits[i]) / gumbel_noise;
  157. }
  158. }
  159. static std::vector<int32_t> get_num_transfer_tokens(int32_t mask_count, int32_t steps) {
  160. std::vector<int32_t> num_transfer_tokens(steps);
  161. int32_t base = mask_count / steps;
  162. int32_t remainder = mask_count % steps;
  163. for (int32_t i = 0; i < steps; i++) {
  164. num_transfer_tokens[i] = base + (i < remainder ? 1 : 0);
  165. }
  166. return num_transfer_tokens;
  167. }
  168. static void diffusion_generate(llama_context * ctx,
  169. const llama_token * input_tokens,
  170. llama_token * output_tokens,
  171. int32_t n_input,
  172. const diffusion_params & params,
  173. int32_t & n_generated) {
  174. n_generated = 0;
  175. if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) {
  176. return;
  177. }
  178. const llama_model * model = llama_get_model(ctx);
  179. // Initialize with input and pad with mask tokens
  180. std::copy(input_tokens, input_tokens + n_input, output_tokens);
  181. std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id);
  182. std::mt19937 rng(params.seed);
  183. llama_set_causal_attn(ctx, false);
  184. int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
  185. std::vector<llama_token_data> candidates(n_vocab);
  186. std::vector<llama_token_data> conf_candidates;
  187. conf_candidates.reserve(params.max_length);
  188. std::vector<int32_t> mask_positions;
  189. mask_positions.reserve(params.max_length);
  190. // Setup sampler chain
  191. struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
  192. if (params.top_k > 0) {
  193. llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
  194. }
  195. if (params.top_p < 1.0f) {
  196. llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
  197. }
  198. if (params.temperature > 0.0f) {
  199. llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
  200. }
  201. llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
  202. struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
  203. llama_batch batch = llama_batch_init(params.max_length, 0, 1);
  204. batch.n_tokens = params.max_length;
  205. // Pre-allocate buffers for CFG if needed
  206. int32_t logits_size = n_vocab * params.max_length;
  207. std::vector<float> cond_logits_buffer;
  208. std::vector<llama_token> un_x_buffer;
  209. if (params.cfg_scale > 0.0f) {
  210. cond_logits_buffer.resize(logits_size);
  211. un_x_buffer.resize(params.max_length);
  212. }
  213. // For block-based processing
  214. std::vector<int32_t> num_transfer_tokens;
  215. int32_t num_blocks = 1;
  216. int32_t steps_per_block = params.steps;
  217. if (params.schedule == BLOCK_BASED) {
  218. GGML_ASSERT(params.max_length % params.block_length == 0);
  219. num_blocks = params.max_length / params.block_length;
  220. GGML_ASSERT(params.steps % num_blocks == 0);
  221. steps_per_block = params.steps / num_blocks;
  222. }
  223. std::vector<float> confidence(params.max_length);
  224. int64_t total_sampling_time = 0;
  225. int64_t total_time = 0;
  226. int64_t time_start = ggml_time_us();
  227. for (int block_num = 0; block_num < num_blocks; block_num++) {
  228. int32_t block_start = (params.schedule == BLOCK_BASED) ? n_input + block_num * params.block_length : 0;
  229. int32_t block_end = (params.schedule == BLOCK_BASED) ?
  230. std::min(n_input + (block_num + 1) * params.block_length, params.max_length) :
  231. params.max_length;
  232. // Count masked tokens in current block for block-based processing
  233. if (params.schedule == BLOCK_BASED) {
  234. int32_t block_mask_count = 0;
  235. for (int i = block_start; i < block_end; i++) {
  236. if (output_tokens[i] == params.mask_token_id) {
  237. block_mask_count++;
  238. }
  239. }
  240. num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block);
  241. }
  242. for (int32_t step = 0; step < steps_per_block; step++) {
  243. int32_t global_step = block_num * steps_per_block + step;
  244. if (params.step_callback) {
  245. if (!params.step_callback(
  246. global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) {
  247. break;
  248. }
  249. }
  250. // Setup batch
  251. for (int32_t i = 0; i < params.max_length; i++) {
  252. batch.token[i] = output_tokens[i];
  253. batch.pos[i] = i;
  254. batch.n_seq_id[i] = 1;
  255. batch.seq_id[i][0] = 0;
  256. batch.logits[i] = 1;
  257. }
  258. float * logits = nullptr;
  259. if (params.cfg_scale > 0.0f) {
  260. int ret = llama_decode(ctx, batch);
  261. if (ret != 0) {
  262. LOG_ERR("Failed to generate conditional");
  263. break;
  264. }
  265. float * cond_logits_ptr = llama_get_logits(ctx);
  266. std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float));
  267. // Unconditional generation (mask input)
  268. std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin());
  269. for (int32_t i = 0; i < n_input; i++) {
  270. un_x_buffer[i] = params.mask_token_id;
  271. }
  272. for (int32_t i = 0; i < params.max_length; i++) {
  273. batch.token[i] = un_x_buffer[i];
  274. }
  275. ret = llama_decode(ctx, batch);
  276. if (ret != 0) {
  277. LOG_ERR("Failed to generate unconditional");
  278. break;
  279. }
  280. float * uncond_logits = llama_get_logits(ctx);
  281. // Apply CFG
  282. for (int32_t i = 0; i < logits_size; i++) {
  283. cond_logits_buffer[i] =
  284. uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]);
  285. }
  286. logits = cond_logits_buffer.data();
  287. } else {
  288. int ret = llama_decode(ctx, batch);
  289. if (ret != 0) {
  290. LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret);
  291. break;
  292. }
  293. logits = llama_get_logits(ctx);
  294. }
  295. if (!logits) {
  296. LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step);
  297. break;
  298. }
  299. auto get_logits_for_pos = [&](int32_t pos) -> const float * {
  300. if (params.shift_logits) {
  301. return pos == 0 ? logits : logits + (pos - 1) * n_vocab;
  302. }
  303. return logits + (pos) *n_vocab;
  304. };
  305. int64_t time_start_sampling = ggml_time_us();
  306. mask_positions.clear();
  307. for (int32_t i = 0; i < params.max_length; i++) {
  308. if (output_tokens[i] == params.mask_token_id) {
  309. // For block-based, only consider current block
  310. if (params.schedule != BLOCK_BASED || (i >= block_start && i < block_end)) {
  311. mask_positions.push_back(i);
  312. }
  313. }
  314. }
  315. if (mask_positions.empty()) {
  316. break;
  317. }
  318. if (params.add_gumbel_noise && params.temperature > 0.0f) {
  319. add_gumbel_noise(logits, n_vocab, params.temperature, rng);
  320. }
  321. if (params.algorithm == ORIGIN) {
  322. int32_t transfer_count = calculate_transfer_count(
  323. step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
  324. float p_transfer = (float) transfer_count / mask_positions.size();
  325. for (int32_t pos : mask_positions) {
  326. if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
  327. const float * pos_logits = get_logits_for_pos(pos);
  328. for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
  329. candidates[token_id].id = token_id;
  330. candidates[token_id].logit = pos_logits[token_id];
  331. candidates[token_id].p = 0.0f;
  332. }
  333. llama_token_data_array cur_p = {
  334. candidates.data(),
  335. (size_t) n_vocab,
  336. -1,
  337. false,
  338. };
  339. llama_sampler_apply(sampler, &cur_p);
  340. output_tokens[pos] = cur_p.data[cur_p.selected].id;
  341. }
  342. }
  343. } else {
  344. std::vector<std::pair<float, int32_t>> confidences;
  345. std::vector<llama_token> sampled_tokens(mask_positions.size());
  346. for (size_t i = 0; i < mask_positions.size(); i++) {
  347. int32_t pos = mask_positions[i];
  348. const float * pos_logits = get_logits_for_pos(pos);
  349. for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
  350. candidates[token_id].logit = pos_logits[token_id];
  351. candidates[token_id].p = 0.0f;
  352. candidates[token_id].id = token_id;
  353. }
  354. llama_token_data_array cur_p = {
  355. candidates.data(),
  356. candidates.size(),
  357. -1,
  358. false,
  359. };
  360. llama_sampler_apply(sampler, &cur_p);
  361. llama_token sampled_token = cur_p.data[cur_p.selected].id;
  362. float conf = calculate_confidence(cur_p, params.algorithm, rng);
  363. sampled_tokens[i] = sampled_token;
  364. confidences.emplace_back(conf, i);
  365. }
  366. int32_t transfer_count = calculate_transfer_count(
  367. step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
  368. if (transfer_count > 0) {
  369. if (params.alg_temp == 0.0f) {
  370. std::partial_sort(confidences.begin(),
  371. confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()),
  372. confidences.end(),
  373. [](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
  374. if (a.first != b.first) {
  375. return a.first > b.first;
  376. }
  377. return a.second < b.second;
  378. });
  379. for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
  380. int32_t mask_idx = confidences[i].second;
  381. int32_t pos = mask_positions[mask_idx];
  382. output_tokens[pos] = sampled_tokens[mask_idx];
  383. }
  384. } else {
  385. conf_candidates.clear();
  386. for (size_t i = 0; i < confidences.size(); i++) {
  387. float conf_logit = confidences[i].first / params.alg_temp;
  388. conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f });
  389. }
  390. llama_token_data_array conf_array = {
  391. conf_candidates.data(),
  392. conf_candidates.size(),
  393. -1,
  394. false,
  395. };
  396. for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
  397. llama_sampler_apply(dist_sampler, &conf_array);
  398. int32_t selected_idx = conf_array.selected;
  399. int32_t mask_idx = selected_idx;
  400. int32_t pos = mask_positions[mask_idx];
  401. output_tokens[pos] = sampled_tokens[mask_idx];
  402. conf_candidates[selected_idx].p = 0.0f;
  403. conf_array.selected = -1;
  404. }
  405. }
  406. }
  407. }
  408. int64_t time_end_sampling = ggml_time_us();
  409. total_sampling_time += time_end_sampling - time_start_sampling;
  410. }
  411. }
  412. int64_t time_end = ggml_time_us();
  413. total_time += time_end - time_start;
  414. LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
  415. total_time / 1000.0,
  416. total_time / 1000.0 / params.steps,
  417. total_sampling_time / 1000.0 / params.steps);
  418. llama_batch_free(batch);
  419. llama_sampler_free(sampler);
  420. llama_sampler_free(dist_sampler);
  421. n_generated = params.max_length;
  422. }
  423. static std::string format_input_text(const std::string & prompt, const std::string & system_prompt, bool use_chat_template, llama_model * model) {
  424. if (!use_chat_template) {
  425. return prompt;
  426. }
  427. auto chat_templates = common_chat_templates_init(model, "");
  428. common_chat_templates_inputs inputs;
  429. common_chat_msg system_msg;
  430. if (!system_prompt.empty()) {
  431. system_msg.role = "system";
  432. system_msg.content = system_prompt;
  433. inputs.messages.push_back(system_msg);
  434. }
  435. common_chat_msg user_msg;
  436. user_msg.role = "user";
  437. user_msg.content = prompt;
  438. inputs.messages.push_back(user_msg);
  439. inputs.add_generation_prompt = true;
  440. auto result = common_chat_templates_apply(chat_templates.get(), inputs);
  441. return result.prompt;
  442. }
  443. int main(int argc, char ** argv) {
  444. ggml_time_init();
  445. common_params params;
  446. if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) {
  447. return 1;
  448. }
  449. common_init();
  450. llama_backend_init();
  451. llama_model_params model_params = llama_model_default_params();
  452. model_params.n_gpu_layers = params.n_gpu_layers;
  453. model_params.devices = params.devices.data();
  454. model_params.use_mmap = params.use_mmap;
  455. model_params.use_mlock = params.use_mlock;
  456. model_params.check_tensors = params.check_tensors;
  457. llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
  458. if (!model) {
  459. LOG_ERR("error: failed to load model '%s'\n", params.model.path.c_str());
  460. return 1;
  461. }
  462. if (!llama_model_is_diffusion(model)) {
  463. LOG_ERR("error: unsupported model for diffusion");
  464. llama_model_free(model);
  465. return 1;
  466. }
  467. llama_context_params ctx_params = llama_context_default_params();
  468. ctx_params.n_ctx = params.n_ctx;
  469. ctx_params.n_batch = params.n_batch;
  470. ctx_params.n_ubatch = params.n_ubatch;
  471. ctx_params.flash_attn_type = params.flash_attn_type;
  472. ctx_params.no_perf = params.no_perf;
  473. ctx_params.type_k = params.cache_type_k;
  474. ctx_params.type_v = params.cache_type_v;
  475. llama_context * ctx = llama_init_from_model(model, ctx_params);
  476. if (!ctx) {
  477. LOG_ERR("error: failed to create context\n");
  478. llama_model_free(model);
  479. return 1;
  480. }
  481. llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads);
  482. const llama_vocab * vocab = llama_model_get_vocab(model);
  483. std::string formatted_prompt = format_input_text(params.prompt, params.system_prompt, params.enable_chat_template, model);
  484. std::vector<llama_token> input_tokens = common_tokenize(vocab,
  485. formatted_prompt,
  486. /*add special tokens*/ true,
  487. /*parse special*/ true);
  488. int n_input = input_tokens.size();
  489. if (n_input >= params.n_ctx) {
  490. LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
  491. llama_free(ctx);
  492. llama_model_free(model);
  493. return 1;
  494. }
  495. llama_token mask_token_id = llama_vocab_mask(vocab);
  496. GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
  497. bool visual_mode = params.diffusion.visual_mode;
  498. int32_t n_generated = 0;
  499. std::vector<llama_token> output_tokens(params.n_ubatch);
  500. struct diffusion_params diff_params;
  501. char shift_logits_str[8];
  502. if (llama_model_meta_val_str(model, "diffusion.shift_logits", shift_logits_str, sizeof(shift_logits_str)) >= 0) {
  503. diff_params.shift_logits = (strcmp(shift_logits_str, "true") == 0);
  504. } else {
  505. diff_params.shift_logits = true;
  506. }
  507. //Use either eps or block length, but not both
  508. GGML_ASSERT((params.diffusion.eps == 0) ^ (params.diffusion.block_length == 0));
  509. if (params.diffusion.eps) {
  510. diff_params.schedule = TIMESTEP_BASED;
  511. diff_params.eps = params.diffusion.eps;
  512. } else if (params.diffusion.block_length) {
  513. diff_params.schedule = BLOCK_BASED;
  514. diff_params.block_length = params.diffusion.block_length;
  515. }
  516. diff_params.mask_token_id = mask_token_id;
  517. diff_params.seed = params.sampling.seed;
  518. diff_params.temperature = params.sampling.temp;
  519. diff_params.steps = params.diffusion.steps;
  520. diff_params.algorithm = static_cast<diffusion_algorithm>(params.diffusion.algorithm);
  521. diff_params.max_length = params.n_ubatch;
  522. diff_params.top_p = params.sampling.top_p;
  523. diff_params.top_k = params.sampling.top_k;
  524. diff_params.visual_mode = params.diffusion.visual_mode;
  525. diff_params.add_gumbel_noise = params.diffusion.add_gumbel_noise;
  526. diff_params.step_callback = diffusion_step_callback;
  527. callback_data cb_data = { &diff_params, vocab, n_input };
  528. diff_params.step_callback_user_data = &cb_data;
  529. const char * alg_names[] = { "ORIGIN", "ENTROPY_BASED", "MARGIN_BASED", "RANDOM", "CONFIDENCE_BASED" };
  530. const char * sched_names[] = { "TIMESTEP_BASED", "BLOCK_BASED" };
  531. const char * alg_name =
  532. (diff_params.algorithm >= 0 && diff_params.algorithm <= 4) ? alg_names[diff_params.algorithm] : "UNKNOWN";
  533. const char * sched_name =
  534. (diff_params.schedule >= 0 && diff_params.schedule <= 1) ? sched_names[diff_params.schedule] : "UNKNOWN";
  535. LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
  536. LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", diff_params.steps);
  537. LOG_INF("diffusion_params: - %-25s u32 = %d\n", "max_length", diff_params.max_length);
  538. LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "algorithm", diff_params.algorithm, alg_name);
  539. LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "schedule", diff_params.schedule, sched_name);
  540. LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "temperature", diff_params.temperature);
  541. if (diff_params.schedule == TIMESTEP_BASED) {
  542. LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", diff_params.eps);
  543. LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", diff_params.alg_temp);
  544. }
  545. if (diff_params.schedule == BLOCK_BASED) {
  546. LOG_INF("diffusion_params: - %-25s u32 = %d\n", "block_length", diff_params.block_length);
  547. LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "cfg_scale", diff_params.cfg_scale);
  548. }
  549. diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, diff_params, n_generated);
  550. if (n_generated > 0) {
  551. if (visual_mode) {
  552. //clear screen and move cursor to top-left
  553. LOG_INF("\033[2J\033[H");
  554. }
  555. output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
  556. std::string output_data = common_detokenize(vocab, output_tokens, false);
  557. LOG_INF("\n%s\n", output_data.c_str());
  558. } else {
  559. LOG_INF("Error: diffusion generation failed\n");
  560. }
  561. llama_free(ctx);
  562. llama_model_free(model);
  563. llama_backend_free();
  564. return 0;
  565. }