download.cpp 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181
  1. #include "arg.h"
  2. #include "common.h"
  3. #include "gguf.h" // for reading GGUF splits
  4. #include "log.h"
  5. #include "download.h"
  6. #define JSON_ASSERT GGML_ASSERT
  7. #include <nlohmann/json.hpp>
  8. #include <algorithm>
  9. #include <filesystem>
  10. #include <fstream>
  11. #include <future>
  12. #include <map>
  13. #include <mutex>
  14. #include <regex>
  15. #include <string>
  16. #include <thread>
  17. #include <vector>
  18. #if defined(LLAMA_USE_CURL)
  19. #include <curl/curl.h>
  20. #include <curl/easy.h>
  21. #elif defined(LLAMA_USE_HTTPLIB)
  22. #include "http.h"
  23. #endif
  24. #ifndef __EMSCRIPTEN__
  25. #ifdef __linux__
  26. #include <linux/limits.h>
  27. #elif defined(_WIN32)
  28. # if !defined(PATH_MAX)
  29. # define PATH_MAX MAX_PATH
  30. # endif
  31. #elif defined(_AIX)
  32. #include <sys/limits.h>
  33. #else
  34. #include <sys/syslimits.h>
  35. #endif
  36. #endif
  37. #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
  38. // isatty
  39. #if defined(_WIN32)
  40. #include <io.h>
  41. #else
  42. #include <unistd.h>
  43. #endif
  44. using json = nlohmann::ordered_json;
  45. //
  46. // downloader
  47. //
  48. // validate repo name format: owner/repo
  49. static bool validate_repo_name(const std::string & repo) {
  50. static const std::regex repo_regex(R"(^[A-Za-z0-9_.\-]+\/[A-Za-z0-9_.\-]+$)");
  51. return std::regex_match(repo, repo_regex);
  52. }
  53. static std::string get_manifest_path(const std::string & repo, const std::string & tag) {
  54. // we use "=" to avoid clashing with other component, while still being allowed on windows
  55. std::string fname = "manifest=" + repo + "=" + tag + ".json";
  56. if (!validate_repo_name(repo)) {
  57. throw std::runtime_error("error: repo name must be in the format 'owner/repo'");
  58. }
  59. string_replace_all(fname, "/", "=");
  60. return fs_get_cache_file(fname);
  61. }
  62. static std::string read_file(const std::string & fname) {
  63. std::ifstream file(fname);
  64. if (!file) {
  65. throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
  66. }
  67. std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
  68. file.close();
  69. return content;
  70. }
  71. static void write_file(const std::string & fname, const std::string & content) {
  72. const std::string fname_tmp = fname + ".tmp";
  73. std::ofstream file(fname_tmp);
  74. if (!file) {
  75. throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
  76. }
  77. try {
  78. file << content;
  79. file.close();
  80. // Makes write atomic
  81. if (rename(fname_tmp.c_str(), fname.c_str()) != 0) {
  82. LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, fname_tmp.c_str(), fname.c_str());
  83. // If rename fails, try to delete the temporary file
  84. if (remove(fname_tmp.c_str()) != 0) {
  85. LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
  86. }
  87. }
  88. } catch (...) {
  89. // If anything fails, try to delete the temporary file
  90. if (remove(fname_tmp.c_str()) != 0) {
  91. LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
  92. }
  93. throw std::runtime_error(string_format("error: failed to write file '%s'\n", fname.c_str()));
  94. }
  95. }
  96. static void write_etag(const std::string & path, const std::string & etag) {
  97. const std::string etag_path = path + ".etag";
  98. write_file(etag_path, etag);
  99. LOG_DBG("%s: file etag saved: %s\n", __func__, etag_path.c_str());
  100. }
  101. static std::string read_etag(const std::string & path) {
  102. std::string none;
  103. const std::string etag_path = path + ".etag";
  104. if (std::filesystem::exists(etag_path)) {
  105. std::ifstream etag_in(etag_path);
  106. if (!etag_in) {
  107. LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
  108. return none;
  109. }
  110. std::string etag;
  111. std::getline(etag_in, etag);
  112. return etag;
  113. }
  114. // no etag file, but maybe there is an old .json
  115. // remove this code later
  116. const std::string metadata_path = path + ".json";
  117. if (std::filesystem::exists(metadata_path)) {
  118. std::ifstream metadata_in(metadata_path);
  119. try {
  120. nlohmann::json metadata_json;
  121. metadata_in >> metadata_json;
  122. LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
  123. metadata_json.dump().c_str());
  124. if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
  125. std::string etag = metadata_json.at("etag");
  126. write_etag(path, etag);
  127. if (!std::filesystem::remove(metadata_path)) {
  128. LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str());
  129. }
  130. return etag;
  131. }
  132. } catch (const nlohmann::json::exception & e) {
  133. LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
  134. }
  135. }
  136. return none;
  137. }
  138. static bool is_http_status_ok(int status) {
  139. return status >= 200 && status < 400;
  140. }
  141. std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag) {
  142. auto parts = string_split<std::string>(hf_repo_with_tag, ':');
  143. std::string tag = parts.size() > 1 ? parts.back() : "latest";
  144. std::string hf_repo = parts[0];
  145. if (string_split<std::string>(hf_repo, '/').size() != 2) {
  146. throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
  147. }
  148. return {hf_repo, tag};
  149. }
  150. #ifdef LLAMA_USE_CURL
  151. //
  152. // CURL utils
  153. //
  154. using curl_ptr = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
  155. // cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one
  156. struct curl_slist_ptr {
  157. struct curl_slist * ptr = nullptr;
  158. ~curl_slist_ptr() {
  159. if (ptr) {
  160. curl_slist_free_all(ptr);
  161. }
  162. }
  163. };
  164. static CURLcode common_curl_perf(CURL * curl) {
  165. CURLcode res = curl_easy_perform(curl);
  166. if (res != CURLE_OK) {
  167. LOG_ERR("%s: curl_easy_perform() failed\n", __func__);
  168. }
  169. return res;
  170. }
  171. // Send a HEAD request to retrieve the etag and last-modified headers
  172. struct common_load_model_from_url_headers {
  173. std::string etag;
  174. std::string last_modified;
  175. std::string accept_ranges;
  176. };
  177. struct FILE_deleter {
  178. void operator()(FILE * f) const { fclose(f); }
  179. };
  180. static size_t common_header_callback(char * buffer, size_t, size_t n_items, void * userdata) {
  181. common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
  182. static std::regex header_regex("([^:]+): (.*)\r\n");
  183. static std::regex etag_regex("ETag", std::regex_constants::icase);
  184. static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
  185. static std::regex accept_ranges_regex("Accept-Ranges", std::regex_constants::icase);
  186. std::string header(buffer, n_items);
  187. std::smatch match;
  188. if (std::regex_match(header, match, header_regex)) {
  189. const std::string & key = match[1];
  190. const std::string & value = match[2];
  191. if (std::regex_match(key, match, etag_regex)) {
  192. headers->etag = value;
  193. } else if (std::regex_match(key, match, last_modified_regex)) {
  194. headers->last_modified = value;
  195. } else if (std::regex_match(key, match, accept_ranges_regex)) {
  196. headers->accept_ranges = value;
  197. }
  198. }
  199. return n_items;
  200. }
  201. static size_t common_write_callback(void * data, size_t size, size_t nmemb, void * fd) {
  202. return std::fwrite(data, size, nmemb, static_cast<FILE *>(fd));
  203. }
  204. // helper function to hide password in URL
  205. static std::string llama_download_hide_password_in_url(const std::string & url) {
  206. // Use regex to match and replace the user[:password]@ pattern in URLs
  207. // Pattern: scheme://[user[:password]@]host[...]
  208. static const std::regex url_regex(R"(^(?:[A-Za-z][A-Za-z0-9+.-]://)(?:[^/@]+@)?.$)");
  209. std::smatch match;
  210. if (std::regex_match(url, match, url_regex)) {
  211. // match[1] = scheme (e.g., "https://")
  212. // match[2] = user[:password]@ part
  213. // match[3] = rest of URL (host and path)
  214. return match[1].str() + "********@" + match[3].str();
  215. }
  216. return url; // No credentials found or malformed URL
  217. }
  218. static void common_curl_easy_setopt_head(CURL * curl, const std::string & url) {
  219. // Set the URL, allow to follow http redirection
  220. curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
  221. curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
  222. # if defined(_WIN32)
  223. // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
  224. // operating system. Currently implemented under MS-Windows.
  225. curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
  226. # endif
  227. curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
  228. curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L); // hide head request progress
  229. curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, common_header_callback);
  230. }
  231. static void common_curl_easy_setopt_get(CURL * curl) {
  232. curl_easy_setopt(curl, CURLOPT_NOBODY, 0L);
  233. curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, common_write_callback);
  234. // display download progress
  235. curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
  236. }
  237. static bool common_pull_file(CURL * curl, const std::string & path_temporary) {
  238. if (std::filesystem::exists(path_temporary)) {
  239. const std::string partial_size = std::to_string(std::filesystem::file_size(path_temporary));
  240. LOG_INF("%s: server supports range requests, resuming download from byte %s\n", __func__, partial_size.c_str());
  241. const std::string range_str = partial_size + "-";
  242. curl_easy_setopt(curl, CURLOPT_RANGE, range_str.c_str());
  243. }
  244. // Always open file in append mode could be resuming
  245. std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "ab"));
  246. if (!outfile) {
  247. LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_temporary.c_str());
  248. return false;
  249. }
  250. common_curl_easy_setopt_get(curl);
  251. curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile.get());
  252. return common_curl_perf(curl) == CURLE_OK;
  253. }
  254. static bool common_download_head(CURL * curl,
  255. curl_slist_ptr & http_headers,
  256. const std::string & url,
  257. const std::string & bearer_token) {
  258. if (!curl) {
  259. LOG_ERR("%s: error initializing libcurl\n", __func__);
  260. return false;
  261. }
  262. http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
  263. // Check if hf-token or bearer-token was specified
  264. if (!bearer_token.empty()) {
  265. std::string auth_header = "Authorization: Bearer " + bearer_token;
  266. http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
  267. }
  268. curl_easy_setopt(curl, CURLOPT_HTTPHEADER, http_headers.ptr);
  269. common_curl_easy_setopt_head(curl, url);
  270. return common_curl_perf(curl) == CURLE_OK;
  271. }
  272. // download one single file from remote URL to local path
  273. // returns status code or -1 on error
  274. static int common_download_file_single_online(const std::string & url,
  275. const std::string & path,
  276. const std::string & bearer_token,
  277. const common_header_list & custom_headers) {
  278. static const int max_attempts = 3;
  279. static const int retry_delay_seconds = 2;
  280. for (int i = 0; i < max_attempts; ++i) {
  281. std::string etag;
  282. // Check if the file already exists locally
  283. const auto file_exists = std::filesystem::exists(path);
  284. if (file_exists) {
  285. etag = read_etag(path);
  286. } else {
  287. LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
  288. }
  289. bool head_request_ok = false;
  290. bool should_download = !file_exists; // by default, we should download if the file does not exist
  291. // Initialize libcurl
  292. curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
  293. common_load_model_from_url_headers headers;
  294. curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
  295. curl_slist_ptr http_headers;
  296. for (const auto & h : custom_headers) {
  297. std::string s = h.first + ": " + h.second;
  298. http_headers.ptr = curl_slist_append(http_headers.ptr, s.c_str());
  299. }
  300. const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token);
  301. if (!was_perform_successful) {
  302. head_request_ok = false;
  303. }
  304. long http_code = 0;
  305. curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
  306. if (http_code == 200) {
  307. head_request_ok = true;
  308. } else {
  309. LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
  310. head_request_ok = false;
  311. }
  312. // if head_request_ok is false, we don't have the etag or last-modified headers
  313. // we leave should_download as-is, which is true if the file does not exist
  314. bool should_download_from_scratch = false;
  315. if (head_request_ok) {
  316. // check if ETag or Last-Modified headers are different
  317. // if it is, we need to download the file again
  318. if (!etag.empty() && etag != headers.etag) {
  319. LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(),
  320. headers.etag.c_str());
  321. should_download = true;
  322. should_download_from_scratch = true;
  323. }
  324. }
  325. const bool accept_ranges_supported = !headers.accept_ranges.empty() && headers.accept_ranges != "none";
  326. if (should_download) {
  327. if (file_exists &&
  328. !accept_ranges_supported) { // Resumable downloads not supported, delete and start again.
  329. LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
  330. if (remove(path.c_str()) != 0) {
  331. LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
  332. return -1;
  333. }
  334. }
  335. const std::string path_temporary = path + ".downloadInProgress";
  336. if (should_download_from_scratch) {
  337. if (std::filesystem::exists(path_temporary)) {
  338. if (remove(path_temporary.c_str()) != 0) {
  339. LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
  340. return -1;
  341. }
  342. }
  343. if (std::filesystem::exists(path)) {
  344. if (remove(path.c_str()) != 0) {
  345. LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
  346. return -1;
  347. }
  348. }
  349. }
  350. if (head_request_ok) {
  351. write_etag(path, headers.etag);
  352. }
  353. // start the download
  354. LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
  355. __func__, llama_download_hide_password_in_url(url).c_str(), path_temporary.c_str(),
  356. headers.etag.c_str(), headers.last_modified.c_str());
  357. const bool was_pull_successful = common_pull_file(curl.get(), path_temporary);
  358. if (!was_pull_successful) {
  359. if (i + 1 < max_attempts) {
  360. const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
  361. LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
  362. std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
  363. } else {
  364. LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
  365. }
  366. continue;
  367. }
  368. long http_code = 0;
  369. curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
  370. int status = static_cast<int>(http_code);
  371. if (!is_http_status_ok(http_code)) {
  372. LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
  373. return status; // TODO: maybe only return on certain codes
  374. }
  375. if (rename(path_temporary.c_str(), path.c_str()) != 0) {
  376. LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
  377. return -1;
  378. }
  379. return static_cast<int>(http_code);
  380. } else {
  381. LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
  382. return 304; // Not Modified - fake cached response
  383. }
  384. }
  385. return -1; // max attempts reached
  386. }
  387. std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params) {
  388. curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
  389. curl_slist_ptr http_headers;
  390. std::vector<char> res_buffer;
  391. curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
  392. curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
  393. curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
  394. curl_easy_setopt(curl.get(), CURLOPT_VERBOSE, 0L);
  395. typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
  396. auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
  397. auto data_vec = static_cast<std::vector<char> *>(data);
  398. data_vec->insert(data_vec->end(), (char *)ptr, (char *)ptr + size * nmemb);
  399. return size * nmemb;
  400. };
  401. curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
  402. curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_buffer);
  403. #if defined(_WIN32)
  404. curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
  405. #endif
  406. if (params.timeout > 0) {
  407. curl_easy_setopt(curl.get(), CURLOPT_TIMEOUT, params.timeout);
  408. }
  409. if (params.max_size > 0) {
  410. curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size);
  411. }
  412. http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
  413. for (const auto & header : params.headers) {
  414. std::string header_ = header.first + ": " + header.second;
  415. http_headers.ptr = curl_slist_append(http_headers.ptr, header_.c_str());
  416. }
  417. curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
  418. CURLcode res = curl_easy_perform(curl.get());
  419. if (res != CURLE_OK) {
  420. std::string error_msg = curl_easy_strerror(res);
  421. throw std::runtime_error("error: cannot make GET request: " + error_msg);
  422. }
  423. long res_code;
  424. curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
  425. return { res_code, std::move(res_buffer) };
  426. }
  427. #elif defined(LLAMA_USE_HTTPLIB)
  428. class ProgressBar {
  429. static inline std::mutex mutex;
  430. static inline std::map<const ProgressBar *, int> lines;
  431. static inline int max_line = 0;
  432. static void cleanup(const ProgressBar * line) {
  433. lines.erase(line);
  434. if (lines.empty()) {
  435. max_line = 0;
  436. }
  437. }
  438. static bool is_output_a_tty() {
  439. #if defined(_WIN32)
  440. return _isatty(_fileno(stdout));
  441. #else
  442. return isatty(1);
  443. #endif
  444. }
  445. public:
  446. ProgressBar() = default;
  447. ~ProgressBar() {
  448. std::lock_guard<std::mutex> lock(mutex);
  449. cleanup(this);
  450. }
  451. void update(size_t current, size_t total) {
  452. if (!is_output_a_tty()) {
  453. return;
  454. }
  455. if (!total) {
  456. return;
  457. }
  458. std::lock_guard<std::mutex> lock(mutex);
  459. if (lines.find(this) == lines.end()) {
  460. lines[this] = max_line++;
  461. std::cout << "\n";
  462. }
  463. int lines_up = max_line - lines[this];
  464. size_t width = 50;
  465. size_t pct = (100 * current) / total;
  466. size_t pos = (width * current) / total;
  467. std::cout << "\033[s";
  468. if (lines_up > 0) {
  469. std::cout << "\033[" << lines_up << "A";
  470. }
  471. std::cout << "\033[2K\r["
  472. << std::string(pos, '=')
  473. << (pos < width ? ">" : "")
  474. << std::string(width - pos, ' ')
  475. << "] " << std::setw(3) << pct << "% ("
  476. << current / (1024 * 1024) << " MB / "
  477. << total / (1024 * 1024) << " MB) "
  478. << "\033[u";
  479. std::cout.flush();
  480. if (current == total) {
  481. cleanup(this);
  482. }
  483. }
  484. ProgressBar(const ProgressBar &) = delete;
  485. ProgressBar & operator=(const ProgressBar &) = delete;
  486. };
  487. static bool common_pull_file(httplib::Client & cli,
  488. const std::string & resolve_path,
  489. const std::string & path_tmp,
  490. bool supports_ranges,
  491. size_t existing_size,
  492. size_t & total_size) {
  493. std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app);
  494. if (!ofs.is_open()) {
  495. LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str());
  496. return false;
  497. }
  498. httplib::Headers headers;
  499. if (supports_ranges && existing_size > 0) {
  500. headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-");
  501. }
  502. const char * func = __func__; // avoid __func__ inside a lambda
  503. size_t downloaded = existing_size;
  504. size_t progress_step = 0;
  505. ProgressBar bar;
  506. auto res = cli.Get(resolve_path, headers,
  507. [&](const httplib::Response &response) {
  508. if (existing_size > 0 && response.status != 206) {
  509. LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status);
  510. return false;
  511. }
  512. if (existing_size == 0 && response.status != 200) {
  513. LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status);
  514. return false;
  515. }
  516. if (total_size == 0 && response.has_header("Content-Length")) {
  517. try {
  518. size_t content_length = std::stoull(response.get_header_value("Content-Length"));
  519. total_size = existing_size + content_length;
  520. } catch (const std::exception &e) {
  521. LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what());
  522. }
  523. }
  524. return true;
  525. },
  526. [&](const char *data, size_t len) {
  527. ofs.write(data, len);
  528. if (!ofs) {
  529. LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str());
  530. return false;
  531. }
  532. downloaded += len;
  533. progress_step += len;
  534. if (progress_step >= total_size / 1000 || downloaded == total_size) {
  535. bar.update(downloaded, total_size);
  536. progress_step = 0;
  537. }
  538. return true;
  539. },
  540. nullptr
  541. );
  542. if (!res) {
  543. LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
  544. return false;
  545. }
  546. return true;
  547. }
  548. // download one single file from remote URL to local path
  549. // returns status code or -1 on error
  550. static int common_download_file_single_online(const std::string & url,
  551. const std::string & path,
  552. const std::string & bearer_token,
  553. const common_header_list & custom_headers) {
  554. static const int max_attempts = 3;
  555. static const int retry_delay_seconds = 2;
  556. auto [cli, parts] = common_http_client(url);
  557. httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}};
  558. if (!bearer_token.empty()) {
  559. default_headers.insert({"Authorization", "Bearer " + bearer_token});
  560. }
  561. for (const auto & h : custom_headers) {
  562. default_headers.emplace(h.first, h.second);
  563. }
  564. cli.set_default_headers(default_headers);
  565. const bool file_exists = std::filesystem::exists(path);
  566. std::string last_etag;
  567. if (file_exists) {
  568. last_etag = read_etag(path);
  569. } else {
  570. LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
  571. }
  572. for (int i = 0; i < max_attempts; ++i) {
  573. auto head = cli.Head(parts.path);
  574. bool head_ok = head && head->status >= 200 && head->status < 300;
  575. if (!head_ok) {
  576. LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
  577. if (file_exists) {
  578. LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
  579. return 304; // 304 Not Modified - fake cached response
  580. }
  581. return head->status; // cannot use cached file, return raw status code
  582. // TODO: maybe retry only on certain codes
  583. }
  584. std::string etag;
  585. if (head_ok && head->has_header("ETag")) {
  586. etag = head->get_header_value("ETag");
  587. }
  588. size_t total_size = 0;
  589. if (head_ok && head->has_header("Content-Length")) {
  590. try {
  591. total_size = std::stoull(head->get_header_value("Content-Length"));
  592. } catch (const std::exception& e) {
  593. LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
  594. }
  595. }
  596. bool supports_ranges = false;
  597. if (head_ok && head->has_header("Accept-Ranges")) {
  598. supports_ranges = head->get_header_value("Accept-Ranges") != "none";
  599. }
  600. bool should_download_from_scratch = false;
  601. if (!last_etag.empty() && !etag.empty() && last_etag != etag) {
  602. LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
  603. last_etag.c_str(), etag.c_str());
  604. should_download_from_scratch = true;
  605. }
  606. if (file_exists) {
  607. if (!should_download_from_scratch) {
  608. LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
  609. return 304; // 304 Not Modified - fake cached response
  610. }
  611. LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
  612. if (remove(path.c_str()) != 0) {
  613. LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
  614. return -1;
  615. }
  616. }
  617. const std::string path_temporary = path + ".downloadInProgress";
  618. size_t existing_size = 0;
  619. if (std::filesystem::exists(path_temporary)) {
  620. if (supports_ranges && !should_download_from_scratch) {
  621. existing_size = std::filesystem::file_size(path_temporary);
  622. } else if (remove(path_temporary.c_str()) != 0) {
  623. LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
  624. return -1;
  625. }
  626. }
  627. // start the download
  628. LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
  629. __func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
  630. const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
  631. if (!was_pull_successful) {
  632. if (i + 1 < max_attempts) {
  633. const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
  634. LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
  635. std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
  636. } else {
  637. LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
  638. }
  639. continue;
  640. }
  641. if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
  642. LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
  643. return -1;
  644. }
  645. if (!etag.empty()) {
  646. write_etag(path, etag);
  647. }
  648. return head->status; // TODO: use actual GET status?
  649. }
  650. return -1; // max attempts reached
  651. }
  652. std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
  653. const common_remote_params & params) {
  654. auto [cli, parts] = common_http_client(url);
  655. httplib::Headers headers = {{"User-Agent", "llama-cpp"}};
  656. for (const auto & header : params.headers) {
  657. headers.emplace(header.first, header.second);
  658. }
  659. if (params.timeout > 0) {
  660. cli.set_read_timeout(params.timeout, 0);
  661. cli.set_write_timeout(params.timeout, 0);
  662. }
  663. std::vector<char> buf;
  664. auto res = cli.Get(parts.path, headers,
  665. [&](const char *data, size_t len) {
  666. buf.insert(buf.end(), data, data + len);
  667. return params.max_size == 0 ||
  668. buf.size() <= static_cast<size_t>(params.max_size);
  669. },
  670. nullptr
  671. );
  672. if (!res) {
  673. throw std::runtime_error("error: cannot make GET request");
  674. }
  675. return { res->status, std::move(buf) };
  676. }
  677. #endif // LLAMA_USE_CURL
  678. #if defined(LLAMA_USE_CURL) || defined(LLAMA_USE_HTTPLIB)
  679. int common_download_file_single(const std::string & url,
  680. const std::string & path,
  681. const std::string & bearer_token,
  682. bool offline,
  683. const common_header_list & headers) {
  684. if (!offline) {
  685. return common_download_file_single_online(url, path, bearer_token, headers);
  686. }
  687. if (!std::filesystem::exists(path)) {
  688. LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
  689. return -1;
  690. }
  691. LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
  692. return 304; // Not Modified - fake cached response
  693. }
  694. // download multiple files from remote URLs to local paths
  695. // the input is a vector of pairs <url, path>
  696. static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls,
  697. const std::string & bearer_token,
  698. bool offline,
  699. const common_header_list & headers) {
  700. // Prepare download in parallel
  701. std::vector<std::future<bool>> futures_download;
  702. futures_download.reserve(urls.size());
  703. for (auto const & item : urls) {
  704. futures_download.push_back(
  705. std::async(
  706. std::launch::async,
  707. [&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
  708. const int http_status = common_download_file_single(it.first, it.second, bearer_token, offline, headers);
  709. return is_http_status_ok(http_status);
  710. },
  711. item
  712. )
  713. );
  714. }
  715. // Wait for all downloads to complete
  716. for (auto & f : futures_download) {
  717. if (!f.get()) {
  718. return false;
  719. }
  720. }
  721. return true;
  722. }
  723. bool common_download_model(const common_params_model & model,
  724. const std::string & bearer_token,
  725. bool offline,
  726. const common_header_list & headers) {
  727. // Basic validation of the model.url
  728. if (model.url.empty()) {
  729. LOG_ERR("%s: invalid model url\n", __func__);
  730. return false;
  731. }
  732. const int http_status = common_download_file_single(model.url, model.path, bearer_token, offline, headers);
  733. if (!is_http_status_ok(http_status)) {
  734. return false;
  735. }
  736. // check for additional GGUFs split to download
  737. int n_split = 0;
  738. {
  739. struct gguf_init_params gguf_params = {
  740. /*.no_alloc = */ true,
  741. /*.ctx = */ NULL,
  742. };
  743. auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params);
  744. if (!ctx_gguf) {
  745. LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str());
  746. return false;
  747. }
  748. auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
  749. if (key_n_split >= 0) {
  750. n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
  751. }
  752. gguf_free(ctx_gguf);
  753. }
  754. if (n_split > 1) {
  755. char split_prefix[PATH_MAX] = {0};
  756. char split_url_prefix[LLAMA_MAX_URL_LENGTH] = {0};
  757. // Verify the first split file format
  758. // and extract split URL and PATH prefixes
  759. {
  760. if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) {
  761. LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split);
  762. return false;
  763. }
  764. if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) {
  765. LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split);
  766. return false;
  767. }
  768. }
  769. std::vector<std::pair<std::string, std::string>> urls;
  770. for (int idx = 1; idx < n_split; idx++) {
  771. char split_path[PATH_MAX] = {0};
  772. llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split);
  773. char split_url[LLAMA_MAX_URL_LENGTH] = {0};
  774. llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split);
  775. if (std::string(split_path) == model.path) {
  776. continue; // skip the already downloaded file
  777. }
  778. urls.push_back({split_url, split_path});
  779. }
  780. // Download in parallel
  781. common_download_file_multiple(urls, bearer_token, offline, headers);
  782. }
  783. return true;
  784. }
  785. common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag,
  786. const std::string & bearer_token,
  787. bool offline,
  788. const common_header_list & custom_headers) {
  789. // the returned hf_repo is without tag
  790. auto [hf_repo, tag] = common_download_split_repo_tag(hf_repo_with_tag);
  791. std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
  792. // headers
  793. common_header_list headers = custom_headers;
  794. headers.push_back({"Accept", "application/json"});
  795. if (!bearer_token.empty()) {
  796. headers.push_back({"Authorization", "Bearer " + bearer_token});
  797. }
  798. // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
  799. // User-Agent header is already set in common_remote_get_content, no need to set it here
  800. // make the request
  801. common_remote_params params;
  802. params.headers = headers;
  803. long res_code = 0;
  804. std::string res_str;
  805. bool use_cache = false;
  806. std::string cached_response_path = get_manifest_path(hf_repo, tag);
  807. if (!offline) {
  808. try {
  809. auto res = common_remote_get_content(url, params);
  810. res_code = res.first;
  811. res_str = std::string(res.second.data(), res.second.size());
  812. } catch (const std::exception & e) {
  813. LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what());
  814. }
  815. }
  816. if (res_code == 0) {
  817. if (std::filesystem::exists(cached_response_path)) {
  818. LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str());
  819. res_str = read_file(cached_response_path);
  820. res_code = 200;
  821. use_cache = true;
  822. } else {
  823. throw std::runtime_error(
  824. offline ? "error: failed to get manifest (offline mode)"
  825. : "error: failed to get manifest (check your internet connection)");
  826. }
  827. }
  828. std::string ggufFile;
  829. std::string mmprojFile;
  830. if (res_code == 200 || res_code == 304) {
  831. try {
  832. auto j = json::parse(res_str);
  833. if (j.contains("ggufFile") && j["ggufFile"].contains("rfilename")) {
  834. ggufFile = j["ggufFile"]["rfilename"].get<std::string>();
  835. }
  836. if (j.contains("mmprojFile") && j["mmprojFile"].contains("rfilename")) {
  837. mmprojFile = j["mmprojFile"]["rfilename"].get<std::string>();
  838. }
  839. } catch (const std::exception & e) {
  840. throw std::runtime_error(std::string("error parsing manifest JSON: ") + e.what());
  841. }
  842. if (!use_cache) {
  843. // if not using cached response, update the cache file
  844. write_file(cached_response_path, res_str);
  845. }
  846. } else if (res_code == 401) {
  847. throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
  848. } else {
  849. throw std::runtime_error(string_format("error from HF API (%s), response code: %ld, data: %s", url.c_str(), res_code, res_str.c_str()));
  850. }
  851. // check response
  852. if (ggufFile.empty()) {
  853. throw std::runtime_error("error: model does not have ggufFile");
  854. }
  855. return { hf_repo, ggufFile, mmprojFile };
  856. }
  857. //
  858. // Docker registry functions
  859. //
  860. static std::string common_docker_get_token(const std::string & repo) {
  861. std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
  862. common_remote_params params;
  863. auto res = common_remote_get_content(url, params);
  864. if (res.first != 200) {
  865. throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
  866. }
  867. std::string response_str(res.second.begin(), res.second.end());
  868. nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
  869. if (!response.contains("token")) {
  870. throw std::runtime_error("Docker registry token response missing 'token' field");
  871. }
  872. return response["token"].get<std::string>();
  873. }
  874. std::string common_docker_resolve_model(const std::string & docker) {
  875. // Parse ai/smollm2:135M-Q4_0
  876. size_t colon_pos = docker.find(':');
  877. std::string repo, tag;
  878. if (colon_pos != std::string::npos) {
  879. repo = docker.substr(0, colon_pos);
  880. tag = docker.substr(colon_pos + 1);
  881. } else {
  882. repo = docker;
  883. tag = "latest";
  884. }
  885. // ai/ is the default
  886. size_t slash_pos = docker.find('/');
  887. if (slash_pos == std::string::npos) {
  888. repo.insert(0, "ai/");
  889. }
  890. LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
  891. try {
  892. // --- helper: digest validation ---
  893. auto validate_oci_digest = [](const std::string & digest) -> std::string {
  894. // Expected: algo:hex ; start with sha256 (64 hex chars)
  895. // You can extend this map if supporting other algorithms in future.
  896. static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
  897. std::smatch m;
  898. if (!std::regex_match(digest, m, re)) {
  899. throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
  900. }
  901. // normalize hex to lowercase
  902. std::string normalized = digest;
  903. std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
  904. return std::tolower(c);
  905. });
  906. return normalized;
  907. };
  908. std::string token = common_docker_get_token(repo); // Get authentication token
  909. // Get manifest
  910. // TODO: cache the manifest response so that it appears in the model list
  911. const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
  912. std::string manifest_url = url_prefix + "/manifests/" + tag;
  913. common_remote_params manifest_params;
  914. manifest_params.headers.push_back({"Authorization", "Bearer " + token});
  915. manifest_params.headers.push_back({"Accept",
  916. "application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"
  917. });
  918. auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
  919. if (manifest_res.first != 200) {
  920. throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
  921. }
  922. std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
  923. nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
  924. std::string gguf_digest; // Find the GGUF layer
  925. if (manifest.contains("layers")) {
  926. for (const auto & layer : manifest["layers"]) {
  927. if (layer.contains("mediaType")) {
  928. std::string media_type = layer["mediaType"].get<std::string>();
  929. if (media_type == "application/vnd.docker.ai.gguf.v3" ||
  930. media_type.find("gguf") != std::string::npos) {
  931. gguf_digest = layer["digest"].get<std::string>();
  932. break;
  933. }
  934. }
  935. }
  936. }
  937. if (gguf_digest.empty()) {
  938. throw std::runtime_error("No GGUF layer found in Docker manifest");
  939. }
  940. // Validate & normalize digest
  941. gguf_digest = validate_oci_digest(gguf_digest);
  942. LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());
  943. // Prepare local filename
  944. std::string model_filename = repo;
  945. std::replace(model_filename.begin(), model_filename.end(), '/', '_');
  946. model_filename += "_" + tag + ".gguf";
  947. std::string local_path = fs_get_cache_file(model_filename);
  948. const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
  949. const int http_status = common_download_file_single(blob_url, local_path, token, false, {});
  950. if (!is_http_status_ok(http_status)) {
  951. throw std::runtime_error("Failed to download Docker Model");
  952. }
  953. LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
  954. return local_path;
  955. } catch (const std::exception & e) {
  956. LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
  957. throw;
  958. }
  959. }
  960. #else
  961. common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) {
  962. throw std::runtime_error("download functionality is not enabled in this build");
  963. }
  964. bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) {
  965. throw std::runtime_error("download functionality is not enabled in this build");
  966. }
  967. std::string common_docker_resolve_model(const std::string &) {
  968. throw std::runtime_error("download functionality is not enabled in this build");
  969. }
  970. int common_download_file_single(const std::string &,
  971. const std::string &,
  972. const std::string &,
  973. bool,
  974. const common_header_list &) {
  975. throw std::runtime_error("download functionality is not enabled in this build");
  976. }
  977. #endif // LLAMA_USE_CURL || LLAMA_USE_HTTPLIB
  978. std::vector<common_cached_model_info> common_list_cached_models() {
  979. std::vector<common_cached_model_info> models;
  980. const std::string cache_dir = fs_get_cache_directory();
  981. const std::vector<common_file_info> files = fs_list(cache_dir, false);
  982. for (const auto & file : files) {
  983. if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
  984. common_cached_model_info model_info;
  985. model_info.manifest_path = file.path;
  986. std::string fname = file.name;
  987. string_replace_all(fname, ".json", ""); // remove extension
  988. auto parts = string_split<std::string>(fname, '=');
  989. if (parts.size() == 4) {
  990. // expect format: manifest=<user>=<model>=<tag>=<other>
  991. model_info.user = parts[1];
  992. model_info.model = parts[2];
  993. model_info.tag = parts[3];
  994. } else {
  995. // invalid format
  996. continue;
  997. }
  998. model_info.size = 0; // TODO: get GGUF size, not manifest size
  999. models.push_back(model_info);
  1000. }
  1001. }
  1002. return models;
  1003. }