|
@@ -5,67 +5,212 @@
|
|
|
#include "log.h"
|
|
#include "log.h"
|
|
|
|
|
|
|
|
#include <limits.h>
|
|
#include <limits.h>
|
|
|
-#include <string>
|
|
|
|
|
-#include <vector>
|
|
|
|
|
|
|
+
|
|
|
#include <algorithm>
|
|
#include <algorithm>
|
|
|
#include <cmath>
|
|
#include <cmath>
|
|
|
|
|
+#include <cstring>
|
|
|
#include <limits>
|
|
#include <limits>
|
|
|
#include <random>
|
|
#include <random>
|
|
|
|
|
+#include <string>
|
|
|
|
|
+#include <vector>
|
|
|
|
|
|
|
|
-typedef bool (*diffusion_step_callback_t)(int32_t step,
|
|
|
|
|
- int32_t total_steps,
|
|
|
|
|
- const llama_token * tokens,
|
|
|
|
|
- int32_t n_tokens,
|
|
|
|
|
- void * user_data);
|
|
|
|
|
-
|
|
|
|
|
-enum diffusion_alg {
|
|
|
|
|
- DIFFUSION_ALG_ORIGIN = 0,
|
|
|
|
|
- DIFFUSION_ALG_MASKGIT_PLUS = 1,
|
|
|
|
|
- DIFFUSION_ALG_TOPK_MARGIN = 2,
|
|
|
|
|
- DIFFUSION_ALG_ENTROPY = 3,
|
|
|
|
|
|
|
+enum diffusion_algorithm { ORIGIN = 0, ENTROPY_BASED = 1, MARGIN_BASED = 2, RANDOM = 3, CONFIDENCE_BASED = 4 };
|
|
|
|
|
+
|
|
|
|
|
+// Unified transfer scheduling methods
|
|
|
|
|
+enum transfer_schedule {
|
|
|
|
|
+ TIMESTEP_BASED = 0, // Dream-style: (1.0 - s/t) * remaining
|
|
|
|
|
+ BLOCK_BASED = 1, // LLaDA-style: process in blocks with get_num_transfer_tokens
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
|
|
+typedef bool (*diffusion_step_callback_t)(int32_t step,
|
|
|
|
|
+ int32_t total_steps,
|
|
|
|
|
+ const llama_token * tokens,
|
|
|
|
|
+ int32_t n_tokens,
|
|
|
|
|
+ void * user_data);
|
|
|
|
|
+
|
|
|
struct diffusion_params {
|
|
struct diffusion_params {
|
|
|
- int32_t steps;
|
|
|
|
|
- float eps;
|
|
|
|
|
- float temperature;
|
|
|
|
|
- float top_p;
|
|
|
|
|
- int32_t top_k;
|
|
|
|
|
- llama_token mask_token_id;
|
|
|
|
|
- enum diffusion_alg algorithm;
|
|
|
|
|
- float alg_temp;
|
|
|
|
|
- diffusion_step_callback_t step_callback;
|
|
|
|
|
- void * step_callback_user_data;
|
|
|
|
|
- int32_t seed;
|
|
|
|
|
|
|
+ int32_t steps = 0;
|
|
|
|
|
+ float temperature = 0;
|
|
|
|
|
+ llama_token mask_token_id = LLAMA_TOKEN_NULL;
|
|
|
|
|
+ diffusion_step_callback_t step_callback = nullptr;
|
|
|
|
|
+ void * step_callback_user_data = nullptr;
|
|
|
|
|
+ int32_t seed = 0;
|
|
|
|
|
+ bool visual_mode = false;
|
|
|
|
|
+ bool shift_logits = false; // Shift logits by -1 after decode
|
|
|
|
|
+
|
|
|
|
|
+ float top_p = 0.;
|
|
|
|
|
+ int32_t top_k = 0.;
|
|
|
|
|
+
|
|
|
|
|
+ diffusion_algorithm algorithm = CONFIDENCE_BASED;
|
|
|
|
|
+ transfer_schedule schedule = TIMESTEP_BASED;
|
|
|
|
|
+
|
|
|
|
|
+ float cfg_scale = 0.; // Config scale for classifier-free guidance
|
|
|
|
|
+ float eps = 0.; // Timestep scheduling
|
|
|
|
|
+ int32_t block_length = 0; // Block size (for block scheduling)
|
|
|
|
|
+ float alg_temp = 0; // algorithm temperature (0.0 = deterministic)
|
|
|
|
|
+ bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0
|
|
|
|
|
+
|
|
|
|
|
+ int32_t max_length = 0; // Maximum sequence length
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
|
|
+struct callback_data {
|
|
|
|
|
+ diffusion_params * diff_params;
|
|
|
|
|
+ const llama_vocab * vocab;
|
|
|
|
|
+ int32_t n_input;
|
|
|
|
|
+};
|
|
|
|
|
+
|
|
|
|
|
+static float calculate_confidence(const llama_token_data_array & cur_p,
|
|
|
|
|
+ diffusion_algorithm algorithm,
|
|
|
|
|
+ std::mt19937 & rng) {
|
|
|
|
|
+ switch (algorithm) {
|
|
|
|
|
+ case CONFIDENCE_BASED:
|
|
|
|
|
+ return cur_p.data[cur_p.selected].p; // Selected token probability
|
|
|
|
|
+
|
|
|
|
|
+ case ENTROPY_BASED:
|
|
|
|
|
+ {
|
|
|
|
|
+ float entropy = 0.0f;
|
|
|
|
|
+ const float epsilon = 1e-10f;
|
|
|
|
|
+ for (size_t i = 0; i < cur_p.size; i++) {
|
|
|
|
|
+ float prob = cur_p.data[i].p;
|
|
|
|
|
+ entropy += prob * logf(prob + epsilon);
|
|
|
|
|
+ }
|
|
|
|
|
+ return -entropy; // Higher entropy = lower confidence
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ case MARGIN_BASED:
|
|
|
|
|
+ return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p;
|
|
|
|
|
+
|
|
|
|
|
+ case RANDOM:
|
|
|
|
|
+ {
|
|
|
|
|
+ std::uniform_real_distribution<float> uniform(0.0f, 1.0f);
|
|
|
|
|
+ return uniform(rng); // Random confidence
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ case ORIGIN:
|
|
|
|
|
+ return cur_p.data[cur_p.selected].p;
|
|
|
|
|
+
|
|
|
|
|
+ default:
|
|
|
|
|
+ return 0.0f;
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Unified transfer count calculation function
|
|
|
|
|
+static int32_t calculate_transfer_count(int32_t step,
|
|
|
|
|
+ int32_t total_steps,
|
|
|
|
|
+ int32_t remaining_masked,
|
|
|
|
|
+ transfer_schedule schedule,
|
|
|
|
|
+ float eps,
|
|
|
|
|
+ const std::vector<int32_t> & num_transfer_tokens = {}) {
|
|
|
|
|
+ switch (schedule) {
|
|
|
|
|
+ case TIMESTEP_BASED:
|
|
|
|
|
+ {
|
|
|
|
|
+ float t = 1.0f - (float) step / total_steps * (1.0f - eps);
|
|
|
|
|
+ float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps);
|
|
|
|
|
+ float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f;
|
|
|
|
|
+ return (int32_t) (remaining_masked * p_transfer);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ case BLOCK_BASED:
|
|
|
|
|
+ if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) {
|
|
|
|
|
+ return num_transfer_tokens[step];
|
|
|
|
|
+ }
|
|
|
|
|
+ return remaining_masked / (total_steps - step); // Fallback
|
|
|
|
|
+
|
|
|
|
|
+ default:
|
|
|
|
|
+ return remaining_masked / (total_steps - step);
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+static bool diffusion_step_callback(int32_t step,
|
|
|
|
|
+ int32_t total_steps,
|
|
|
|
|
+ const llama_token * tokens,
|
|
|
|
|
+ int32_t n_tokens,
|
|
|
|
|
+ void * user_data) {
|
|
|
|
|
+ (void) user_data;
|
|
|
|
|
+
|
|
|
|
|
+ callback_data * data = static_cast<callback_data *>(user_data);
|
|
|
|
|
+
|
|
|
|
|
+ auto print_progress_bar = [](int32_t step, int32_t total_steps) {
|
|
|
|
|
+ int progress_percent = (step * 100) / total_steps;
|
|
|
|
|
+ int progress_bars = (step * 50) / total_steps;
|
|
|
|
|
+ LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
|
|
|
|
|
+ step,
|
|
|
|
|
+ total_steps,
|
|
|
|
|
+ std::string(progress_bars, '=').c_str(),
|
|
|
|
|
+ std::string(50 - progress_bars, ' ').c_str(),
|
|
|
|
|
+ progress_percent);
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ if (data->diff_params->visual_mode) {
|
|
|
|
|
+ // Visual mode: clear
|
|
|
|
|
+ LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left
|
|
|
|
|
+
|
|
|
|
|
+ print_progress_bar(step, total_steps);
|
|
|
|
|
+
|
|
|
|
|
+ LOG_INF("\n");
|
|
|
|
|
+
|
|
|
|
|
+ std::string current_text = " ";
|
|
|
|
|
+
|
|
|
|
|
+ for (int32_t i = data->n_input; i < n_tokens; i++) {
|
|
|
|
|
+ std::string token_str;
|
|
|
|
|
+ if (tokens[i] != llama_vocab_mask(data->vocab)) {
|
|
|
|
|
+ char piece[256];
|
|
|
|
|
+ int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
|
|
|
|
|
+ if (n_chars > 0) {
|
|
|
|
|
+ piece[n_chars] = '\0';
|
|
|
|
|
+ token_str = piece;
|
|
|
|
|
+ }
|
|
|
|
|
+ } else {
|
|
|
|
|
+ token_str = " ";
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ current_text += token_str;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
-static diffusion_params diffusion_default_params() {
|
|
|
|
|
- diffusion_params params = {};
|
|
|
|
|
- params.steps = 64;
|
|
|
|
|
- params.eps = 1e-3f;
|
|
|
|
|
- params.temperature = 0.2f;
|
|
|
|
|
- params.top_p = 0.95f;
|
|
|
|
|
- params.top_k = 0;
|
|
|
|
|
- params.mask_token_id = LLAMA_TOKEN_NULL;
|
|
|
|
|
- params.algorithm = DIFFUSION_ALG_ORIGIN;
|
|
|
|
|
- params.alg_temp = 0.0f;
|
|
|
|
|
- params.step_callback = nullptr;
|
|
|
|
|
- params.step_callback_user_data = nullptr;
|
|
|
|
|
- params.seed = 0;
|
|
|
|
|
- return params;
|
|
|
|
|
|
|
+ LOG_INF("%s\n", current_text.c_str());
|
|
|
|
|
+ } else {
|
|
|
|
|
+ print_progress_bar(step, total_steps);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return true;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-static void diffusion_generate(llama_context * ctx,
|
|
|
|
|
- const llama_token * input_tokens,
|
|
|
|
|
- llama_token * output_tokens,
|
|
|
|
|
- int32_t n_input,
|
|
|
|
|
- int32_t max_length,
|
|
|
|
|
- struct diffusion_params params,
|
|
|
|
|
- int32_t & n_generated) {
|
|
|
|
|
|
|
+static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) {
|
|
|
|
|
+ if (temperature == 0.0f) {
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ std::uniform_real_distribution<double> uniform(0.0, 1.0);
|
|
|
|
|
+ for (int32_t i = 0; i < n_vocab; i++) {
|
|
|
|
|
+ double noise = uniform(rng);
|
|
|
|
|
+ // Prevent log(0)
|
|
|
|
|
+ noise = std::max(noise, 1e-20);
|
|
|
|
|
+ double gumbel_noise = std::pow(-std::log(noise), temperature);
|
|
|
|
|
+ logits[i] = std::exp(logits[i]) / gumbel_noise;
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+static std::vector<int32_t> get_num_transfer_tokens(int32_t mask_count, int32_t steps) {
|
|
|
|
|
+ std::vector<int32_t> num_transfer_tokens(steps);
|
|
|
|
|
+
|
|
|
|
|
+ int32_t base = mask_count / steps;
|
|
|
|
|
+ int32_t remainder = mask_count % steps;
|
|
|
|
|
+
|
|
|
|
|
+ for (int32_t i = 0; i < steps; i++) {
|
|
|
|
|
+ num_transfer_tokens[i] = base + (i < remainder ? 1 : 0);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return num_transfer_tokens;
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
|
|
+static void diffusion_generate(llama_context * ctx,
|
|
|
|
|
+ const llama_token * input_tokens,
|
|
|
|
|
+ llama_token * output_tokens,
|
|
|
|
|
+ int32_t n_input,
|
|
|
|
|
+ const diffusion_params & params,
|
|
|
|
|
+ int32_t & n_generated) {
|
|
|
n_generated = 0;
|
|
n_generated = 0;
|
|
|
- if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || max_length <= n_input) {
|
|
|
|
|
|
|
+ if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) {
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -73,27 +218,21 @@ static void diffusion_generate(llama_context * ctx,
|
|
|
|
|
|
|
|
// Initialize with input and pad with mask tokens
|
|
// Initialize with input and pad with mask tokens
|
|
|
std::copy(input_tokens, input_tokens + n_input, output_tokens);
|
|
std::copy(input_tokens, input_tokens + n_input, output_tokens);
|
|
|
- std::fill(output_tokens + n_input, output_tokens + max_length, params.mask_token_id);
|
|
|
|
|
|
|
+ std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id);
|
|
|
|
|
|
|
|
std::mt19937 rng(params.seed);
|
|
std::mt19937 rng(params.seed);
|
|
|
|
|
|
|
|
- std::vector<float> timesteps(params.steps + 1);
|
|
|
|
|
- for (int32_t i = 0; i <= params.steps; i++) {
|
|
|
|
|
- timesteps[i] = 1.0f - (float) i / params.steps * (1.0f - params.eps);
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
llama_set_causal_attn(ctx, false);
|
|
llama_set_causal_attn(ctx, false);
|
|
|
|
|
|
|
|
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
|
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
|
|
|
|
|
|
|
std::vector<llama_token_data> candidates(n_vocab);
|
|
std::vector<llama_token_data> candidates(n_vocab);
|
|
|
-
|
|
|
|
|
std::vector<llama_token_data> conf_candidates;
|
|
std::vector<llama_token_data> conf_candidates;
|
|
|
- conf_candidates.reserve(max_length);
|
|
|
|
|
-
|
|
|
|
|
|
|
+ conf_candidates.reserve(params.max_length);
|
|
|
std::vector<int32_t> mask_positions;
|
|
std::vector<int32_t> mask_positions;
|
|
|
- mask_positions.reserve(max_length);
|
|
|
|
|
|
|
+ mask_positions.reserve(params.max_length);
|
|
|
|
|
|
|
|
|
|
+ // Setup sampler chain
|
|
|
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
|
|
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
|
|
|
if (params.top_k > 0) {
|
|
if (params.top_k > 0) {
|
|
|
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
|
|
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
|
|
@@ -108,210 +247,269 @@ static void diffusion_generate(llama_context * ctx,
|
|
|
|
|
|
|
|
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
|
|
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
|
|
|
|
|
|
|
|
- llama_batch batch = llama_batch_init(max_length, 0, 1);
|
|
|
|
|
- batch.n_tokens = max_length;
|
|
|
|
|
|
|
+ llama_batch batch = llama_batch_init(params.max_length, 0, 1);
|
|
|
|
|
+ batch.n_tokens = params.max_length;
|
|
|
|
|
|
|
|
- int64_t total_sampling_time = 0;
|
|
|
|
|
- int64_t total_time = 0;
|
|
|
|
|
|
|
+ // Pre-allocate buffers for CFG if needed
|
|
|
|
|
+ int32_t logits_size = n_vocab * params.max_length;
|
|
|
|
|
+ std::vector<float> cond_logits_buffer;
|
|
|
|
|
+ std::vector<llama_token> un_x_buffer;
|
|
|
|
|
+ if (params.cfg_scale > 0.0f) {
|
|
|
|
|
+ cond_logits_buffer.resize(logits_size);
|
|
|
|
|
+ un_x_buffer.resize(params.max_length);
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- int64_t time_start = ggml_time_us();
|
|
|
|
|
- for (int32_t step = 0; step < params.steps; step++) {
|
|
|
|
|
- if (params.step_callback) {
|
|
|
|
|
- if (!params.step_callback(step, params.steps, output_tokens, max_length, params.step_callback_user_data)) {
|
|
|
|
|
- break;
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ // For block-based processing
|
|
|
|
|
+ std::vector<int32_t> num_transfer_tokens;
|
|
|
|
|
+ int32_t num_blocks = 1;
|
|
|
|
|
+ int32_t steps_per_block = params.steps;
|
|
|
|
|
|
|
|
- for (int32_t i = 0; i < max_length; i++) {
|
|
|
|
|
- batch.token[i] = output_tokens[i];
|
|
|
|
|
- batch.pos[i] = i;
|
|
|
|
|
- batch.n_seq_id[i] = 1;
|
|
|
|
|
- batch.seq_id[i][0] = 0;
|
|
|
|
|
- batch.logits[i] = 1;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ if (params.schedule == BLOCK_BASED) {
|
|
|
|
|
+ GGML_ASSERT(params.max_length % params.block_length == 0);
|
|
|
|
|
+ num_blocks = params.max_length / params.block_length;
|
|
|
|
|
+ GGML_ASSERT(params.steps % num_blocks == 0);
|
|
|
|
|
+ steps_per_block = params.steps / num_blocks;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- int ret = llama_decode(ctx, batch);
|
|
|
|
|
- if (ret != 0) {
|
|
|
|
|
- LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, step, ret);
|
|
|
|
|
- break;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ std::vector<float> confidence(params.max_length);
|
|
|
|
|
|
|
|
- float * raw_logits = llama_get_logits(ctx);
|
|
|
|
|
- if (!raw_logits) {
|
|
|
|
|
- LOG_ERR("%s: failed to get logits at step %d\n", __func__, step);
|
|
|
|
|
- break;
|
|
|
|
|
|
|
+ int64_t total_sampling_time = 0;
|
|
|
|
|
+ int64_t total_time = 0;
|
|
|
|
|
+ int64_t time_start = ggml_time_us();
|
|
|
|
|
+
|
|
|
|
|
+ for (int block_num = 0; block_num < num_blocks; block_num++) {
|
|
|
|
|
+ int32_t block_start = (params.schedule == BLOCK_BASED) ? n_input + block_num * params.block_length : 0;
|
|
|
|
|
+ int32_t block_end = (params.schedule == BLOCK_BASED) ?
|
|
|
|
|
+ std::min(n_input + (block_num + 1) * params.block_length, params.max_length) :
|
|
|
|
|
+ params.max_length;
|
|
|
|
|
+
|
|
|
|
|
+ // Count masked tokens in current block for block-based processing
|
|
|
|
|
+ if (params.schedule == BLOCK_BASED) {
|
|
|
|
|
+ int32_t block_mask_count = 0;
|
|
|
|
|
+ for (int i = block_start; i < block_end; i++) {
|
|
|
|
|
+ if (output_tokens[i] == params.mask_token_id) {
|
|
|
|
|
+ block_mask_count++;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- auto get_logits_for_pos = [&](int32_t pos) -> const float * {
|
|
|
|
|
- return pos == 0 ? raw_logits : raw_logits + (pos - 1) * n_vocab;
|
|
|
|
|
- };
|
|
|
|
|
-
|
|
|
|
|
- int64_t time_start_sampling = ggml_time_us();
|
|
|
|
|
|
|
+ for (int32_t step = 0; step < steps_per_block; step++) {
|
|
|
|
|
+ int32_t global_step = block_num * steps_per_block + step;
|
|
|
|
|
|
|
|
- mask_positions.clear();
|
|
|
|
|
- for (int32_t i = 0; i < max_length; i++) {
|
|
|
|
|
- if (output_tokens[i] == params.mask_token_id) {
|
|
|
|
|
- mask_positions.push_back(i);
|
|
|
|
|
|
|
+ if (params.step_callback) {
|
|
|
|
|
+ if (!params.step_callback(
|
|
|
|
|
+ global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) {
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
- }
|
|
|
|
|
|
|
|
|
|
- if (mask_positions.empty()) {
|
|
|
|
|
- break;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ // Setup batch
|
|
|
|
|
+ for (int32_t i = 0; i < params.max_length; i++) {
|
|
|
|
|
+ batch.token[i] = output_tokens[i];
|
|
|
|
|
+ batch.pos[i] = i;
|
|
|
|
|
+ batch.n_seq_id[i] = 1;
|
|
|
|
|
+ batch.seq_id[i][0] = 0;
|
|
|
|
|
+ batch.logits[i] = 1;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- float t = timesteps[step];
|
|
|
|
|
- float s = timesteps[step + 1];
|
|
|
|
|
|
|
+ float * logits = nullptr;
|
|
|
|
|
|
|
|
- if (params.algorithm == DIFFUSION_ALG_ORIGIN) {
|
|
|
|
|
- float p_transfer = (step < params.steps - 1) ? (1.0f - s / t) : 1.0f;
|
|
|
|
|
|
|
+ if (params.cfg_scale > 0.0f) {
|
|
|
|
|
+ int ret = llama_decode(ctx, batch);
|
|
|
|
|
+ if (ret != 0) {
|
|
|
|
|
+ LOG_ERR("Failed to generate conditional");
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ float * cond_logits_ptr = llama_get_logits(ctx);
|
|
|
|
|
+ std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float));
|
|
|
|
|
|
|
|
- for (int32_t pos : mask_positions) {
|
|
|
|
|
- if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
|
|
|
|
|
- const float * pos_logits = get_logits_for_pos(pos);
|
|
|
|
|
- for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
|
|
|
|
- candidates[token_id].id = token_id;
|
|
|
|
|
- candidates[token_id].logit = pos_logits[token_id];
|
|
|
|
|
- candidates[token_id].p = 0.0f;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ // Unconditional generation (mask input)
|
|
|
|
|
+ std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin());
|
|
|
|
|
+ for (int32_t i = 0; i < n_input; i++) {
|
|
|
|
|
+ un_x_buffer[i] = params.mask_token_id;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- llama_token_data_array cur_p = {
|
|
|
|
|
- /* .data = */ candidates.data(),
|
|
|
|
|
- /* .size = */ (size_t) n_vocab, // Reset size to full vocab
|
|
|
|
|
- /* .selected = */ -1,
|
|
|
|
|
- /* .sorted = */ false,
|
|
|
|
|
- };
|
|
|
|
|
|
|
+ for (int32_t i = 0; i < params.max_length; i++) {
|
|
|
|
|
+ batch.token[i] = un_x_buffer[i];
|
|
|
|
|
+ }
|
|
|
|
|
+ ret = llama_decode(ctx, batch);
|
|
|
|
|
+ if (ret != 0) {
|
|
|
|
|
+ LOG_ERR("Failed to generate unconditional");
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ float * uncond_logits = llama_get_logits(ctx);
|
|
|
|
|
|
|
|
- llama_sampler_apply(sampler, &cur_p);
|
|
|
|
|
- output_tokens[pos] = cur_p.data[cur_p.selected].id;
|
|
|
|
|
|
|
+ // Apply CFG
|
|
|
|
|
+ for (int32_t i = 0; i < logits_size; i++) {
|
|
|
|
|
+ cond_logits_buffer[i] =
|
|
|
|
|
+ uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]);
|
|
|
}
|
|
}
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- std::vector<std::pair<float, int32_t>> confidences;
|
|
|
|
|
- std::vector<llama_token> sampled_tokens(mask_positions.size());
|
|
|
|
|
-
|
|
|
|
|
- for (size_t i = 0; i < mask_positions.size(); i++) {
|
|
|
|
|
- int32_t pos = mask_positions[i];
|
|
|
|
|
- const float * pos_logits = get_logits_for_pos(pos);
|
|
|
|
|
-
|
|
|
|
|
- for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
|
|
|
|
- candidates[token_id].logit = pos_logits[token_id];
|
|
|
|
|
- candidates[token_id].p = 0.0f;
|
|
|
|
|
- candidates[token_id].id = token_id;
|
|
|
|
|
|
|
+ logits = cond_logits_buffer.data();
|
|
|
|
|
+ } else {
|
|
|
|
|
+ int ret = llama_decode(ctx, batch);
|
|
|
|
|
+ if (ret != 0) {
|
|
|
|
|
+ LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret);
|
|
|
|
|
+ break;
|
|
|
}
|
|
}
|
|
|
|
|
+ logits = llama_get_logits(ctx);
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- llama_token_data_array cur_p = {
|
|
|
|
|
- /* .data = */ candidates.data(),
|
|
|
|
|
- /* .size = */ candidates.size(),
|
|
|
|
|
- /* .selected = */ -1,
|
|
|
|
|
- /* .sorted = */ false,
|
|
|
|
|
- };
|
|
|
|
|
|
|
+ if (!logits) {
|
|
|
|
|
+ LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step);
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- llama_sampler_apply(sampler, &cur_p);
|
|
|
|
|
|
|
+ auto get_logits_for_pos = [&](int32_t pos) -> const float * {
|
|
|
|
|
+ if (params.shift_logits) {
|
|
|
|
|
+ return pos == 0 ? logits : logits + (pos - 1) * n_vocab;
|
|
|
|
|
+ }
|
|
|
|
|
+ return logits + (pos) *n_vocab;
|
|
|
|
|
+ };
|
|
|
|
|
|
|
|
- llama_token sampled_token = cur_p.data[cur_p.selected].id;
|
|
|
|
|
|
|
+ int64_t time_start_sampling = ggml_time_us();
|
|
|
|
|
|
|
|
- float confidence = 0.0f;
|
|
|
|
|
- if (params.algorithm == DIFFUSION_ALG_ENTROPY) {
|
|
|
|
|
- const float epsilon = 1e-10f;
|
|
|
|
|
- for (size_t j = 0; j < cur_p.size; j++) {
|
|
|
|
|
- float prob = cur_p.data[j].p;
|
|
|
|
|
- confidence += prob * logf(prob + epsilon);
|
|
|
|
|
|
|
+ mask_positions.clear();
|
|
|
|
|
+ for (int32_t i = 0; i < params.max_length; i++) {
|
|
|
|
|
+ if (output_tokens[i] == params.mask_token_id) {
|
|
|
|
|
+ // For block-based, only consider current block
|
|
|
|
|
+ if (params.schedule != BLOCK_BASED || (i >= block_start && i < block_end)) {
|
|
|
|
|
+ mask_positions.push_back(i);
|
|
|
}
|
|
}
|
|
|
- } else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) {
|
|
|
|
|
- confidence = cur_p.data[0].p - cur_p.data[1].p;
|
|
|
|
|
- } else {
|
|
|
|
|
- confidence = cur_p.data[cur_p.selected].p;
|
|
|
|
|
}
|
|
}
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- sampled_tokens[i] = sampled_token;
|
|
|
|
|
- confidences.emplace_back(confidence, i);
|
|
|
|
|
|
|
+ if (mask_positions.empty()) {
|
|
|
|
|
+ break;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- int32_t num_transfer =
|
|
|
|
|
- (step < params.steps - 1) ? (int32_t) (mask_positions.size() * (1.0f - s / t)) : mask_positions.size();
|
|
|
|
|
-
|
|
|
|
|
- if (num_transfer > 0) {
|
|
|
|
|
- if (params.alg_temp == 0.0f) {
|
|
|
|
|
- std::partial_sort(confidences.begin(), confidences.begin() + num_transfer, confidences.end(),
|
|
|
|
|
- [](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
|
|
|
|
|
- if (a.first != b.first) {
|
|
|
|
|
- return a.first > b.first;
|
|
|
|
|
- }
|
|
|
|
|
- return a.second < b.second;
|
|
|
|
|
- });
|
|
|
|
|
- } else {
|
|
|
|
|
- conf_candidates.clear();
|
|
|
|
|
-
|
|
|
|
|
- for (int32_t pos = 0; pos < max_length; pos++) {
|
|
|
|
|
- float conf_logit = -std::numeric_limits<float>::infinity();
|
|
|
|
|
-
|
|
|
|
|
- auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
|
|
|
|
|
- if (it != mask_positions.end()) {
|
|
|
|
|
- size_t mask_idx = std::distance(mask_positions.begin(), it);
|
|
|
|
|
- conf_logit = confidences[mask_idx].first / params.alg_temp; // Apply temperature scaling
|
|
|
|
|
|
|
+ if (params.add_gumbel_noise && params.temperature > 0.0f) {
|
|
|
|
|
+ add_gumbel_noise(logits, n_vocab, params.temperature, rng);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (params.algorithm == ORIGIN) {
|
|
|
|
|
+ int32_t transfer_count = calculate_transfer_count(
|
|
|
|
|
+ step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
|
|
|
|
|
+ float p_transfer = (float) transfer_count / mask_positions.size();
|
|
|
|
|
+
|
|
|
|
|
+ for (int32_t pos : mask_positions) {
|
|
|
|
|
+ if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
|
|
|
|
|
+ const float * pos_logits = get_logits_for_pos(pos);
|
|
|
|
|
+ for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
|
|
|
|
+ candidates[token_id].id = token_id;
|
|
|
|
|
+ candidates[token_id].logit = pos_logits[token_id];
|
|
|
|
|
+ candidates[token_id].p = 0.0f;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- conf_candidates.emplace_back(llama_token_data{ pos, conf_logit, 0.0f });
|
|
|
|
|
|
|
+ llama_token_data_array cur_p = {
|
|
|
|
|
+ candidates.data(),
|
|
|
|
|
+ (size_t) n_vocab,
|
|
|
|
|
+ -1,
|
|
|
|
|
+ false,
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ llama_sampler_apply(sampler, &cur_p);
|
|
|
|
|
+ output_tokens[pos] = cur_p.data[cur_p.selected].id;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ } else {
|
|
|
|
|
+ std::vector<std::pair<float, int32_t>> confidences;
|
|
|
|
|
+ std::vector<llama_token> sampled_tokens(mask_positions.size());
|
|
|
|
|
+
|
|
|
|
|
+ for (size_t i = 0; i < mask_positions.size(); i++) {
|
|
|
|
|
+ int32_t pos = mask_positions[i];
|
|
|
|
|
+ const float * pos_logits = get_logits_for_pos(pos);
|
|
|
|
|
+
|
|
|
|
|
+ for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
|
|
|
|
+ candidates[token_id].logit = pos_logits[token_id];
|
|
|
|
|
+ candidates[token_id].p = 0.0f;
|
|
|
|
|
+ candidates[token_id].id = token_id;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- llama_token_data_array conf_array = {
|
|
|
|
|
- /* .data = */ conf_candidates.data(),
|
|
|
|
|
- /* .size = */ conf_candidates.size(),
|
|
|
|
|
- /* .selected = */ -1,
|
|
|
|
|
- /* .sorted = */ false,
|
|
|
|
|
|
|
+ llama_token_data_array cur_p = {
|
|
|
|
|
+ candidates.data(),
|
|
|
|
|
+ candidates.size(),
|
|
|
|
|
+ -1,
|
|
|
|
|
+ false,
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
- for (int32_t i = 0; i < num_transfer; i++) {
|
|
|
|
|
- // Apply distribution sampler to get selected index
|
|
|
|
|
- llama_sampler_apply(dist_sampler, &conf_array);
|
|
|
|
|
- int selected_idx = conf_array.selected;
|
|
|
|
|
- confidences[i].second = conf_candidates[selected_idx].id;
|
|
|
|
|
|
|
+ llama_sampler_apply(sampler, &cur_p);
|
|
|
|
|
+ llama_token sampled_token = cur_p.data[cur_p.selected].id;
|
|
|
|
|
+
|
|
|
|
|
+ float conf = calculate_confidence(cur_p, params.algorithm, rng);
|
|
|
|
|
|
|
|
- conf_candidates[selected_idx].p = 0.0f;
|
|
|
|
|
- conf_array.selected = -1;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ sampled_tokens[i] = sampled_token;
|
|
|
|
|
+ confidences.emplace_back(conf, i);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if (params.alg_temp == 0.0f) {
|
|
|
|
|
- // Deterministic - use confidence order
|
|
|
|
|
- for (int32_t i = 0; i < num_transfer; i++) {
|
|
|
|
|
- int32_t mask_idx = confidences[i].second;
|
|
|
|
|
- int32_t pos = mask_positions[mask_idx];
|
|
|
|
|
- llama_token token = sampled_tokens[mask_idx];
|
|
|
|
|
- output_tokens[pos] = token;
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- for (int32_t i = 0; i < num_transfer; i++) {
|
|
|
|
|
- int32_t pos = confidences[i].second;
|
|
|
|
|
- auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
|
|
|
|
|
- if (it != mask_positions.end()) {
|
|
|
|
|
- int32_t mask_idx = std::distance(mask_positions.begin(), it);
|
|
|
|
|
|
|
+ int32_t transfer_count = calculate_transfer_count(
|
|
|
|
|
+ step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
|
|
|
|
|
+
|
|
|
|
|
+ if (transfer_count > 0) {
|
|
|
|
|
+ if (params.alg_temp == 0.0f) {
|
|
|
|
|
+ std::partial_sort(confidences.begin(),
|
|
|
|
|
+ confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()),
|
|
|
|
|
+ confidences.end(),
|
|
|
|
|
+ [](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
|
|
|
|
|
+ if (a.first != b.first) {
|
|
|
|
|
+ return a.first > b.first;
|
|
|
|
|
+ }
|
|
|
|
|
+ return a.second < b.second;
|
|
|
|
|
+ });
|
|
|
|
|
+
|
|
|
|
|
+ for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
|
|
|
|
|
+ int32_t mask_idx = confidences[i].second;
|
|
|
|
|
+ int32_t pos = mask_positions[mask_idx];
|
|
|
output_tokens[pos] = sampled_tokens[mask_idx];
|
|
output_tokens[pos] = sampled_tokens[mask_idx];
|
|
|
}
|
|
}
|
|
|
|
|
+ } else {
|
|
|
|
|
+ conf_candidates.clear();
|
|
|
|
|
+ for (size_t i = 0; i < confidences.size(); i++) {
|
|
|
|
|
+ float conf_logit = confidences[i].first / params.alg_temp;
|
|
|
|
|
+ conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f });
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ llama_token_data_array conf_array = {
|
|
|
|
|
+ conf_candidates.data(),
|
|
|
|
|
+ conf_candidates.size(),
|
|
|
|
|
+ -1,
|
|
|
|
|
+ false,
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
|
|
|
|
|
+ llama_sampler_apply(dist_sampler, &conf_array);
|
|
|
|
|
+ int32_t selected_idx = conf_array.selected;
|
|
|
|
|
+ int32_t mask_idx = selected_idx;
|
|
|
|
|
+ int32_t pos = mask_positions[mask_idx];
|
|
|
|
|
+ output_tokens[pos] = sampled_tokens[mask_idx];
|
|
|
|
|
+
|
|
|
|
|
+ conf_candidates[selected_idx].p = 0.0f;
|
|
|
|
|
+ conf_array.selected = -1;
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ int64_t time_end_sampling = ggml_time_us();
|
|
|
|
|
+ total_sampling_time += time_end_sampling - time_start_sampling;
|
|
|
}
|
|
}
|
|
|
- int64_t time_end_sampling = ggml_time_us();
|
|
|
|
|
- total_sampling_time += time_end_sampling - time_start_sampling;
|
|
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
int64_t time_end = ggml_time_us();
|
|
int64_t time_end = ggml_time_us();
|
|
|
total_time += time_end - time_start;
|
|
total_time += time_end - time_start;
|
|
|
|
|
|
|
|
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
|
|
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
|
|
|
- total_time / 1000.0, total_time / 1000.0 / params.steps, total_sampling_time / 1000.0 / params.steps);
|
|
|
|
|
-
|
|
|
|
|
|
|
+ total_time / 1000.0,
|
|
|
|
|
+ total_time / 1000.0 / params.steps,
|
|
|
|
|
+ total_sampling_time / 1000.0 / params.steps);
|
|
|
|
|
|
|
|
llama_batch_free(batch);
|
|
llama_batch_free(batch);
|
|
|
llama_sampler_free(sampler);
|
|
llama_sampler_free(sampler);
|
|
|
llama_sampler_free(dist_sampler);
|
|
llama_sampler_free(dist_sampler);
|
|
|
|
|
|
|
|
- n_generated = max_length;
|
|
|
|
|
|
|
+ n_generated = params.max_length;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
|
|
static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
|
|
|
if (!use_chat_template) {
|
|
if (!use_chat_template) {
|
|
|
return prompt;
|
|
return prompt;
|
|
@@ -331,66 +529,6 @@ static std::string format_input_text(const std::string & prompt, bool use_chat_t
|
|
|
return result.prompt;
|
|
return result.prompt;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-struct callback_data {
|
|
|
|
|
- const common_params_diffusion * diff_params;
|
|
|
|
|
- const llama_vocab * vocab;
|
|
|
|
|
- int32_t n_input;
|
|
|
|
|
-};
|
|
|
|
|
-
|
|
|
|
|
-static bool diffusion_step_callback(int32_t step,
|
|
|
|
|
- int32_t total_steps,
|
|
|
|
|
- const llama_token * tokens,
|
|
|
|
|
- int32_t n_tokens,
|
|
|
|
|
- void * user_data) {
|
|
|
|
|
- (void)user_data;
|
|
|
|
|
-
|
|
|
|
|
- callback_data * data = static_cast<callback_data *>(user_data);
|
|
|
|
|
-
|
|
|
|
|
- auto print_progress_bar = [](int32_t step, int32_t total_steps) {
|
|
|
|
|
- int progress_percent = (step * 100) / total_steps;
|
|
|
|
|
- int progress_bars = (step * 50) / total_steps;
|
|
|
|
|
- LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
|
|
|
|
|
- step,
|
|
|
|
|
- total_steps,
|
|
|
|
|
- std::string(progress_bars, '=').c_str(),
|
|
|
|
|
- std::string(50 - progress_bars, ' ').c_str(),
|
|
|
|
|
- progress_percent);
|
|
|
|
|
- };
|
|
|
|
|
-
|
|
|
|
|
- if (data->diff_params->visual_mode) {
|
|
|
|
|
- // Visual mode: clear
|
|
|
|
|
- LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left
|
|
|
|
|
-
|
|
|
|
|
- print_progress_bar(step, total_steps);
|
|
|
|
|
-
|
|
|
|
|
- LOG_INF("\n");
|
|
|
|
|
-
|
|
|
|
|
- std::string current_text = " ";
|
|
|
|
|
-
|
|
|
|
|
- for (int32_t i = data->n_input; i < n_tokens; i++) {
|
|
|
|
|
- std::string token_str;
|
|
|
|
|
- if (tokens[i] != llama_vocab_mask(data->vocab)) {
|
|
|
|
|
- char piece[256];
|
|
|
|
|
- int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
|
|
|
|
|
- if (n_chars > 0) {
|
|
|
|
|
- piece[n_chars] = '\0';
|
|
|
|
|
- token_str = piece;
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- token_str = " ";
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- current_text += token_str;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- LOG_INF("%s\n", current_text.c_str());
|
|
|
|
|
- } else {
|
|
|
|
|
- print_progress_bar(step, total_steps);
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- return true;
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
int main(int argc, char ** argv) {
|
|
int main(int argc, char ** argv) {
|
|
|
ggml_time_init();
|
|
ggml_time_init();
|
|
|
|
|
|
|
@@ -400,11 +538,6 @@ int main(int argc, char ** argv) {
|
|
|
return 1;
|
|
return 1;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- const char * alg_names[] = { "ORIGIN", "MASKGIT_PLUS", "TOPK_MARGIN", "ENTROPY" };
|
|
|
|
|
- const char * alg_name = (params.diffusion.algorithm >= 0 && params.diffusion.algorithm <= 3) ?
|
|
|
|
|
- alg_names[params.diffusion.algorithm] :
|
|
|
|
|
- "UNKNOWN";
|
|
|
|
|
-
|
|
|
|
|
common_init();
|
|
common_init();
|
|
|
llama_backend_init();
|
|
llama_backend_init();
|
|
|
|
|
|
|
@@ -421,6 +554,12 @@ int main(int argc, char ** argv) {
|
|
|
return 1;
|
|
return 1;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ if (!llama_model_is_diffusion(model)) {
|
|
|
|
|
+ LOG_ERR("error: unsupported model for diffusion");
|
|
|
|
|
+ llama_model_free(model);
|
|
|
|
|
+ return 1;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
llama_context_params ctx_params = llama_context_default_params();
|
|
llama_context_params ctx_params = llama_context_default_params();
|
|
|
ctx_params.n_ctx = params.n_ctx;
|
|
ctx_params.n_ctx = params.n_ctx;
|
|
|
ctx_params.n_batch = params.n_batch;
|
|
ctx_params.n_batch = params.n_batch;
|
|
@@ -442,10 +581,12 @@ int main(int argc, char ** argv) {
|
|
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
|
std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);
|
|
std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);
|
|
|
|
|
|
|
|
- std::vector<llama_token> input_tokens = common_tokenize(vocab, formatted_prompt,
|
|
|
|
|
|
|
+ std::vector<llama_token> input_tokens = common_tokenize(vocab,
|
|
|
|
|
+ formatted_prompt,
|
|
|
/*add special tokens*/ true,
|
|
/*add special tokens*/ true,
|
|
|
/*parse special*/ true);
|
|
/*parse special*/ true);
|
|
|
- int n_input = input_tokens.size();
|
|
|
|
|
|
|
+
|
|
|
|
|
+ int n_input = input_tokens.size();
|
|
|
|
|
|
|
|
if (n_input >= params.n_ctx) {
|
|
if (n_input >= params.n_ctx) {
|
|
|
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
|
|
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
|
|
@@ -454,44 +595,79 @@ int main(int argc, char ** argv) {
|
|
|
return 1;
|
|
return 1;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- struct diffusion_params ldiff_params = diffusion_default_params();
|
|
|
|
|
- ldiff_params.steps = params.diffusion.steps;
|
|
|
|
|
- ldiff_params.eps = params.diffusion.eps;
|
|
|
|
|
- ldiff_params.temperature = params.sampling.temp;
|
|
|
|
|
- ldiff_params.top_p = params.sampling.top_p;
|
|
|
|
|
- ldiff_params.top_k = params.sampling.top_k;
|
|
|
|
|
- ldiff_params.algorithm = static_cast<enum diffusion_alg>(params.diffusion.algorithm);
|
|
|
|
|
- ldiff_params.alg_temp = params.diffusion.alg_temp;
|
|
|
|
|
- ldiff_params.seed = params.sampling.seed;
|
|
|
|
|
-
|
|
|
|
|
llama_token mask_token_id = llama_vocab_mask(vocab);
|
|
llama_token mask_token_id = llama_vocab_mask(vocab);
|
|
|
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
|
|
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
|
|
|
|
|
|
|
|
- LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
|
|
|
|
|
- LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", params.diffusion.steps);
|
|
|
|
|
- LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", params.diffusion.eps);
|
|
|
|
|
- LOG_INF("diffusion_params: - %-25s u32 = %d (%s)\n", "algorithm", params.diffusion.algorithm,
|
|
|
|
|
- alg_name);
|
|
|
|
|
- LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", params.diffusion.alg_temp);
|
|
|
|
|
|
|
+ bool visual_mode = params.diffusion.visual_mode;
|
|
|
|
|
|
|
|
- ldiff_params.mask_token_id = mask_token_id;
|
|
|
|
|
|
|
+ int32_t n_generated = 0;
|
|
|
|
|
+ std::vector<llama_token> output_tokens(params.n_ubatch);
|
|
|
|
|
|
|
|
- callback_data cb_data = { ¶ms.diffusion, vocab, n_input };
|
|
|
|
|
|
|
+ struct diffusion_params diff_params;
|
|
|
|
|
|
|
|
- ldiff_params.step_callback = diffusion_step_callback;
|
|
|
|
|
- ldiff_params.step_callback_user_data = &cb_data;
|
|
|
|
|
|
|
+ char shift_logits_str[8];
|
|
|
|
|
+ if (llama_model_meta_val_str(model, "diffusion.shift_logits", shift_logits_str, sizeof(shift_logits_str)) >= 0) {
|
|
|
|
|
+ diff_params.shift_logits = (strcmp(shift_logits_str, "true") == 0);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ diff_params.shift_logits = true;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- int32_t n_generated = 0;
|
|
|
|
|
|
|
+ //Use either eps or block length, but not both
|
|
|
|
|
+ GGML_ASSERT((params.diffusion.eps == 0) ^ (params.diffusion.block_length == 0));
|
|
|
|
|
|
|
|
- std::vector<llama_token> output_tokens(params.n_ubatch);
|
|
|
|
|
- diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, params.n_ubatch,
|
|
|
|
|
- ldiff_params, n_generated);
|
|
|
|
|
|
|
+ if (params.diffusion.eps) {
|
|
|
|
|
+ diff_params.schedule = TIMESTEP_BASED;
|
|
|
|
|
+ diff_params.eps = params.diffusion.eps;
|
|
|
|
|
+ } else if (params.diffusion.block_length) {
|
|
|
|
|
+ diff_params.schedule = BLOCK_BASED;
|
|
|
|
|
+ diff_params.block_length = params.diffusion.block_length;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ diff_params.mask_token_id = mask_token_id;
|
|
|
|
|
+ diff_params.seed = params.sampling.seed;
|
|
|
|
|
+ diff_params.temperature = params.sampling.temp;
|
|
|
|
|
+ diff_params.steps = params.diffusion.steps;
|
|
|
|
|
+ diff_params.algorithm = static_cast<diffusion_algorithm>(params.diffusion.algorithm);
|
|
|
|
|
+ diff_params.max_length = params.n_ubatch;
|
|
|
|
|
+ diff_params.top_p = params.sampling.top_p;
|
|
|
|
|
+ diff_params.top_k = params.sampling.top_k;
|
|
|
|
|
+ diff_params.visual_mode = params.diffusion.visual_mode;
|
|
|
|
|
+ diff_params.add_gumbel_noise = params.diffusion.add_gumbel_noise;
|
|
|
|
|
+
|
|
|
|
|
+ diff_params.step_callback = diffusion_step_callback;
|
|
|
|
|
+ callback_data cb_data = { &diff_params, vocab, n_input };
|
|
|
|
|
+ diff_params.step_callback_user_data = &cb_data;
|
|
|
|
|
+
|
|
|
|
|
+ const char * alg_names[] = { "ORIGIN", "ENTROPY_BASED", "MARGIN_BASED", "RANDOM", "CONFIDENCE_BASED" };
|
|
|
|
|
+ const char * sched_names[] = { "TIMESTEP_BASED", "BLOCK_BASED" };
|
|
|
|
|
+ const char * alg_name =
|
|
|
|
|
+ (diff_params.algorithm >= 0 && diff_params.algorithm <= 4) ? alg_names[diff_params.algorithm] : "UNKNOWN";
|
|
|
|
|
+ const char * sched_name =
|
|
|
|
|
+ (diff_params.schedule >= 0 && diff_params.schedule <= 1) ? sched_names[diff_params.schedule] : "UNKNOWN";
|
|
|
|
|
+
|
|
|
|
|
+ LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
|
|
|
|
|
+ LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", diff_params.steps);
|
|
|
|
|
+ LOG_INF("diffusion_params: - %-25s u32 = %d\n", "max_length", diff_params.max_length);
|
|
|
|
|
+ LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "algorithm", diff_params.algorithm, alg_name);
|
|
|
|
|
+ LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "schedule", diff_params.schedule, sched_name);
|
|
|
|
|
+ LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "temperature", diff_params.temperature);
|
|
|
|
|
+ if (diff_params.schedule == TIMESTEP_BASED) {
|
|
|
|
|
+ LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", diff_params.eps);
|
|
|
|
|
+ LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", diff_params.alg_temp);
|
|
|
|
|
+ }
|
|
|
|
|
+ if (diff_params.schedule == BLOCK_BASED) {
|
|
|
|
|
+ LOG_INF("diffusion_params: - %-25s u32 = %d\n", "block_length", diff_params.block_length);
|
|
|
|
|
+ LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "cfg_scale", diff_params.cfg_scale);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, diff_params, n_generated);
|
|
|
|
|
|
|
|
if (n_generated > 0) {
|
|
if (n_generated > 0) {
|
|
|
- if (params.diffusion.visual_mode) {
|
|
|
|
|
|
|
+ if (visual_mode) {
|
|
|
//clear screen and move cursor to top-left
|
|
//clear screen and move cursor to top-left
|
|
|
LOG_INF("\033[2J\033[H");
|
|
LOG_INF("\033[2J\033[H");
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
|
|
output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
|
|
|
std::string output_data = common_detokenize(vocab, output_tokens, false);
|
|
std::string output_data = common_detokenize(vocab, output_tokens, false);
|
|
|
LOG_INF("\n%s\n", output_data.c_str());
|
|
LOG_INF("\n%s\n", output_data.c_str());
|