|
|
@@ -5,6 +5,7 @@
|
|
|
#include <cstdio>
|
|
|
#include <string>
|
|
|
#include <vector>
|
|
|
+#include <set>
|
|
|
|
|
|
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
|
|
|
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
|
|
@@ -18,6 +19,7 @@ struct seq_draft {
|
|
|
std::vector<int> i_batch_tgt;
|
|
|
|
|
|
std::vector<llama_token> tokens;
|
|
|
+ std::vector<std::vector<llama_token_data>> dists;
|
|
|
|
|
|
struct llama_sampling_context * ctx_sampling;
|
|
|
};
|
|
|
@@ -37,12 +39,15 @@ int main(int argc, char ** argv) {
|
|
|
// max number of parallel drafting sequences (i.e. tree branches)
|
|
|
const int n_seq_dft = params.n_parallel;
|
|
|
|
|
|
- // probability threshold for accepting a token from the draft model
|
|
|
- const float p_accept = params.p_accept;
|
|
|
-
|
|
|
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
|
|
|
const float p_split = params.p_split;
|
|
|
|
|
|
+ if (params.seed == LLAMA_DEFAULT_SEED) {
|
|
|
+ params.seed = time(NULL);
|
|
|
+ }
|
|
|
+ std::default_random_engine rng(params.seed);
|
|
|
+ std::uniform_real_distribution<> u_dist;
|
|
|
+
|
|
|
#ifndef LOG_DISABLE_LOGS
|
|
|
log_set_target(log_filename_generator("speculative", "log"));
|
|
|
LOG_TEE("Log start\n");
|
|
|
@@ -166,7 +171,9 @@ int main(int argc, char ** argv) {
|
|
|
std::vector<seq_draft> drafts(n_seq_dft);
|
|
|
|
|
|
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
|
|
|
- params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
|
|
|
+ if (params.sparams.temp == 0) {
|
|
|
+ params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
|
|
|
+ }
|
|
|
|
|
|
for (int s = 0; s < n_seq_dft; ++s) {
|
|
|
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
|
|
|
@@ -182,12 +189,15 @@ int main(int argc, char ** argv) {
|
|
|
drafts[0].i_batch_tgt[0] = 0;
|
|
|
|
|
|
while (true) {
|
|
|
+ std::set<int> active_seqs = {};
|
|
|
+
|
|
|
// print current draft sequences
|
|
|
for (int s = 0; s < n_seq_dft; ++s) {
|
|
|
if (!drafts[s].active) {
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
+ active_seqs.insert(s);
|
|
|
const auto & tokens = drafts[s].tokens;
|
|
|
|
|
|
LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
|
|
|
@@ -196,48 +206,156 @@ int main(int argc, char ** argv) {
|
|
|
int i_dft = 0;
|
|
|
int s_keep = 0;
|
|
|
|
|
|
+ llama_token token_id;
|
|
|
+ std::string token_str;
|
|
|
+
|
|
|
+ // loop until we fail to accept a drafted token or we run out of drafted tokens
|
|
|
while (true) {
|
|
|
- LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
|
|
|
|
|
- // sample from the target model
|
|
|
- llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
|
|
+ // check if the target token matches any of the drafts
|
|
|
+ // for stochastic sampling, attempt to match the token with the drafted tokens
|
|
|
+ {
|
|
|
+ bool accept = false;
|
|
|
+ if (params.sparams.temp > 0) {
|
|
|
+ // stochastic verification
|
|
|
+
|
|
|
+ llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
|
|
+ float p_tgt = 0, p_dft = 0;
|
|
|
+
|
|
|
+ // GGML_ASSERT(dist_tgt.size() == dist_dft.size());
|
|
|
+
|
|
|
+ while (active_seqs.size() > 0) {
|
|
|
+ // randomly select a sequence to verify from active sequences
|
|
|
+ std::uniform_int_distribution<u_int> u_int_dist(0, active_seqs.size() - 1);
|
|
|
+ int s = *std::next(active_seqs.begin(), u_int_dist(rng));
|
|
|
+ if (i_dft >= (int) drafts[s].tokens.size()) {
|
|
|
+ drafts[s].active = false;
|
|
|
+ active_seqs.erase(s);
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (accept) {
|
|
|
+ // if we already accepted a token, we can skip the rest
|
|
|
+ if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) {
|
|
|
+ drafts[s].active = false;
|
|
|
+ active_seqs.erase(s);
|
|
|
+ }
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
|
|
|
+ float r = u_dist(rng);
|
|
|
+ llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
|
|
|
+ // acquire the token probabilities assigned by the draft and target models
|
|
|
+ for (size_t i = 0; i < dist_tgt.size; i++) {
|
|
|
+ if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
|
|
|
+ p_tgt = dist_tgt.data[i].p;
|
|
|
+ }
|
|
|
+ if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
|
|
|
+ p_dft = dist_dft.data[i].p;
|
|
|
+ }
|
|
|
+ if (p_tgt && p_dft) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ LOG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt);
|
|
|
+ if (r <= p_tgt / p_dft) {
|
|
|
+ s_keep = s;
|
|
|
+ accept = true;
|
|
|
+ token_id = drafts[s].tokens[i_dft];
|
|
|
+ token_str = llama_token_to_piece(ctx_tgt, token_id);
|
|
|
+ llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
|
|
+
|
|
|
+ LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
|
|
|
+ break;
|
|
|
+ } else {
|
|
|
+ LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
|
|
|
+ drafts[s].active = false;
|
|
|
+
|
|
|
+ // calculate residual probability
|
|
|
+ GGML_ASSERT(dist_tgt.sorted);
|
|
|
+ GGML_ASSERT(dist_dft.sorted);
|
|
|
+ float sum_probs = 0.0f;
|
|
|
+
|
|
|
+ // sort dist by id
|
|
|
+ std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
|
|
|
+ return a.id < b.id;
|
|
|
+ });
|
|
|
+ std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) {
|
|
|
+ return a.id < b.id;
|
|
|
+ });
|
|
|
+
|
|
|
+ for (size_t i = 0; i < dist_tgt.size; i++) {
|
|
|
+ dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
|
|
|
+ sum_probs += dist_tgt.data[i].p;
|
|
|
+ }
|
|
|
+ for (size_t i = 0; i < dist_tgt.size; i++) {
|
|
|
+ dist_tgt.data[i].p /= sum_probs;
|
|
|
+ }
|
|
|
|
|
|
- llama_sampling_accept(ctx_sampling, ctx_tgt, id, true);
|
|
|
+ // sort dist_tgt by p desc
|
|
|
+ std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
|
|
|
+ return a.p > b.p;
|
|
|
+ });
|
|
|
+ }
|
|
|
|
|
|
- //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
|
|
+ active_seqs.erase(s);
|
|
|
+ for(int i = 0; i < n_seq_dft; i++) {
|
|
|
+ if (i == s) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
|
|
|
+ // synchronize active status for sequences with the same drafted token
|
|
|
+ drafts[i].active = drafts[i].active && accept;
|
|
|
+ if (!drafts[i].active) {
|
|
|
+ active_seqs.erase(s);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- const std::string token_str = llama_token_to_piece(ctx_tgt, id);
|
|
|
+ if (!accept) {
|
|
|
+ // all drafted tokens were rejected
|
|
|
+ // sample from the target model
|
|
|
+ LOG("all drafted tokens were rejected, sampling from residual distribution\n");
|
|
|
+ token_id = llama_sample_token(ctx_tgt, &dist_tgt);
|
|
|
+ llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
|
|
+ token_str = llama_token_to_piece(ctx_tgt, token_id);
|
|
|
+ }
|
|
|
|
|
|
- if (!params.use_color) {
|
|
|
- printf("%s", token_str.c_str());
|
|
|
- }
|
|
|
+ } else {
|
|
|
+ // greedy verification
|
|
|
|
|
|
- if (id == llama_token_eos(model_tgt)) {
|
|
|
- has_eos = true;
|
|
|
- }
|
|
|
+ // sample from the target model
|
|
|
+ LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
|
|
+ token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
|
|
|
|
|
- ++n_predict;
|
|
|
+ llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
|
|
|
|
|
- // check if the target token matches any of the drafts
|
|
|
- {
|
|
|
- bool matches = false;
|
|
|
+ //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
|
|
|
|
|
- for (int s = 0; s < n_seq_dft; ++s) {
|
|
|
- if (!drafts[s].active) {
|
|
|
- continue;
|
|
|
- }
|
|
|
+ token_str = llama_token_to_piece(ctx_tgt, token_id);
|
|
|
|
|
|
- if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
|
|
|
- LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());
|
|
|
+ for (int s = 0; s < n_seq_dft; ++s) {
|
|
|
+ if (!drafts[s].active) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
|
|
|
- s_keep = s;
|
|
|
- matches = true;
|
|
|
- } else {
|
|
|
- drafts[s].active = false;
|
|
|
+ if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) {
|
|
|
+ LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str());
|
|
|
+
|
|
|
+ s_keep = s;
|
|
|
+ accept = true;
|
|
|
+ } else {
|
|
|
+ drafts[s].active = false;
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if (matches) {
|
|
|
+ if (token_id == llama_token_eos(model_tgt)) {
|
|
|
+ has_eos = true;
|
|
|
+ }
|
|
|
+ ++n_predict;
|
|
|
+
|
|
|
+ if (accept) {
|
|
|
++n_accept;
|
|
|
++n_past_tgt;
|
|
|
++n_past_dft;
|
|
|
@@ -245,17 +363,21 @@ int main(int argc, char ** argv) {
|
|
|
if (params.use_color) {
|
|
|
// Color token according to its origin sequence
|
|
|
printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
|
|
|
- fflush(stdout);
|
|
|
+ } else {
|
|
|
+ printf("%s", token_str.c_str());
|
|
|
}
|
|
|
+ fflush(stdout);
|
|
|
continue;
|
|
|
+ } else {
|
|
|
+ printf("%s", token_str.c_str());
|
|
|
+ fflush(stdout);
|
|
|
+ break;
|
|
|
}
|
|
|
}
|
|
|
- if (params.use_color) {
|
|
|
- printf("%s", token_str.c_str());
|
|
|
- }
|
|
|
- fflush(stdout);
|
|
|
+ }
|
|
|
|
|
|
- LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
|
|
|
+ {
|
|
|
+ LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
|
|
|
|
|
|
// TODO: simplify
|
|
|
{
|
|
|
@@ -275,21 +397,21 @@ int main(int argc, char ** argv) {
|
|
|
drafts[s].active = false;
|
|
|
drafts[s].tokens.clear();
|
|
|
drafts[s].i_batch_tgt.clear();
|
|
|
+ drafts[s].dists.clear();
|
|
|
}
|
|
|
// note: will be erased after the speculation phase
|
|
|
- drafts[0].tokens.push_back(id);
|
|
|
+ drafts[0].tokens.push_back(token_id);
|
|
|
+ drafts[0].dists.push_back(std::vector<llama_token_data>());
|
|
|
drafts[0].i_batch_tgt.push_back(0);
|
|
|
|
|
|
llama_batch_clear(batch_dft);
|
|
|
- llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
|
|
|
+ llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
|
|
|
|
|
|
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
|
|
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
|
|
- llama_decode (ctx_dft, batch_dft);
|
|
|
+ llama_decode(ctx_dft, batch_dft);
|
|
|
|
|
|
++n_past_dft;
|
|
|
-
|
|
|
- break;
|
|
|
}
|
|
|
|
|
|
if (n_predict > params.n_predict || has_eos) {
|
|
|
@@ -334,12 +456,6 @@ int main(int argc, char ** argv) {
|
|
|
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
|
|
|
}
|
|
|
|
|
|
- if (cur_p[0].p < p_accept) {
|
|
|
- LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept);
|
|
|
- drafts[s].drafting = false;
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
std::vector<int> sa(1, s);
|
|
|
|
|
|
// attempt to split the branch if the probability is high enough
|
|
|
@@ -367,6 +483,7 @@ int main(int argc, char ** argv) {
|
|
|
drafts[n_seq_cur].skip = true;
|
|
|
|
|
|
drafts[n_seq_cur].tokens = drafts[s].tokens;
|
|
|
+ drafts[n_seq_cur].dists = drafts[s].dists;
|
|
|
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
|
|
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
|
|
|
|
|
@@ -389,6 +506,8 @@ int main(int argc, char ** argv) {
|
|
|
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
|
|
|
|
|
|
drafts[s].tokens.push_back(id);
|
|
|
+ // save cur_p.data into drafts[s].dists
|
|
|
+ drafts[s].dists.push_back(cur_p);
|
|
|
|
|
|
// add unique drafted tokens to the target batch
|
|
|
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
|
|
@@ -440,6 +559,7 @@ int main(int argc, char ** argv) {
|
|
|
}
|
|
|
|
|
|
drafts[s].tokens.erase(drafts[s].tokens.begin());
|
|
|
+ drafts[s].dists.erase(drafts[s].dists.begin());
|
|
|
}
|
|
|
}
|
|
|
|