|
|
@@ -729,10 +729,12 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) {
|
|
|
|
|
|
// Function to tokenize the prompt
|
|
|
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
|
|
|
- std::vector<llama_token> & prompt_tokens) {
|
|
|
- const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
|
|
|
+ std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
|
|
|
+ const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0;
|
|
|
+
|
|
|
+ const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
|
|
|
prompt_tokens.resize(n_prompt_tokens);
|
|
|
- if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
|
|
|
+ if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first,
|
|
|
true) < 0) {
|
|
|
printe("failed to tokenize the prompt\n");
|
|
|
return -1;
|
|
|
@@ -778,7 +780,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
|
|
const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get());
|
|
|
|
|
|
std::vector<llama_token> tokens;
|
|
|
- if (tokenize_prompt(vocab, prompt, tokens) < 0) {
|
|
|
+ if (tokenize_prompt(vocab, prompt, tokens, llama_data) < 0) {
|
|
|
return 1;
|
|
|
}
|
|
|
|