download.cpp 43 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150
  1. #include "arg.h"
  2. #include "common.h"
  3. #include "gguf.h" // for reading GGUF splits
  4. #include "log.h"
  5. #include "download.h"
  6. #define JSON_ASSERT GGML_ASSERT
  7. #include <nlohmann/json.hpp>
  8. #include <algorithm>
  9. #include <filesystem>
  10. #include <fstream>
  11. #include <future>
  12. #include <map>
  13. #include <mutex>
  14. #include <regex>
  15. #include <string>
  16. #include <thread>
  17. #include <vector>
  18. #if defined(LLAMA_USE_CURL)
  19. #include <curl/curl.h>
  20. #include <curl/easy.h>
  21. #elif defined(LLAMA_USE_HTTPLIB)
  22. #include "http.h"
  23. #endif
  24. #ifndef __EMSCRIPTEN__
  25. #ifdef __linux__
  26. #include <linux/limits.h>
  27. #elif defined(_WIN32)
  28. # if !defined(PATH_MAX)
  29. # define PATH_MAX MAX_PATH
  30. # endif
  31. #elif defined(_AIX)
  32. #include <sys/limits.h>
  33. #else
  34. #include <sys/syslimits.h>
  35. #endif
  36. #endif
  37. #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
  38. // isatty
  39. #if defined(_WIN32)
  40. #include <io.h>
  41. #else
  42. #include <unistd.h>
  43. #endif
  44. using json = nlohmann::ordered_json;
  45. //
  46. // downloader
  47. //
  48. // validate repo name format: owner/repo
  49. static bool validate_repo_name(const std::string & repo) {
  50. static const std::regex repo_regex(R"(^[A-Za-z0-9_.\-]+\/[A-Za-z0-9_.\-]+$)");
  51. return std::regex_match(repo, repo_regex);
  52. }
  53. static std::string get_manifest_path(const std::string & repo, const std::string & tag) {
  54. // we use "=" to avoid clashing with other component, while still being allowed on windows
  55. std::string fname = "manifest=" + repo + "=" + tag + ".json";
  56. if (!validate_repo_name(repo)) {
  57. throw std::runtime_error("error: repo name must be in the format 'owner/repo'");
  58. }
  59. string_replace_all(fname, "/", "=");
  60. return fs_get_cache_file(fname);
  61. }
  62. static std::string read_file(const std::string & fname) {
  63. std::ifstream file(fname);
  64. if (!file) {
  65. throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
  66. }
  67. std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
  68. file.close();
  69. return content;
  70. }
  71. static void write_file(const std::string & fname, const std::string & content) {
  72. const std::string fname_tmp = fname + ".tmp";
  73. std::ofstream file(fname_tmp);
  74. if (!file) {
  75. throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
  76. }
  77. try {
  78. file << content;
  79. file.close();
  80. // Makes write atomic
  81. if (rename(fname_tmp.c_str(), fname.c_str()) != 0) {
  82. LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, fname_tmp.c_str(), fname.c_str());
  83. // If rename fails, try to delete the temporary file
  84. if (remove(fname_tmp.c_str()) != 0) {
  85. LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
  86. }
  87. }
  88. } catch (...) {
  89. // If anything fails, try to delete the temporary file
  90. if (remove(fname_tmp.c_str()) != 0) {
  91. LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
  92. }
  93. throw std::runtime_error(string_format("error: failed to write file '%s'\n", fname.c_str()));
  94. }
  95. }
  96. static void write_etag(const std::string & path, const std::string & etag) {
  97. const std::string etag_path = path + ".etag";
  98. write_file(etag_path, etag);
  99. LOG_DBG("%s: file etag saved: %s\n", __func__, etag_path.c_str());
  100. }
  101. static std::string read_etag(const std::string & path) {
  102. std::string none;
  103. const std::string etag_path = path + ".etag";
  104. if (std::filesystem::exists(etag_path)) {
  105. std::ifstream etag_in(etag_path);
  106. if (!etag_in) {
  107. LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
  108. return none;
  109. }
  110. std::string etag;
  111. std::getline(etag_in, etag);
  112. return etag;
  113. }
  114. // no etag file, but maybe there is an old .json
  115. // remove this code later
  116. const std::string metadata_path = path + ".json";
  117. if (std::filesystem::exists(metadata_path)) {
  118. std::ifstream metadata_in(metadata_path);
  119. try {
  120. nlohmann::json metadata_json;
  121. metadata_in >> metadata_json;
  122. LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
  123. metadata_json.dump().c_str());
  124. if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
  125. std::string etag = metadata_json.at("etag");
  126. write_etag(path, etag);
  127. if (!std::filesystem::remove(metadata_path)) {
  128. LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str());
  129. }
  130. return etag;
  131. }
  132. } catch (const nlohmann::json::exception & e) {
  133. LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
  134. }
  135. }
  136. return none;
  137. }
  138. #ifdef LLAMA_USE_CURL
  139. //
  140. // CURL utils
  141. //
  142. using curl_ptr = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
  143. // cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one
  144. struct curl_slist_ptr {
  145. struct curl_slist * ptr = nullptr;
  146. ~curl_slist_ptr() {
  147. if (ptr) {
  148. curl_slist_free_all(ptr);
  149. }
  150. }
  151. };
  152. static CURLcode common_curl_perf(CURL * curl) {
  153. CURLcode res = curl_easy_perform(curl);
  154. if (res != CURLE_OK) {
  155. LOG_ERR("%s: curl_easy_perform() failed\n", __func__);
  156. }
  157. return res;
  158. }
  159. // Send a HEAD request to retrieve the etag and last-modified headers
  160. struct common_load_model_from_url_headers {
  161. std::string etag;
  162. std::string last_modified;
  163. std::string accept_ranges;
  164. };
  165. struct FILE_deleter {
  166. void operator()(FILE * f) const { fclose(f); }
  167. };
  168. static size_t common_header_callback(char * buffer, size_t, size_t n_items, void * userdata) {
  169. common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
  170. static std::regex header_regex("([^:]+): (.*)\r\n");
  171. static std::regex etag_regex("ETag", std::regex_constants::icase);
  172. static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
  173. static std::regex accept_ranges_regex("Accept-Ranges", std::regex_constants::icase);
  174. std::string header(buffer, n_items);
  175. std::smatch match;
  176. if (std::regex_match(header, match, header_regex)) {
  177. const std::string & key = match[1];
  178. const std::string & value = match[2];
  179. if (std::regex_match(key, match, etag_regex)) {
  180. headers->etag = value;
  181. } else if (std::regex_match(key, match, last_modified_regex)) {
  182. headers->last_modified = value;
  183. } else if (std::regex_match(key, match, accept_ranges_regex)) {
  184. headers->accept_ranges = value;
  185. }
  186. }
  187. return n_items;
  188. }
  189. static size_t common_write_callback(void * data, size_t size, size_t nmemb, void * fd) {
  190. return std::fwrite(data, size, nmemb, static_cast<FILE *>(fd));
  191. }
  192. // helper function to hide password in URL
  193. static std::string llama_download_hide_password_in_url(const std::string & url) {
  194. // Use regex to match and replace the user[:password]@ pattern in URLs
  195. // Pattern: scheme://[user[:password]@]host[...]
  196. static const std::regex url_regex(R"(^(?:[A-Za-z][A-Za-z0-9+.-]://)(?:[^/@]+@)?.$)");
  197. std::smatch match;
  198. if (std::regex_match(url, match, url_regex)) {
  199. // match[1] = scheme (e.g., "https://")
  200. // match[2] = user[:password]@ part
  201. // match[3] = rest of URL (host and path)
  202. return match[1].str() + "********@" + match[3].str();
  203. }
  204. return url; // No credentials found or malformed URL
  205. }
  206. static void common_curl_easy_setopt_head(CURL * curl, const std::string & url) {
  207. // Set the URL, allow to follow http redirection
  208. curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
  209. curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
  210. # if defined(_WIN32)
  211. // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
  212. // operating system. Currently implemented under MS-Windows.
  213. curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
  214. # endif
  215. curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
  216. curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L); // hide head request progress
  217. curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, common_header_callback);
  218. }
  219. static void common_curl_easy_setopt_get(CURL * curl) {
  220. curl_easy_setopt(curl, CURLOPT_NOBODY, 0L);
  221. curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, common_write_callback);
  222. // display download progress
  223. curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
  224. }
  225. static bool common_pull_file(CURL * curl, const std::string & path_temporary) {
  226. if (std::filesystem::exists(path_temporary)) {
  227. const std::string partial_size = std::to_string(std::filesystem::file_size(path_temporary));
  228. LOG_INF("%s: server supports range requests, resuming download from byte %s\n", __func__, partial_size.c_str());
  229. const std::string range_str = partial_size + "-";
  230. curl_easy_setopt(curl, CURLOPT_RANGE, range_str.c_str());
  231. }
  232. // Always open file in append mode could be resuming
  233. std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "ab"));
  234. if (!outfile) {
  235. LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_temporary.c_str());
  236. return false;
  237. }
  238. common_curl_easy_setopt_get(curl);
  239. curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile.get());
  240. return common_curl_perf(curl) == CURLE_OK;
  241. }
  242. static bool common_download_head(CURL * curl,
  243. curl_slist_ptr & http_headers,
  244. const std::string & url,
  245. const std::string & bearer_token) {
  246. if (!curl) {
  247. LOG_ERR("%s: error initializing libcurl\n", __func__);
  248. return false;
  249. }
  250. http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
  251. // Check if hf-token or bearer-token was specified
  252. if (!bearer_token.empty()) {
  253. std::string auth_header = "Authorization: Bearer " + bearer_token;
  254. http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
  255. }
  256. curl_easy_setopt(curl, CURLOPT_HTTPHEADER, http_headers.ptr);
  257. common_curl_easy_setopt_head(curl, url);
  258. return common_curl_perf(curl) == CURLE_OK;
  259. }
  260. // download one single file from remote URL to local path
  261. static bool common_download_file_single_online(const std::string & url,
  262. const std::string & path,
  263. const std::string & bearer_token,
  264. const common_header_list & custom_headers) {
  265. static const int max_attempts = 3;
  266. static const int retry_delay_seconds = 2;
  267. for (int i = 0; i < max_attempts; ++i) {
  268. std::string etag;
  269. // Check if the file already exists locally
  270. const auto file_exists = std::filesystem::exists(path);
  271. if (file_exists) {
  272. etag = read_etag(path);
  273. } else {
  274. LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
  275. }
  276. bool head_request_ok = false;
  277. bool should_download = !file_exists; // by default, we should download if the file does not exist
  278. // Initialize libcurl
  279. curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
  280. common_load_model_from_url_headers headers;
  281. curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
  282. curl_slist_ptr http_headers;
  283. for (const auto & h : custom_headers) {
  284. std::string s = h.first + ": " + h.second;
  285. http_headers.ptr = curl_slist_append(http_headers.ptr, s.c_str());
  286. }
  287. const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token);
  288. if (!was_perform_successful) {
  289. head_request_ok = false;
  290. }
  291. long http_code = 0;
  292. curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
  293. if (http_code == 200) {
  294. head_request_ok = true;
  295. } else {
  296. LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
  297. head_request_ok = false;
  298. }
  299. // if head_request_ok is false, we don't have the etag or last-modified headers
  300. // we leave should_download as-is, which is true if the file does not exist
  301. bool should_download_from_scratch = false;
  302. if (head_request_ok) {
  303. // check if ETag or Last-Modified headers are different
  304. // if it is, we need to download the file again
  305. if (!etag.empty() && etag != headers.etag) {
  306. LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(),
  307. headers.etag.c_str());
  308. should_download = true;
  309. should_download_from_scratch = true;
  310. }
  311. }
  312. const bool accept_ranges_supported = !headers.accept_ranges.empty() && headers.accept_ranges != "none";
  313. if (should_download) {
  314. if (file_exists &&
  315. !accept_ranges_supported) { // Resumable downloads not supported, delete and start again.
  316. LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
  317. if (remove(path.c_str()) != 0) {
  318. LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
  319. return false;
  320. }
  321. }
  322. const std::string path_temporary = path + ".downloadInProgress";
  323. if (should_download_from_scratch) {
  324. if (std::filesystem::exists(path_temporary)) {
  325. if (remove(path_temporary.c_str()) != 0) {
  326. LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
  327. return false;
  328. }
  329. }
  330. if (std::filesystem::exists(path)) {
  331. if (remove(path.c_str()) != 0) {
  332. LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
  333. return false;
  334. }
  335. }
  336. }
  337. if (head_request_ok) {
  338. write_etag(path, headers.etag);
  339. }
  340. // start the download
  341. LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
  342. __func__, llama_download_hide_password_in_url(url).c_str(), path_temporary.c_str(),
  343. headers.etag.c_str(), headers.last_modified.c_str());
  344. const bool was_pull_successful = common_pull_file(curl.get(), path_temporary);
  345. if (!was_pull_successful) {
  346. if (i + 1 < max_attempts) {
  347. const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
  348. LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
  349. std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
  350. } else {
  351. LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
  352. }
  353. continue;
  354. }
  355. long http_code = 0;
  356. curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
  357. if (http_code < 200 || http_code >= 400) {
  358. LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
  359. return false;
  360. }
  361. if (rename(path_temporary.c_str(), path.c_str()) != 0) {
  362. LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
  363. return false;
  364. }
  365. } else {
  366. LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
  367. }
  368. break;
  369. }
  370. return true;
  371. }
  372. std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params) {
  373. curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
  374. curl_slist_ptr http_headers;
  375. std::vector<char> res_buffer;
  376. curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
  377. curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
  378. curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
  379. curl_easy_setopt(curl.get(), CURLOPT_VERBOSE, 0L);
  380. typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
  381. auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
  382. auto data_vec = static_cast<std::vector<char> *>(data);
  383. data_vec->insert(data_vec->end(), (char *)ptr, (char *)ptr + size * nmemb);
  384. return size * nmemb;
  385. };
  386. curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
  387. curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_buffer);
  388. #if defined(_WIN32)
  389. curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
  390. #endif
  391. if (params.timeout > 0) {
  392. curl_easy_setopt(curl.get(), CURLOPT_TIMEOUT, params.timeout);
  393. }
  394. if (params.max_size > 0) {
  395. curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size);
  396. }
  397. http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
  398. for (const auto & header : params.headers) {
  399. std::string header_ = header.first + ": " + header.second;
  400. http_headers.ptr = curl_slist_append(http_headers.ptr, header_.c_str());
  401. }
  402. curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
  403. CURLcode res = curl_easy_perform(curl.get());
  404. if (res != CURLE_OK) {
  405. std::string error_msg = curl_easy_strerror(res);
  406. throw std::runtime_error("error: cannot make GET request: " + error_msg);
  407. }
  408. long res_code;
  409. curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
  410. return { res_code, std::move(res_buffer) };
  411. }
  412. #elif defined(LLAMA_USE_HTTPLIB)
  413. class ProgressBar {
  414. static inline std::mutex mutex;
  415. static inline std::map<const ProgressBar *, int> lines;
  416. static inline int max_line = 0;
  417. static void cleanup(const ProgressBar * line) {
  418. lines.erase(line);
  419. if (lines.empty()) {
  420. max_line = 0;
  421. }
  422. }
  423. static bool is_output_a_tty() {
  424. #if defined(_WIN32)
  425. return _isatty(_fileno(stdout));
  426. #else
  427. return isatty(1);
  428. #endif
  429. }
  430. public:
  431. ProgressBar() = default;
  432. ~ProgressBar() {
  433. std::lock_guard<std::mutex> lock(mutex);
  434. cleanup(this);
  435. }
  436. void update(size_t current, size_t total) {
  437. if (!is_output_a_tty()) {
  438. return;
  439. }
  440. if (!total) {
  441. return;
  442. }
  443. std::lock_guard<std::mutex> lock(mutex);
  444. if (lines.find(this) == lines.end()) {
  445. lines[this] = max_line++;
  446. std::cout << "\n";
  447. }
  448. int lines_up = max_line - lines[this];
  449. size_t width = 50;
  450. size_t pct = (100 * current) / total;
  451. size_t pos = (width * current) / total;
  452. std::cout << "\033[s";
  453. if (lines_up > 0) {
  454. std::cout << "\033[" << lines_up << "A";
  455. }
  456. std::cout << "\033[2K\r["
  457. << std::string(pos, '=')
  458. << (pos < width ? ">" : "")
  459. << std::string(width - pos, ' ')
  460. << "] " << std::setw(3) << pct << "% ("
  461. << current / (1024 * 1024) << " MB / "
  462. << total / (1024 * 1024) << " MB) "
  463. << "\033[u";
  464. std::cout.flush();
  465. if (current == total) {
  466. cleanup(this);
  467. }
  468. }
  469. ProgressBar(const ProgressBar &) = delete;
  470. ProgressBar & operator=(const ProgressBar &) = delete;
  471. };
  472. static bool common_pull_file(httplib::Client & cli,
  473. const std::string & resolve_path,
  474. const std::string & path_tmp,
  475. bool supports_ranges,
  476. size_t existing_size,
  477. size_t & total_size) {
  478. std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app);
  479. if (!ofs.is_open()) {
  480. LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str());
  481. return false;
  482. }
  483. httplib::Headers headers;
  484. if (supports_ranges && existing_size > 0) {
  485. headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-");
  486. }
  487. const char * func = __func__; // avoid __func__ inside a lambda
  488. size_t downloaded = existing_size;
  489. size_t progress_step = 0;
  490. ProgressBar bar;
  491. auto res = cli.Get(resolve_path, headers,
  492. [&](const httplib::Response &response) {
  493. if (existing_size > 0 && response.status != 206) {
  494. LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status);
  495. return false;
  496. }
  497. if (existing_size == 0 && response.status != 200) {
  498. LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status);
  499. return false;
  500. }
  501. if (total_size == 0 && response.has_header("Content-Length")) {
  502. try {
  503. size_t content_length = std::stoull(response.get_header_value("Content-Length"));
  504. total_size = existing_size + content_length;
  505. } catch (const std::exception &e) {
  506. LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what());
  507. }
  508. }
  509. return true;
  510. },
  511. [&](const char *data, size_t len) {
  512. ofs.write(data, len);
  513. if (!ofs) {
  514. LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str());
  515. return false;
  516. }
  517. downloaded += len;
  518. progress_step += len;
  519. if (progress_step >= total_size / 1000 || downloaded == total_size) {
  520. bar.update(downloaded, total_size);
  521. progress_step = 0;
  522. }
  523. return true;
  524. },
  525. nullptr
  526. );
  527. if (!res) {
  528. LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
  529. return false;
  530. }
  531. return true;
  532. }
  533. // download one single file from remote URL to local path
  534. static bool common_download_file_single_online(const std::string & url,
  535. const std::string & path,
  536. const std::string & bearer_token,
  537. const common_header_list & custom_headers) {
  538. static const int max_attempts = 3;
  539. static const int retry_delay_seconds = 2;
  540. auto [cli, parts] = common_http_client(url);
  541. httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}};
  542. if (!bearer_token.empty()) {
  543. default_headers.insert({"Authorization", "Bearer " + bearer_token});
  544. }
  545. for (const auto & h : custom_headers) {
  546. default_headers.emplace(h.first, h.second);
  547. }
  548. cli.set_default_headers(default_headers);
  549. const bool file_exists = std::filesystem::exists(path);
  550. std::string last_etag;
  551. if (file_exists) {
  552. last_etag = read_etag(path);
  553. } else {
  554. LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
  555. }
  556. for (int i = 0; i < max_attempts; ++i) {
  557. auto head = cli.Head(parts.path);
  558. bool head_ok = head && head->status >= 200 && head->status < 300;
  559. if (!head_ok) {
  560. LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
  561. if (file_exists) {
  562. LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
  563. return true;
  564. }
  565. }
  566. std::string etag;
  567. if (head_ok && head->has_header("ETag")) {
  568. etag = head->get_header_value("ETag");
  569. }
  570. size_t total_size = 0;
  571. if (head_ok && head->has_header("Content-Length")) {
  572. try {
  573. total_size = std::stoull(head->get_header_value("Content-Length"));
  574. } catch (const std::exception& e) {
  575. LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
  576. }
  577. }
  578. bool supports_ranges = false;
  579. if (head_ok && head->has_header("Accept-Ranges")) {
  580. supports_ranges = head->get_header_value("Accept-Ranges") != "none";
  581. }
  582. bool should_download_from_scratch = false;
  583. if (!last_etag.empty() && !etag.empty() && last_etag != etag) {
  584. LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
  585. last_etag.c_str(), etag.c_str());
  586. should_download_from_scratch = true;
  587. }
  588. if (file_exists) {
  589. if (!should_download_from_scratch) {
  590. LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
  591. return true;
  592. }
  593. LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
  594. if (remove(path.c_str()) != 0) {
  595. LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
  596. return false;
  597. }
  598. }
  599. const std::string path_temporary = path + ".downloadInProgress";
  600. size_t existing_size = 0;
  601. if (std::filesystem::exists(path_temporary)) {
  602. if (supports_ranges && !should_download_from_scratch) {
  603. existing_size = std::filesystem::file_size(path_temporary);
  604. } else if (remove(path_temporary.c_str()) != 0) {
  605. LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
  606. return false;
  607. }
  608. }
  609. // start the download
  610. LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
  611. __func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
  612. const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
  613. if (!was_pull_successful) {
  614. if (i + 1 < max_attempts) {
  615. const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
  616. LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
  617. std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
  618. } else {
  619. LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
  620. }
  621. continue;
  622. }
  623. if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
  624. LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
  625. return false;
  626. }
  627. if (!etag.empty()) {
  628. write_etag(path, etag);
  629. }
  630. break;
  631. }
  632. return true;
  633. }
  634. std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
  635. const common_remote_params & params) {
  636. auto [cli, parts] = common_http_client(url);
  637. httplib::Headers headers = {{"User-Agent", "llama-cpp"}};
  638. for (const auto & header : params.headers) {
  639. headers.emplace(header.first, header.second);
  640. }
  641. if (params.timeout > 0) {
  642. cli.set_read_timeout(params.timeout, 0);
  643. cli.set_write_timeout(params.timeout, 0);
  644. }
  645. std::vector<char> buf;
  646. auto res = cli.Get(parts.path, headers,
  647. [&](const char *data, size_t len) {
  648. buf.insert(buf.end(), data, data + len);
  649. return params.max_size == 0 ||
  650. buf.size() <= static_cast<size_t>(params.max_size);
  651. },
  652. nullptr
  653. );
  654. if (!res) {
  655. throw std::runtime_error("error: cannot make GET request");
  656. }
  657. return { res->status, std::move(buf) };
  658. }
  659. #endif // LLAMA_USE_CURL
  660. #if defined(LLAMA_USE_CURL) || defined(LLAMA_USE_HTTPLIB)
  661. static bool common_download_file_single(const std::string & url,
  662. const std::string & path,
  663. const std::string & bearer_token,
  664. bool offline,
  665. const common_header_list & headers) {
  666. if (!offline) {
  667. return common_download_file_single_online(url, path, bearer_token, headers);
  668. }
  669. if (!std::filesystem::exists(path)) {
  670. LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
  671. return false;
  672. }
  673. LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
  674. return true;
  675. }
  676. // download multiple files from remote URLs to local paths
  677. // the input is a vector of pairs <url, path>
  678. static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls,
  679. const std::string & bearer_token,
  680. bool offline,
  681. const common_header_list & headers) {
  682. // Prepare download in parallel
  683. std::vector<std::future<bool>> futures_download;
  684. futures_download.reserve(urls.size());
  685. for (auto const & item : urls) {
  686. futures_download.push_back(
  687. std::async(
  688. std::launch::async,
  689. [&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
  690. return common_download_file_single(it.first, it.second, bearer_token, offline, headers);
  691. },
  692. item
  693. )
  694. );
  695. }
  696. // Wait for all downloads to complete
  697. for (auto & f : futures_download) {
  698. if (!f.get()) {
  699. return false;
  700. }
  701. }
  702. return true;
  703. }
  704. bool common_download_model(const common_params_model & model,
  705. const std::string & bearer_token,
  706. bool offline,
  707. const common_header_list & headers) {
  708. // Basic validation of the model.url
  709. if (model.url.empty()) {
  710. LOG_ERR("%s: invalid model url\n", __func__);
  711. return false;
  712. }
  713. if (!common_download_file_single(model.url, model.path, bearer_token, offline, headers)) {
  714. return false;
  715. }
  716. // check for additional GGUFs split to download
  717. int n_split = 0;
  718. {
  719. struct gguf_init_params gguf_params = {
  720. /*.no_alloc = */ true,
  721. /*.ctx = */ NULL,
  722. };
  723. auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params);
  724. if (!ctx_gguf) {
  725. LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str());
  726. return false;
  727. }
  728. auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
  729. if (key_n_split >= 0) {
  730. n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
  731. }
  732. gguf_free(ctx_gguf);
  733. }
  734. if (n_split > 1) {
  735. char split_prefix[PATH_MAX] = {0};
  736. char split_url_prefix[LLAMA_MAX_URL_LENGTH] = {0};
  737. // Verify the first split file format
  738. // and extract split URL and PATH prefixes
  739. {
  740. if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) {
  741. LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split);
  742. return false;
  743. }
  744. if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) {
  745. LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split);
  746. return false;
  747. }
  748. }
  749. std::vector<std::pair<std::string, std::string>> urls;
  750. for (int idx = 1; idx < n_split; idx++) {
  751. char split_path[PATH_MAX] = {0};
  752. llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split);
  753. char split_url[LLAMA_MAX_URL_LENGTH] = {0};
  754. llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split);
  755. if (std::string(split_path) == model.path) {
  756. continue; // skip the already downloaded file
  757. }
  758. urls.push_back({split_url, split_path});
  759. }
  760. // Download in parallel
  761. common_download_file_multiple(urls, bearer_token, offline, headers);
  762. }
  763. return true;
  764. }
  765. common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag,
  766. const std::string & bearer_token,
  767. bool offline,
  768. const common_header_list & custom_headers) {
  769. auto parts = string_split<std::string>(hf_repo_with_tag, ':');
  770. std::string tag = parts.size() > 1 ? parts.back() : "latest";
  771. std::string hf_repo = parts[0];
  772. if (string_split<std::string>(hf_repo, '/').size() != 2) {
  773. throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
  774. }
  775. std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
  776. // headers
  777. common_header_list headers = custom_headers;
  778. headers.push_back({"Accept", "application/json"});
  779. if (!bearer_token.empty()) {
  780. headers.push_back({"Authorization", "Bearer " + bearer_token});
  781. }
  782. // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
  783. // User-Agent header is already set in common_remote_get_content, no need to set it here
  784. // make the request
  785. common_remote_params params;
  786. params.headers = headers;
  787. long res_code = 0;
  788. std::string res_str;
  789. bool use_cache = false;
  790. std::string cached_response_path = get_manifest_path(hf_repo, tag);
  791. if (!offline) {
  792. try {
  793. auto res = common_remote_get_content(url, params);
  794. res_code = res.first;
  795. res_str = std::string(res.second.data(), res.second.size());
  796. } catch (const std::exception & e) {
  797. LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what());
  798. }
  799. }
  800. if (res_code == 0) {
  801. if (std::filesystem::exists(cached_response_path)) {
  802. LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str());
  803. res_str = read_file(cached_response_path);
  804. res_code = 200;
  805. use_cache = true;
  806. } else {
  807. throw std::runtime_error(
  808. offline ? "error: failed to get manifest (offline mode)"
  809. : "error: failed to get manifest (check your internet connection)");
  810. }
  811. }
  812. std::string ggufFile;
  813. std::string mmprojFile;
  814. if (res_code == 200 || res_code == 304) {
  815. try {
  816. auto j = json::parse(res_str);
  817. if (j.contains("ggufFile") && j["ggufFile"].contains("rfilename")) {
  818. ggufFile = j["ggufFile"]["rfilename"].get<std::string>();
  819. }
  820. if (j.contains("mmprojFile") && j["mmprojFile"].contains("rfilename")) {
  821. mmprojFile = j["mmprojFile"]["rfilename"].get<std::string>();
  822. }
  823. } catch (const std::exception & e) {
  824. throw std::runtime_error(std::string("error parsing manifest JSON: ") + e.what());
  825. }
  826. if (!use_cache) {
  827. // if not using cached response, update the cache file
  828. write_file(cached_response_path, res_str);
  829. }
  830. } else if (res_code == 401) {
  831. throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
  832. } else {
  833. throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
  834. }
  835. // check response
  836. if (ggufFile.empty()) {
  837. throw std::runtime_error("error: model does not have ggufFile");
  838. }
  839. return { hf_repo, ggufFile, mmprojFile };
  840. }
  841. //
  842. // Docker registry functions
  843. //
  844. static std::string common_docker_get_token(const std::string & repo) {
  845. std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
  846. common_remote_params params;
  847. auto res = common_remote_get_content(url, params);
  848. if (res.first != 200) {
  849. throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
  850. }
  851. std::string response_str(res.second.begin(), res.second.end());
  852. nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
  853. if (!response.contains("token")) {
  854. throw std::runtime_error("Docker registry token response missing 'token' field");
  855. }
  856. return response["token"].get<std::string>();
  857. }
  858. std::string common_docker_resolve_model(const std::string & docker) {
  859. // Parse ai/smollm2:135M-Q4_0
  860. size_t colon_pos = docker.find(':');
  861. std::string repo, tag;
  862. if (colon_pos != std::string::npos) {
  863. repo = docker.substr(0, colon_pos);
  864. tag = docker.substr(colon_pos + 1);
  865. } else {
  866. repo = docker;
  867. tag = "latest";
  868. }
  869. // ai/ is the default
  870. size_t slash_pos = docker.find('/');
  871. if (slash_pos == std::string::npos) {
  872. repo.insert(0, "ai/");
  873. }
  874. LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
  875. try {
  876. // --- helper: digest validation ---
  877. auto validate_oci_digest = [](const std::string & digest) -> std::string {
  878. // Expected: algo:hex ; start with sha256 (64 hex chars)
  879. // You can extend this map if supporting other algorithms in future.
  880. static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
  881. std::smatch m;
  882. if (!std::regex_match(digest, m, re)) {
  883. throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
  884. }
  885. // normalize hex to lowercase
  886. std::string normalized = digest;
  887. std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
  888. return std::tolower(c);
  889. });
  890. return normalized;
  891. };
  892. std::string token = common_docker_get_token(repo); // Get authentication token
  893. // Get manifest
  894. // TODO: cache the manifest response so that it appears in the model list
  895. const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
  896. std::string manifest_url = url_prefix + "/manifests/" + tag;
  897. common_remote_params manifest_params;
  898. manifest_params.headers.push_back({"Authorization", "Bearer " + token});
  899. manifest_params.headers.push_back({"Accept",
  900. "application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"
  901. });
  902. auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
  903. if (manifest_res.first != 200) {
  904. throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
  905. }
  906. std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
  907. nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
  908. std::string gguf_digest; // Find the GGUF layer
  909. if (manifest.contains("layers")) {
  910. for (const auto & layer : manifest["layers"]) {
  911. if (layer.contains("mediaType")) {
  912. std::string media_type = layer["mediaType"].get<std::string>();
  913. if (media_type == "application/vnd.docker.ai.gguf.v3" ||
  914. media_type.find("gguf") != std::string::npos) {
  915. gguf_digest = layer["digest"].get<std::string>();
  916. break;
  917. }
  918. }
  919. }
  920. }
  921. if (gguf_digest.empty()) {
  922. throw std::runtime_error("No GGUF layer found in Docker manifest");
  923. }
  924. // Validate & normalize digest
  925. gguf_digest = validate_oci_digest(gguf_digest);
  926. LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());
  927. // Prepare local filename
  928. std::string model_filename = repo;
  929. std::replace(model_filename.begin(), model_filename.end(), '/', '_');
  930. model_filename += "_" + tag + ".gguf";
  931. std::string local_path = fs_get_cache_file(model_filename);
  932. const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
  933. if (!common_download_file_single(blob_url, local_path, token, false, {})) {
  934. throw std::runtime_error("Failed to download Docker Model");
  935. }
  936. LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
  937. return local_path;
  938. } catch (const std::exception & e) {
  939. LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
  940. throw;
  941. }
  942. }
  943. #else
  944. common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) {
  945. throw std::runtime_error("download functionality is not enabled in this build");
  946. }
  947. bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) {
  948. throw std::runtime_error("download functionality is not enabled in this build");
  949. }
  950. std::string common_docker_resolve_model(const std::string &) {
  951. throw std::runtime_error("download functionality is not enabled in this build");
  952. }
  953. #endif // LLAMA_USE_CURL || LLAMA_USE_HTTPLIB
  954. std::vector<common_cached_model_info> common_list_cached_models() {
  955. std::vector<common_cached_model_info> models;
  956. const std::string cache_dir = fs_get_cache_directory();
  957. const std::vector<common_file_info> files = fs_list(cache_dir, false);
  958. for (const auto & file : files) {
  959. if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
  960. common_cached_model_info model_info;
  961. model_info.manifest_path = file.path;
  962. std::string fname = file.name;
  963. string_replace_all(fname, ".json", ""); // remove extension
  964. auto parts = string_split<std::string>(fname, '=');
  965. if (parts.size() == 4) {
  966. // expect format: manifest=<user>=<model>=<tag>=<other>
  967. model_info.user = parts[1];
  968. model_info.model = parts[2];
  969. model_info.tag = parts[3];
  970. } else {
  971. // invalid format
  972. continue;
  973. }
  974. model_info.size = 0; // TODO: get GGUF size, not manifest size
  975. models.push_back(model_info);
  976. }
  977. }
  978. return models;
  979. }