|
@@ -8,6 +8,9 @@
|
|
|
#include <string>
|
|
#include <string>
|
|
|
#include <vector>
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
+#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
|
|
|
|
|
+#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
|
|
|
|
+
|
|
|
struct seq_draft {
|
|
struct seq_draft {
|
|
|
bool active = false;
|
|
bool active = false;
|
|
|
bool drafting = false;
|
|
bool drafting = false;
|
|
@@ -64,6 +67,33 @@ int main(int argc, char ** argv) {
|
|
|
params.n_gpu_layers = params.n_gpu_layers_draft;
|
|
params.n_gpu_layers = params.n_gpu_layers_draft;
|
|
|
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
|
|
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
|
|
|
|
|
|
|
|
|
|
+ {
|
|
|
|
|
+ const int n_vocab_tgt = llama_n_vocab(model_tgt);
|
|
|
|
|
+ const int n_vocab_dft = llama_n_vocab(model_dft);
|
|
|
|
|
+ const int vocab_diff = n_vocab_tgt > n_vocab_dft
|
|
|
|
|
+ ? n_vocab_tgt - n_vocab_dft
|
|
|
|
|
+ : n_vocab_dft - n_vocab_tgt;
|
|
|
|
|
+
|
|
|
|
|
+ if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
|
|
|
|
+ fprintf(stderr, "%s: error: draft model vocab must closely match target model to use speculation but ", __func__);
|
|
|
|
|
+ fprintf(stderr, "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
|
|
|
|
+ n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
|
|
|
|
+ return 1;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
|
|
|
|
+ const char * token_text_tgt = llama_token_get_text(model_tgt, i);
|
|
|
|
|
+ const char * token_text_dft = llama_token_get_text(model_dft, i);
|
|
|
|
|
+ if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
|
|
|
|
+ fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
|
|
|
|
|
+ fprintf(stderr, "token %d content differs - target '%s', draft '%s'\n", i,
|
|
|
|
|
+ llama_token_to_piece(ctx_tgt, i).c_str(),
|
|
|
|
|
+ llama_token_to_piece(ctx_dft, i).c_str());
|
|
|
|
|
+ return 1;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
// tokenize the prompt
|
|
// tokenize the prompt
|
|
|
std::vector<llama_token> inp;
|
|
std::vector<llama_token> inp;
|
|
|
inp = ::llama_tokenize(ctx_tgt, params.prompt, true);
|
|
inp = ::llama_tokenize(ctx_tgt, params.prompt, true);
|
|
@@ -227,6 +257,7 @@ int main(int argc, char ** argv) {
|
|
|
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
|
|
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
|
|
|
|
|
|
|
|
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
|
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
|
|
|
|
+ // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
|
|
llama_decode (ctx_dft, batch_dft);
|
|
llama_decode (ctx_dft, batch_dft);
|
|
|
|
|
|
|
|
++n_past_dft;
|
|
++n_past_dft;
|
|
@@ -370,7 +401,7 @@ int main(int argc, char ** argv) {
|
|
|
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
|
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- //LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt));
|
|
|
|
|
|
|
+ // LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
|
|
|
llama_decode(ctx_tgt, batch_tgt);
|
|
llama_decode(ctx_tgt, batch_tgt);
|
|
|
++n_past_tgt;
|
|
++n_past_tgt;
|
|
|
}
|
|
}
|