|
|
@@ -39,12 +39,12 @@ static std::ostringstream * g_output_ss;
|
|
|
static std::vector<llama_token> * g_output_tokens;
|
|
|
static bool is_interacting = false;
|
|
|
|
|
|
-static bool file_exists(const std::string &path) {
|
|
|
+static bool file_exists(const std::string & path) {
|
|
|
std::ifstream f(path.c_str());
|
|
|
return f.good();
|
|
|
}
|
|
|
|
|
|
-static bool file_is_empty(const std::string &path) {
|
|
|
+static bool file_is_empty(const std::string & path) {
|
|
|
std::ifstream f;
|
|
|
f.exceptions(std::ifstream::failbit | std::ifstream::badbit);
|
|
|
f.open(path.c_str(), std::ios::in | std::ios::binary | std::ios::ate);
|
|
|
@@ -117,6 +117,14 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
|
|
|
LOG_TEE("%s", text);
|
|
|
}
|
|
|
|
|
|
+static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
|
|
|
+ llama_chat_msg new_msg{role, content};
|
|
|
+ auto formatted = llama_chat_format_single(
|
|
|
+ model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
|
|
+ chat_msgs.push_back({role, content});
|
|
|
+ return formatted;
|
|
|
+}
|
|
|
+
|
|
|
int main(int argc, char ** argv) {
|
|
|
gpt_params params;
|
|
|
g_params = ¶ms;
|
|
|
@@ -190,6 +198,7 @@ int main(int argc, char ** argv) {
|
|
|
llama_model * model;
|
|
|
llama_context * ctx;
|
|
|
llama_context * ctx_guidance = NULL;
|
|
|
+ std::vector<llama_chat_msg> chat_msgs;
|
|
|
g_model = &model;
|
|
|
g_ctx = &ctx;
|
|
|
|
|
|
@@ -215,6 +224,8 @@ int main(int argc, char ** argv) {
|
|
|
__func__, n_ctx_train, n_ctx);
|
|
|
}
|
|
|
|
|
|
+ LOG_TEE("%s: chat template example: %s\n", __func__, llama_chat_format_example(model, params.chat_template).c_str());
|
|
|
+
|
|
|
// print system information
|
|
|
{
|
|
|
LOG_TEE("\n");
|
|
|
@@ -249,16 +260,21 @@ int main(int argc, char ** argv) {
|
|
|
|
|
|
std::vector<llama_token> embd_inp;
|
|
|
|
|
|
- if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
|
|
- LOG("tokenize the prompt\n");
|
|
|
- embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
|
|
|
- } else {
|
|
|
- LOG("use session tokens\n");
|
|
|
- embd_inp = session_tokens;
|
|
|
- }
|
|
|
+ {
|
|
|
+ auto prompt = params.conversation
|
|
|
+ ? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
|
|
|
+ : params.prompt;
|
|
|
+ if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
|
|
+ LOG("tokenize the prompt\n");
|
|
|
+ embd_inp = ::llama_tokenize(ctx, prompt, true, true);
|
|
|
+ } else {
|
|
|
+ LOG("use session tokens\n");
|
|
|
+ embd_inp = session_tokens;
|
|
|
+ }
|
|
|
|
|
|
- LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
|
|
|
- LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
|
|
+ LOG("prompt: \"%s\"\n", log_tostr(prompt));
|
|
|
+ LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
|
|
+ }
|
|
|
|
|
|
// Should not run without any tokens
|
|
|
if (embd_inp.empty()) {
|
|
|
@@ -478,6 +494,7 @@ int main(int argc, char ** argv) {
|
|
|
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
|
|
|
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
|
|
|
std::ostringstream output_ss; g_output_ss = &output_ss;
|
|
|
+ std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
|
|
|
|
|
|
// the first thing we will do is to output the prompt, so set color accordingly
|
|
|
console::set_display(console::prompt);
|
|
|
@@ -793,11 +810,18 @@ int main(int argc, char ** argv) {
|
|
|
is_antiprompt = true;
|
|
|
}
|
|
|
|
|
|
+ chat_add_and_format(model, chat_msgs, "system", assistant_ss.str());
|
|
|
is_interacting = true;
|
|
|
printf("\n");
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ // if current token is not EOG, we add it to current assistant message
|
|
|
+ if (params.conversation) {
|
|
|
+ auto id = llama_sampling_last(ctx_sampling);
|
|
|
+ assistant_ss << llama_token_to_piece(ctx, id, false);
|
|
|
+ }
|
|
|
+
|
|
|
if (n_past > 0 && is_interacting) {
|
|
|
LOG("waiting for user input\n");
|
|
|
|
|
|
@@ -848,8 +872,12 @@ int main(int argc, char ** argv) {
|
|
|
string_process_escapes(buffer);
|
|
|
}
|
|
|
|
|
|
+ std::string user_inp = params.conversation
|
|
|
+ ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
|
|
|
+ : std::move(buffer);
|
|
|
+ // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
|
|
|
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
|
|
|
- const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);
|
|
|
+ const auto line_inp = ::llama_tokenize(ctx, user_inp, false, params.conversation);
|
|
|
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
|
|
|
|
|
|
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
|
|
|
@@ -864,6 +892,9 @@ int main(int argc, char ** argv) {
|
|
|
output_ss << llama_token_to_piece(ctx, token);
|
|
|
}
|
|
|
|
|
|
+ // reset assistant message
|
|
|
+ assistant_ss.str("");
|
|
|
+
|
|
|
n_remain -= line_inp.size();
|
|
|
LOG("n_remain: %d\n", n_remain);
|
|
|
} else {
|