| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206 |
- #include "llama.h"
- #include <cstdio>
- #include <cstring>
- #include <iostream>
- #include <string>
- #include <vector>
- static void print_usage(int, char ** argv) {
- printf("\nexample usage:\n");
- printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]);
- printf("\n");
- }
- int main(int argc, char ** argv) {
- std::string model_path;
- int ngl = 99;
- int n_ctx = 2048;
- // parse command line arguments
- for (int i = 1; i < argc; i++) {
- try {
- if (strcmp(argv[i], "-m") == 0) {
- if (i + 1 < argc) {
- model_path = argv[++i];
- } else {
- print_usage(argc, argv);
- return 1;
- }
- } else if (strcmp(argv[i], "-c") == 0) {
- if (i + 1 < argc) {
- n_ctx = std::stoi(argv[++i]);
- } else {
- print_usage(argc, argv);
- return 1;
- }
- } else if (strcmp(argv[i], "-ngl") == 0) {
- if (i + 1 < argc) {
- ngl = std::stoi(argv[++i]);
- } else {
- print_usage(argc, argv);
- return 1;
- }
- } else {
- print_usage(argc, argv);
- return 1;
- }
- } catch (std::exception & e) {
- fprintf(stderr, "error: %s\n", e.what());
- print_usage(argc, argv);
- return 1;
- }
- }
- if (model_path.empty()) {
- print_usage(argc, argv);
- return 1;
- }
- // only print errors
- llama_log_set([](enum ggml_log_level level, const char * text, void * /* user_data */) {
- if (level >= GGML_LOG_LEVEL_ERROR) {
- fprintf(stderr, "%s", text);
- }
- }, nullptr);
- // load dynamic backends
- ggml_backend_load_all();
- // initialize the model
- llama_model_params model_params = llama_model_default_params();
- model_params.n_gpu_layers = ngl;
- llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params);
- if (!model) {
- fprintf(stderr , "%s: error: unable to load model\n" , __func__);
- return 1;
- }
- const llama_vocab * vocab = llama_model_get_vocab(model);
- // initialize the context
- llama_context_params ctx_params = llama_context_default_params();
- ctx_params.n_ctx = n_ctx;
- ctx_params.n_batch = n_ctx;
- llama_context * ctx = llama_init_from_model(model, ctx_params);
- if (!ctx) {
- fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
- return 1;
- }
- // initialize the sampler
- llama_sampler * smpl = llama_sampler_chain_init(llama_sampler_chain_default_params());
- llama_sampler_chain_add(smpl, llama_sampler_init_min_p(0.05f, 1));
- llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.8f));
- llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
- // helper function to evaluate a prompt and generate a response
- auto generate = [&](const std::string & prompt) {
- std::string response;
- const bool is_first = llama_get_kv_cache_used_cells(ctx) == 0;
- // tokenize the prompt
- const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
- std::vector<llama_token> prompt_tokens(n_prompt_tokens);
- if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, true) < 0) {
- GGML_ABORT("failed to tokenize the prompt\n");
- }
- // prepare a batch for the prompt
- llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
- llama_token new_token_id;
- while (true) {
- // check if we have enough space in the context to evaluate this batch
- int n_ctx = llama_n_ctx(ctx);
- int n_ctx_used = llama_get_kv_cache_used_cells(ctx);
- if (n_ctx_used + batch.n_tokens > n_ctx) {
- printf("\033[0m\n");
- fprintf(stderr, "context size exceeded\n");
- exit(0);
- }
- if (llama_decode(ctx, batch)) {
- GGML_ABORT("failed to decode\n");
- }
- // sample the next token
- new_token_id = llama_sampler_sample(smpl, ctx, -1);
- // is it an end of generation?
- if (llama_vocab_is_eog(vocab, new_token_id)) {
- break;
- }
- // convert the token to a string, print it and add it to the response
- char buf[256];
- int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true);
- if (n < 0) {
- GGML_ABORT("failed to convert token to piece\n");
- }
- std::string piece(buf, n);
- printf("%s", piece.c_str());
- fflush(stdout);
- response += piece;
- // prepare the next batch with the sampled token
- batch = llama_batch_get_one(&new_token_id, 1);
- }
- return response;
- };
- std::vector<llama_chat_message> messages;
- std::vector<char> formatted(llama_n_ctx(ctx));
- int prev_len = 0;
- while (true) {
- // get user input
- printf("\033[32m> \033[0m");
- std::string user;
- std::getline(std::cin, user);
- if (user.empty()) {
- break;
- }
- const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
- // add the user input to the message list and format it
- messages.push_back({"user", strdup(user.c_str())});
- int new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
- if (new_len > (int)formatted.size()) {
- formatted.resize(new_len);
- new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
- }
- if (new_len < 0) {
- fprintf(stderr, "failed to apply the chat template\n");
- return 1;
- }
- // remove previous messages to obtain the prompt to generate the response
- std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len);
- // generate a response
- printf("\033[33m");
- std::string response = generate(prompt);
- printf("\n\033[0m");
- // add the response to the messages
- messages.push_back({"assistant", strdup(response.c_str())});
- prev_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), false, nullptr, 0);
- if (prev_len < 0) {
- fprintf(stderr, "failed to apply the chat template\n");
- return 1;
- }
- }
- // free resources
- for (auto & msg : messages) {
- free(const_cast<char *>(msg.content));
- }
- llama_sampler_free(smpl);
- llama_free(ctx);
- llama_model_free(model);
- return 0;
- }
|