run.cpp 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051
  1. #if defined(_WIN32)
  2. # include <windows.h>
  3. # include <io.h>
  4. #else
  5. # include <sys/file.h>
  6. # include <sys/ioctl.h>
  7. # include <unistd.h>
  8. #endif
  9. #if defined(LLAMA_USE_CURL)
  10. # include <curl/curl.h>
  11. #endif
  12. #include <signal.h>
  13. #include <climits>
  14. #include <cstdarg>
  15. #include <cstdio>
  16. #include <cstring>
  17. #include <filesystem>
  18. #include <iostream>
  19. #include <list>
  20. #include <sstream>
  21. #include <string>
  22. #include <vector>
  23. #include "common.h"
  24. #include "json.hpp"
  25. #include "linenoise.cpp/linenoise.h"
  26. #include "llama-cpp.h"
  27. #include "chat-template.hpp"
  28. #if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
  29. [[noreturn]] static void sigint_handler(int) {
  30. printf("\n\033[0m");
  31. exit(0); // not ideal, but it's the only way to guarantee exit in all cases
  32. }
  33. #endif
  34. GGML_ATTRIBUTE_FORMAT(1, 2)
  35. static std::string fmt(const char * fmt, ...) {
  36. va_list ap;
  37. va_list ap2;
  38. va_start(ap, fmt);
  39. va_copy(ap2, ap);
  40. const int size = vsnprintf(NULL, 0, fmt, ap);
  41. GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
  42. std::string buf;
  43. buf.resize(size);
  44. const int size2 = vsnprintf(const_cast<char *>(buf.data()), buf.size() + 1, fmt, ap2);
  45. GGML_ASSERT(size2 == size);
  46. va_end(ap2);
  47. va_end(ap);
  48. return buf;
  49. }
  50. GGML_ATTRIBUTE_FORMAT(1, 2)
  51. static int printe(const char * fmt, ...) {
  52. va_list args;
  53. va_start(args, fmt);
  54. const int ret = vfprintf(stderr, fmt, args);
  55. va_end(args);
  56. return ret;
  57. }
  58. class Opt {
  59. public:
  60. int init(int argc, const char ** argv) {
  61. ctx_params = llama_context_default_params();
  62. model_params = llama_model_default_params();
  63. context_size_default = ctx_params.n_batch;
  64. ngl_default = model_params.n_gpu_layers;
  65. common_params_sampling sampling;
  66. temperature_default = sampling.temp;
  67. if (argc < 2) {
  68. printe("Error: No arguments provided.\n");
  69. print_help();
  70. return 1;
  71. }
  72. // Parse arguments
  73. if (parse(argc, argv)) {
  74. printe("Error: Failed to parse arguments.\n");
  75. print_help();
  76. return 1;
  77. }
  78. // If help is requested, show help and exit
  79. if (help) {
  80. print_help();
  81. return 2;
  82. }
  83. ctx_params.n_batch = context_size >= 0 ? context_size : context_size_default;
  84. ctx_params.n_ctx = ctx_params.n_batch;
  85. model_params.n_gpu_layers = ngl >= 0 ? ngl : ngl_default;
  86. temperature = temperature >= 0 ? temperature : temperature_default;
  87. return 0; // Success
  88. }
  89. llama_context_params ctx_params;
  90. llama_model_params model_params;
  91. std::string model_;
  92. std::string user;
  93. bool use_jinja = false;
  94. int context_size = -1, ngl = -1;
  95. float temperature = -1;
  96. bool verbose = false;
  97. private:
  98. int context_size_default = -1, ngl_default = -1;
  99. float temperature_default = -1;
  100. bool help = false;
  101. bool parse_flag(const char ** argv, int i, const char * short_opt, const char * long_opt) {
  102. return strcmp(argv[i], short_opt) == 0 || strcmp(argv[i], long_opt) == 0;
  103. }
  104. int handle_option_with_value(int argc, const char ** argv, int & i, int & option_value) {
  105. if (i + 1 >= argc) {
  106. return 1;
  107. }
  108. option_value = std::atoi(argv[++i]);
  109. return 0;
  110. }
  111. int handle_option_with_value(int argc, const char ** argv, int & i, float & option_value) {
  112. if (i + 1 >= argc) {
  113. return 1;
  114. }
  115. option_value = std::atof(argv[++i]);
  116. return 0;
  117. }
  118. int parse(int argc, const char ** argv) {
  119. bool options_parsing = true;
  120. for (int i = 1, positional_args_i = 0; i < argc; ++i) {
  121. if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) {
  122. if (handle_option_with_value(argc, argv, i, context_size) == 1) {
  123. return 1;
  124. }
  125. } else if (options_parsing && (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0)) {
  126. if (handle_option_with_value(argc, argv, i, ngl) == 1) {
  127. return 1;
  128. }
  129. } else if (options_parsing && strcmp(argv[i], "--temp") == 0) {
  130. if (handle_option_with_value(argc, argv, i, temperature) == 1) {
  131. return 1;
  132. }
  133. } else if (options_parsing &&
  134. (parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
  135. verbose = true;
  136. } else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
  137. use_jinja = true;
  138. } else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
  139. help = true;
  140. return 0;
  141. } else if (options_parsing && strcmp(argv[i], "--") == 0) {
  142. options_parsing = false;
  143. } else if (positional_args_i == 0) {
  144. if (!argv[i][0] || argv[i][0] == '-') {
  145. return 1;
  146. }
  147. ++positional_args_i;
  148. model_ = argv[i];
  149. } else if (positional_args_i == 1) {
  150. ++positional_args_i;
  151. user = argv[i];
  152. } else {
  153. user += " " + std::string(argv[i]);
  154. }
  155. }
  156. return 0;
  157. }
  158. void print_help() const {
  159. printf(
  160. "Description:\n"
  161. " Runs a llm\n"
  162. "\n"
  163. "Usage:\n"
  164. " llama-run [options] model [prompt]\n"
  165. "\n"
  166. "Options:\n"
  167. " -c, --context-size <value>\n"
  168. " Context size (default: %d)\n"
  169. " -n, --ngl <value>\n"
  170. " Number of GPU layers (default: %d)\n"
  171. " --temp <value>\n"
  172. " Temperature (default: %.1f)\n"
  173. " -v, --verbose, --log-verbose\n"
  174. " Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n"
  175. " -h, --help\n"
  176. " Show help message\n"
  177. "\n"
  178. "Commands:\n"
  179. " model\n"
  180. " Model is a string with an optional prefix of \n"
  181. " huggingface:// (hf://), ollama://, https:// or file://.\n"
  182. " If no protocol is specified and a file exists in the specified\n"
  183. " path, file:// is assumed, otherwise if a file does not exist in\n"
  184. " the specified path, ollama:// is assumed. Models that are being\n"
  185. " pulled are downloaded with .partial extension while being\n"
  186. " downloaded and then renamed as the file without the .partial\n"
  187. " extension when complete.\n"
  188. "\n"
  189. "Examples:\n"
  190. " llama-run llama3\n"
  191. " llama-run ollama://granite-code\n"
  192. " llama-run ollama://smollm:135m\n"
  193. " llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n"
  194. " llama-run "
  195. "huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n"
  196. " llama-run https://example.com/some-file1.gguf\n"
  197. " llama-run some-file2.gguf\n"
  198. " llama-run file://some-file3.gguf\n"
  199. " llama-run --ngl 999 some-file4.gguf\n"
  200. " llama-run --ngl 999 some-file5.gguf Hello World\n",
  201. context_size_default, ngl_default, temperature_default);
  202. }
  203. };
  204. struct progress_data {
  205. size_t file_size = 0;
  206. std::chrono::steady_clock::time_point start_time = std::chrono::steady_clock::now();
  207. bool printed = false;
  208. };
  209. static int get_terminal_width() {
  210. #if defined(_WIN32)
  211. CONSOLE_SCREEN_BUFFER_INFO csbi;
  212. GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi);
  213. return csbi.srWindow.Right - csbi.srWindow.Left + 1;
  214. #else
  215. struct winsize w;
  216. ioctl(STDOUT_FILENO, TIOCGWINSZ, &w);
  217. return w.ws_col;
  218. #endif
  219. }
  220. #ifdef LLAMA_USE_CURL
  221. class File {
  222. public:
  223. FILE * file = nullptr;
  224. FILE * open(const std::string & filename, const char * mode) {
  225. file = fopen(filename.c_str(), mode);
  226. return file;
  227. }
  228. int lock() {
  229. if (file) {
  230. # ifdef _WIN32
  231. fd = _fileno(file);
  232. hFile = (HANDLE) _get_osfhandle(fd);
  233. if (hFile == INVALID_HANDLE_VALUE) {
  234. fd = -1;
  235. return 1;
  236. }
  237. OVERLAPPED overlapped = {};
  238. if (!LockFileEx(hFile, LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY, 0, MAXDWORD, MAXDWORD,
  239. &overlapped)) {
  240. fd = -1;
  241. return 1;
  242. }
  243. # else
  244. fd = fileno(file);
  245. if (flock(fd, LOCK_EX | LOCK_NB) != 0) {
  246. fd = -1;
  247. return 1;
  248. }
  249. # endif
  250. }
  251. return 0;
  252. }
  253. ~File() {
  254. if (fd >= 0) {
  255. # ifdef _WIN32
  256. if (hFile != INVALID_HANDLE_VALUE) {
  257. OVERLAPPED overlapped = {};
  258. UnlockFileEx(hFile, 0, MAXDWORD, MAXDWORD, &overlapped);
  259. }
  260. # else
  261. flock(fd, LOCK_UN);
  262. # endif
  263. }
  264. if (file) {
  265. fclose(file);
  266. }
  267. }
  268. private:
  269. int fd = -1;
  270. # ifdef _WIN32
  271. HANDLE hFile = nullptr;
  272. # endif
  273. };
  274. class HttpClient {
  275. public:
  276. int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
  277. const bool progress, std::string * response_str = nullptr) {
  278. std::string output_file_partial;
  279. curl = curl_easy_init();
  280. if (!curl) {
  281. return 1;
  282. }
  283. progress_data data;
  284. File out;
  285. if (!output_file.empty()) {
  286. output_file_partial = output_file + ".partial";
  287. if (!out.open(output_file_partial, "ab")) {
  288. printe("Failed to open file\n");
  289. return 1;
  290. }
  291. if (out.lock()) {
  292. printe("Failed to exclusively lock file\n");
  293. return 1;
  294. }
  295. }
  296. set_write_options(response_str, out);
  297. data.file_size = set_resume_point(output_file_partial);
  298. set_progress_options(progress, data);
  299. set_headers(headers);
  300. perform(url);
  301. if (!output_file.empty()) {
  302. std::filesystem::rename(output_file_partial, output_file);
  303. }
  304. return 0;
  305. }
  306. ~HttpClient() {
  307. if (chunk) {
  308. curl_slist_free_all(chunk);
  309. }
  310. if (curl) {
  311. curl_easy_cleanup(curl);
  312. }
  313. }
  314. private:
  315. CURL * curl = nullptr;
  316. struct curl_slist * chunk = nullptr;
  317. void set_write_options(std::string * response_str, const File & out) {
  318. if (response_str) {
  319. curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, capture_data);
  320. curl_easy_setopt(curl, CURLOPT_WRITEDATA, response_str);
  321. } else {
  322. curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data);
  323. curl_easy_setopt(curl, CURLOPT_WRITEDATA, out.file);
  324. }
  325. }
  326. size_t set_resume_point(const std::string & output_file) {
  327. size_t file_size = 0;
  328. if (std::filesystem::exists(output_file)) {
  329. file_size = std::filesystem::file_size(output_file);
  330. curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE, static_cast<curl_off_t>(file_size));
  331. }
  332. return file_size;
  333. }
  334. void set_progress_options(bool progress, progress_data & data) {
  335. if (progress) {
  336. curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
  337. curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &data);
  338. curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, update_progress);
  339. }
  340. }
  341. void set_headers(const std::vector<std::string> & headers) {
  342. if (!headers.empty()) {
  343. if (chunk) {
  344. curl_slist_free_all(chunk);
  345. chunk = 0;
  346. }
  347. for (const auto & header : headers) {
  348. chunk = curl_slist_append(chunk, header.c_str());
  349. }
  350. curl_easy_setopt(curl, CURLOPT_HTTPHEADER, chunk);
  351. }
  352. }
  353. void perform(const std::string & url) {
  354. CURLcode res;
  355. curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
  356. curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
  357. curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https");
  358. curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L);
  359. res = curl_easy_perform(curl);
  360. if (res != CURLE_OK) {
  361. printe("curl_easy_perform() failed: %s\n", curl_easy_strerror(res));
  362. }
  363. }
  364. static std::string human_readable_time(double seconds) {
  365. int hrs = static_cast<int>(seconds) / 3600;
  366. int mins = (static_cast<int>(seconds) % 3600) / 60;
  367. int secs = static_cast<int>(seconds) % 60;
  368. if (hrs > 0) {
  369. return fmt("%dh %02dm %02ds", hrs, mins, secs);
  370. } else if (mins > 0) {
  371. return fmt("%dm %02ds", mins, secs);
  372. } else {
  373. return fmt("%ds", secs);
  374. }
  375. }
  376. static std::string human_readable_size(curl_off_t size) {
  377. static const char * suffix[] = { "B", "KB", "MB", "GB", "TB" };
  378. char length = sizeof(suffix) / sizeof(suffix[0]);
  379. int i = 0;
  380. double dbl_size = size;
  381. if (size > 1024) {
  382. for (i = 0; (size / 1024) > 0 && i < length - 1; i++, size /= 1024) {
  383. dbl_size = size / 1024.0;
  384. }
  385. }
  386. return fmt("%.2f %s", dbl_size, suffix[i]);
  387. }
  388. static int update_progress(void * ptr, curl_off_t total_to_download, curl_off_t now_downloaded, curl_off_t,
  389. curl_off_t) {
  390. progress_data * data = static_cast<progress_data *>(ptr);
  391. if (total_to_download <= 0) {
  392. return 0;
  393. }
  394. total_to_download += data->file_size;
  395. const curl_off_t now_downloaded_plus_file_size = now_downloaded + data->file_size;
  396. const curl_off_t percentage = calculate_percentage(now_downloaded_plus_file_size, total_to_download);
  397. std::string progress_prefix = generate_progress_prefix(percentage);
  398. const double speed = calculate_speed(now_downloaded, data->start_time);
  399. const double tim = (total_to_download - now_downloaded) / speed;
  400. std::string progress_suffix =
  401. generate_progress_suffix(now_downloaded_plus_file_size, total_to_download, speed, tim);
  402. int progress_bar_width = calculate_progress_bar_width(progress_prefix, progress_suffix);
  403. std::string progress_bar;
  404. generate_progress_bar(progress_bar_width, percentage, progress_bar);
  405. print_progress(progress_prefix, progress_bar, progress_suffix);
  406. data->printed = true;
  407. return 0;
  408. }
  409. static curl_off_t calculate_percentage(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download) {
  410. return (now_downloaded_plus_file_size * 100) / total_to_download;
  411. }
  412. static std::string generate_progress_prefix(curl_off_t percentage) { return fmt("%3ld%% |", static_cast<long int>(percentage)); }
  413. static double calculate_speed(curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) {
  414. const auto now = std::chrono::steady_clock::now();
  415. const std::chrono::duration<double> elapsed_seconds = now - start_time;
  416. return now_downloaded / elapsed_seconds.count();
  417. }
  418. static std::string generate_progress_suffix(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download,
  419. double speed, double estimated_time) {
  420. const int width = 10;
  421. return fmt("%*s/%*s%*s/s%*s", width, human_readable_size(now_downloaded_plus_file_size).c_str(), width,
  422. human_readable_size(total_to_download).c_str(), width, human_readable_size(speed).c_str(), width,
  423. human_readable_time(estimated_time).c_str());
  424. }
  425. static int calculate_progress_bar_width(const std::string & progress_prefix, const std::string & progress_suffix) {
  426. int progress_bar_width = get_terminal_width() - progress_prefix.size() - progress_suffix.size() - 3;
  427. if (progress_bar_width < 1) {
  428. progress_bar_width = 1;
  429. }
  430. return progress_bar_width;
  431. }
  432. static std::string generate_progress_bar(int progress_bar_width, curl_off_t percentage,
  433. std::string & progress_bar) {
  434. const curl_off_t pos = (percentage * progress_bar_width) / 100;
  435. for (int i = 0; i < progress_bar_width; ++i) {
  436. progress_bar.append((i < pos) ? "█" : " ");
  437. }
  438. return progress_bar;
  439. }
  440. static void print_progress(const std::string & progress_prefix, const std::string & progress_bar,
  441. const std::string & progress_suffix) {
  442. printe("\r%*s\r%s%s| %s", get_terminal_width(), " ", progress_prefix.c_str(), progress_bar.c_str(),
  443. progress_suffix.c_str());
  444. }
  445. // Function to write data to a file
  446. static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) {
  447. FILE * out = static_cast<FILE *>(stream);
  448. return fwrite(ptr, size, nmemb, out);
  449. }
  450. // Function to capture data into a string
  451. static size_t capture_data(void * ptr, size_t size, size_t nmemb, void * stream) {
  452. std::string * str = static_cast<std::string *>(stream);
  453. str->append(static_cast<char *>(ptr), size * nmemb);
  454. return size * nmemb;
  455. }
  456. };
  457. #endif
  458. class LlamaData {
  459. public:
  460. llama_model_ptr model;
  461. llama_sampler_ptr sampler;
  462. llama_context_ptr context;
  463. std::vector<llama_chat_message> messages;
  464. std::list<std::string> msg_strs;
  465. std::vector<char> fmtted;
  466. int init(Opt & opt) {
  467. model = initialize_model(opt);
  468. if (!model) {
  469. return 1;
  470. }
  471. context = initialize_context(model, opt);
  472. if (!context) {
  473. return 1;
  474. }
  475. sampler = initialize_sampler(opt);
  476. return 0;
  477. }
  478. private:
  479. #ifdef LLAMA_USE_CURL
  480. int download(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
  481. const bool progress, std::string * response_str = nullptr) {
  482. HttpClient http;
  483. if (http.init(url, headers, output_file, progress, response_str)) {
  484. return 1;
  485. }
  486. return 0;
  487. }
  488. #else
  489. int download(const std::string &, const std::vector<std::string> &, const std::string &, const bool,
  490. std::string * = nullptr) {
  491. printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
  492. return 1;
  493. }
  494. #endif
  495. int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
  496. // Find the second occurrence of '/' after protocol string
  497. size_t pos = model.find('/');
  498. pos = model.find('/', pos + 1);
  499. if (pos == std::string::npos) {
  500. return 1;
  501. }
  502. const std::string hfr = model.substr(0, pos);
  503. const std::string hff = model.substr(pos + 1);
  504. const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
  505. return download(url, headers, bn, true);
  506. }
  507. int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn) {
  508. if (model.find('/') == std::string::npos) {
  509. model = "library/" + model;
  510. }
  511. std::string model_tag = "latest";
  512. size_t colon_pos = model.find(':');
  513. if (colon_pos != std::string::npos) {
  514. model_tag = model.substr(colon_pos + 1);
  515. model = model.substr(0, colon_pos);
  516. }
  517. std::string manifest_url = "https://registry.ollama.ai/v2/" + model + "/manifests/" + model_tag;
  518. std::string manifest_str;
  519. const int ret = download(manifest_url, headers, "", false, &manifest_str);
  520. if (ret) {
  521. return ret;
  522. }
  523. nlohmann::json manifest = nlohmann::json::parse(manifest_str);
  524. std::string layer;
  525. for (const auto & l : manifest["layers"]) {
  526. if (l["mediaType"] == "application/vnd.ollama.image.model") {
  527. layer = l["digest"];
  528. break;
  529. }
  530. }
  531. std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer;
  532. return download(blob_url, headers, bn, true);
  533. }
  534. std::string basename(const std::string & path) {
  535. const size_t pos = path.find_last_of("/\\");
  536. if (pos == std::string::npos) {
  537. return path;
  538. }
  539. return path.substr(pos + 1);
  540. }
  541. int remove_proto(std::string & model_) {
  542. const std::string::size_type pos = model_.find("://");
  543. if (pos == std::string::npos) {
  544. return 1;
  545. }
  546. model_ = model_.substr(pos + 3); // Skip past "://"
  547. return 0;
  548. }
  549. int resolve_model(std::string & model_) {
  550. int ret = 0;
  551. if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) {
  552. remove_proto(model_);
  553. return ret;
  554. }
  555. const std::string bn = basename(model_);
  556. const std::vector<std::string> headers = { "--header",
  557. "Accept: application/vnd.docker.distribution.manifest.v2+json" };
  558. if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) {
  559. remove_proto(model_);
  560. ret = huggingface_dl(model_, headers, bn);
  561. } else if (string_starts_with(model_, "ollama://")) {
  562. remove_proto(model_);
  563. ret = ollama_dl(model_, headers, bn);
  564. } else if (string_starts_with(model_, "https://")) {
  565. download(model_, headers, bn, true);
  566. } else {
  567. ret = ollama_dl(model_, headers, bn);
  568. }
  569. model_ = bn;
  570. return ret;
  571. }
  572. // Initializes the model and returns a unique pointer to it
  573. llama_model_ptr initialize_model(Opt & opt) {
  574. ggml_backend_load_all();
  575. resolve_model(opt.model_);
  576. printe(
  577. "\r%*s"
  578. "\rLoading model",
  579. get_terminal_width(), " ");
  580. llama_model_ptr model(llama_model_load_from_file(opt.model_.c_str(), opt.model_params));
  581. if (!model) {
  582. printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str());
  583. }
  584. printe("\r%*s\r", static_cast<int>(sizeof("Loading model")), " ");
  585. return model;
  586. }
  587. // Initializes the context with the specified parameters
  588. llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) {
  589. llama_context_ptr context(llama_init_from_model(model.get(), opt.ctx_params));
  590. if (!context) {
  591. printe("%s: error: failed to create the llama_context\n", __func__);
  592. }
  593. return context;
  594. }
  595. // Initializes and configures the sampler
  596. llama_sampler_ptr initialize_sampler(const Opt & opt) {
  597. llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params()));
  598. llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1));
  599. llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(opt.temperature));
  600. llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
  601. return sampler;
  602. }
  603. };
  604. // Add a message to `messages` and store its content in `msg_strs`
  605. static void add_message(const char * role, const std::string & text, LlamaData & llama_data) {
  606. llama_data.msg_strs.push_back(std::move(text));
  607. llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() });
  608. }
  609. // Function to apply the chat template and resize `formatted` if needed
  610. static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
  611. if (use_jinja) {
  612. json messages = json::array();
  613. for (const auto & msg : llama_data.messages) {
  614. messages.push_back({
  615. {"role", msg.role},
  616. {"content", msg.content},
  617. });
  618. }
  619. try {
  620. auto result = tmpl.apply(messages, /* tools= */ json(), append);
  621. llama_data.fmtted.resize(result.size() + 1);
  622. memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
  623. return result.size();
  624. } catch (const std::exception & e) {
  625. printe("failed to render the chat template: %s\n", e.what());
  626. return -1;
  627. }
  628. }
  629. int result = llama_chat_apply_template(
  630. tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
  631. append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
  632. if (append && result > static_cast<int>(llama_data.fmtted.size())) {
  633. llama_data.fmtted.resize(result);
  634. result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
  635. llama_data.messages.size(), append, llama_data.fmtted.data(),
  636. llama_data.fmtted.size());
  637. }
  638. return result;
  639. }
  640. // Function to tokenize the prompt
  641. static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
  642. std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
  643. const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0;
  644. const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
  645. prompt_tokens.resize(n_prompt_tokens);
  646. if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first,
  647. true) < 0) {
  648. printe("failed to tokenize the prompt\n");
  649. return -1;
  650. }
  651. return n_prompt_tokens;
  652. }
  653. // Check if we have enough space in the context to evaluate this batch
  654. static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) {
  655. const int n_ctx = llama_n_ctx(ctx.get());
  656. const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get());
  657. if (n_ctx_used + batch.n_tokens > n_ctx) {
  658. printf("\033[0m\n");
  659. printe("context size exceeded\n");
  660. return 1;
  661. }
  662. return 0;
  663. }
  664. // convert the token to a string
  665. static int convert_token_to_string(const llama_vocab * vocab, const llama_token token_id, std::string & piece) {
  666. char buf[256];
  667. int n = llama_token_to_piece(vocab, token_id, buf, sizeof(buf), 0, true);
  668. if (n < 0) {
  669. printe("failed to convert token to piece\n");
  670. return 1;
  671. }
  672. piece = std::string(buf, n);
  673. return 0;
  674. }
  675. static void print_word_and_concatenate_to_response(const std::string & piece, std::string & response) {
  676. printf("%s", piece.c_str());
  677. fflush(stdout);
  678. response += piece;
  679. }
  680. // helper function to evaluate a prompt and generate a response
  681. static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) {
  682. const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get());
  683. std::vector<llama_token> tokens;
  684. if (tokenize_prompt(vocab, prompt, tokens, llama_data) < 0) {
  685. return 1;
  686. }
  687. // prepare a batch for the prompt
  688. llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size());
  689. llama_token new_token_id;
  690. while (true) {
  691. check_context_size(llama_data.context, batch);
  692. if (llama_decode(llama_data.context.get(), batch)) {
  693. printe("failed to decode\n");
  694. return 1;
  695. }
  696. // sample the next token, check is it an end of generation?
  697. new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1);
  698. if (llama_vocab_is_eog(vocab, new_token_id)) {
  699. break;
  700. }
  701. std::string piece;
  702. if (convert_token_to_string(vocab, new_token_id, piece)) {
  703. return 1;
  704. }
  705. print_word_and_concatenate_to_response(piece, response);
  706. // prepare the next batch with the sampled token
  707. batch = llama_batch_get_one(&new_token_id, 1);
  708. }
  709. printf("\033[0m");
  710. return 0;
  711. }
  712. static int read_user_input(std::string & user_input) {
  713. static const char * prompt_prefix = "> ";
  714. #ifdef WIN32
  715. printf(
  716. "\r%*s"
  717. "\r\033[0m%s",
  718. get_terminal_width(), " ", prompt_prefix);
  719. std::getline(std::cin, user_input);
  720. if (std::cin.eof()) {
  721. printf("\n");
  722. return 1;
  723. }
  724. #else
  725. std::unique_ptr<char, decltype(&std::free)> line(const_cast<char *>(linenoise(prompt_prefix)), free);
  726. if (!line) {
  727. return 1;
  728. }
  729. user_input = line.get();
  730. #endif
  731. if (user_input == "/bye") {
  732. return 1;
  733. }
  734. if (user_input.empty()) {
  735. return 2;
  736. }
  737. #ifndef WIN32
  738. linenoiseHistoryAdd(line.get());
  739. #endif
  740. return 0; // Should have data in happy path
  741. }
  742. // Function to generate a response based on the prompt
  743. static int generate_response(LlamaData & llama_data, const std::string & prompt, std::string & response,
  744. const bool stdout_a_terminal) {
  745. // Set response color
  746. if (stdout_a_terminal) {
  747. printf("\033[33m");
  748. }
  749. if (generate(llama_data, prompt, response)) {
  750. printe("failed to generate response\n");
  751. return 1;
  752. }
  753. // End response with color reset and newline
  754. printf("\n%s", stdout_a_terminal ? "\033[0m" : "");
  755. return 0;
  756. }
  757. // Helper function to apply the chat template and handle errors
  758. static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
  759. const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
  760. if (new_len < 0) {
  761. printe("failed to apply the chat template\n");
  762. return -1;
  763. }
  764. output_length = new_len;
  765. return 0;
  766. }
  767. // Helper function to handle user input
  768. static int handle_user_input(std::string & user_input, const std::string & user) {
  769. if (!user.empty()) {
  770. user_input = user;
  771. return 0; // No need for interactive input
  772. }
  773. return read_user_input(user_input); // Returns true if input ends the loop
  774. }
  775. static bool is_stdin_a_terminal() {
  776. #if defined(_WIN32)
  777. HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE);
  778. DWORD mode;
  779. return GetConsoleMode(hStdin, &mode);
  780. #else
  781. return isatty(STDIN_FILENO);
  782. #endif
  783. }
  784. static bool is_stdout_a_terminal() {
  785. #if defined(_WIN32)
  786. HANDLE hStdout = GetStdHandle(STD_OUTPUT_HANDLE);
  787. DWORD mode;
  788. return GetConsoleMode(hStdout, &mode);
  789. #else
  790. return isatty(STDOUT_FILENO);
  791. #endif
  792. }
  793. // Function to handle user input
  794. static int get_user_input(std::string & user_input, const std::string & user) {
  795. while (true) {
  796. const int ret = handle_user_input(user_input, user);
  797. if (ret == 1) {
  798. return 1;
  799. }
  800. if (ret == 2) {
  801. continue;
  802. }
  803. break;
  804. }
  805. return 0;
  806. }
  807. // Main chat loop function
  808. static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
  809. int prev_len = 0;
  810. llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
  811. auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
  812. GGML_ASSERT(chat_templates.template_default);
  813. static const bool stdout_a_terminal = is_stdout_a_terminal();
  814. while (true) {
  815. // Get user input
  816. std::string user_input;
  817. if (get_user_input(user_input, user) == 1) {
  818. return 0;
  819. }
  820. add_message("user", user.empty() ? user_input : user, llama_data);
  821. int new_len;
  822. if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) {
  823. return 1;
  824. }
  825. std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
  826. std::string response;
  827. if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
  828. return 1;
  829. }
  830. if (!user.empty()) {
  831. break;
  832. }
  833. add_message("assistant", response, llama_data);
  834. if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) {
  835. return 1;
  836. }
  837. }
  838. return 0;
  839. }
  840. static void log_callback(const enum ggml_log_level level, const char * text, void * p) {
  841. const Opt * opt = static_cast<Opt *>(p);
  842. if (opt->verbose || level == GGML_LOG_LEVEL_ERROR) {
  843. printe("%s", text);
  844. }
  845. }
  846. static std::string read_pipe_data() {
  847. std::ostringstream result;
  848. result << std::cin.rdbuf(); // Read all data from std::cin
  849. return result.str();
  850. }
  851. static void ctrl_c_handling() {
  852. #if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
  853. struct sigaction sigint_action;
  854. sigint_action.sa_handler = sigint_handler;
  855. sigemptyset(&sigint_action.sa_mask);
  856. sigint_action.sa_flags = 0;
  857. sigaction(SIGINT, &sigint_action, NULL);
  858. #elif defined(_WIN32)
  859. auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
  860. return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
  861. };
  862. SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
  863. #endif
  864. }
  865. int main(int argc, const char ** argv) {
  866. ctrl_c_handling();
  867. Opt opt;
  868. const int ret = opt.init(argc, argv);
  869. if (ret == 2) {
  870. return 0;
  871. } else if (ret) {
  872. return 1;
  873. }
  874. if (!is_stdin_a_terminal()) {
  875. if (!opt.user.empty()) {
  876. opt.user += "\n\n";
  877. }
  878. opt.user += read_pipe_data();
  879. }
  880. llama_log_set(log_callback, &opt);
  881. LlamaData llama_data;
  882. if (llama_data.init(opt)) {
  883. return 1;
  884. }
  885. if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
  886. return 1;
  887. }
  888. return 0;
  889. }