|
@@ -117,7 +117,8 @@ int main(int argc, char ** argv) {
|
|
|
llama_token id_last = inp.back();
|
|
llama_token id_last = inp.back();
|
|
|
|
|
|
|
|
// all tokens currently in the target context
|
|
// all tokens currently in the target context
|
|
|
- auto prompt_tgt = std::vector<llama_token>(inp.begin(), inp.end() - 1);
|
|
|
|
|
|
|
+ llama_tokens prompt_tgt(inp.begin(), inp.end() - 1);
|
|
|
|
|
+ prompt_tgt.reserve(llama_n_ctx(ctx_tgt));
|
|
|
|
|
|
|
|
int n_past = inp.size() - 1;
|
|
int n_past = inp.size() - 1;
|
|
|
|
|
|
|
@@ -181,54 +182,44 @@ int main(int argc, char ** argv) {
|
|
|
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
|
|
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
|
|
|
|
|
|
|
|
n_past += ids.size() - 1;
|
|
n_past += ids.size() - 1;
|
|
|
- n_drafted += batch_tgt.n_tokens - 1;
|
|
|
|
|
|
|
+ n_drafted += draft.size(); // note: we ignore the discarded small drafts
|
|
|
n_accept += ids.size() - 1;
|
|
n_accept += ids.size() - 1;
|
|
|
|
|
+ n_predict += ids.size();
|
|
|
|
|
|
|
|
// process the accepted tokens and update contexts
|
|
// process the accepted tokens and update contexts
|
|
|
//
|
|
//
|
|
|
// this is the standard token post-processing that we normally do
|
|
// this is the standard token post-processing that we normally do
|
|
|
// in this case, we do it for a group of accepted tokens at once
|
|
// in this case, we do it for a group of accepted tokens at once
|
|
|
//
|
|
//
|
|
|
- {
|
|
|
|
|
- llama_token id;
|
|
|
|
|
- std::string token_str;
|
|
|
|
|
-
|
|
|
|
|
- for (size_t i = 0; i < ids.size(); ++i) {
|
|
|
|
|
- id = ids[i];
|
|
|
|
|
-
|
|
|
|
|
- ++n_predict;
|
|
|
|
|
-
|
|
|
|
|
- if (llama_token_is_eog(model_tgt, id)) {
|
|
|
|
|
- has_eos = true;
|
|
|
|
|
- break;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- token_str = common_token_to_piece(ctx_tgt, id);
|
|
|
|
|
|
|
+ for (size_t i = 0; i < ids.size(); ++i) {
|
|
|
|
|
+ prompt_tgt.push_back(id_last);
|
|
|
|
|
|
|
|
- if (params.use_color && i + 1 < ids.size()) {
|
|
|
|
|
- LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
|
|
|
|
|
- } else {
|
|
|
|
|
- LOG("%s", token_str.c_str());
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ id_last = ids[i];
|
|
|
|
|
|
|
|
- if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
|
|
|
|
|
|
+ if (llama_token_is_eog(model_tgt, id_last)) {
|
|
|
|
|
+ has_eos = true;
|
|
|
break;
|
|
break;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, (int) draft.size(), id, token_str.c_str());
|
|
|
|
|
|
|
+ const std::string token_str = common_token_to_piece(ctx_tgt, id_last);
|
|
|
|
|
|
|
|
- {
|
|
|
|
|
- LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
|
|
|
|
|
-
|
|
|
|
|
- llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
|
|
|
|
|
|
|
+ if (params.use_color && i + 1 < ids.size()) {
|
|
|
|
|
+ LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
|
|
|
|
|
+ } else {
|
|
|
|
|
+ LOG("%s", token_str.c_str());
|
|
|
}
|
|
}
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- prompt_tgt.push_back(id_last);
|
|
|
|
|
- prompt_tgt.insert(prompt_tgt.end(), ids.begin(), ids.end() - 1);
|
|
|
|
|
|
|
+ LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);
|
|
|
|
|
+
|
|
|
|
|
+ {
|
|
|
|
|
+ LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
|
|
|
|
|
+
|
|
|
|
|
+ llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // remember the last accepted token for the next iteration
|
|
|
|
|
- id_last = id;
|
|
|
|
|
|
|
+ if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
|
|
|
|
+ break;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|