download.cpp 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014
  1. #include "arg.h"
  2. #include "common.h"
  3. #include "gguf.h" // for reading GGUF splits
  4. #include "log.h"
  5. #include "download.h"
  6. #define JSON_ASSERT GGML_ASSERT
  7. #include <nlohmann/json.hpp>
  8. #include <algorithm>
  9. #include <filesystem>
  10. #include <fstream>
  11. #include <future>
  12. #include <regex>
  13. #include <string>
  14. #include <thread>
  15. #include <vector>
  16. #if defined(LLAMA_USE_CURL)
  17. #include <curl/curl.h>
  18. #include <curl/easy.h>
  19. #else
  20. #include "http.h"
  21. #endif
  22. #ifdef __linux__
  23. #include <linux/limits.h>
  24. #elif defined(_WIN32)
  25. # if !defined(PATH_MAX)
  26. # define PATH_MAX MAX_PATH
  27. # endif
  28. #elif defined(_AIX)
  29. #include <sys/limits.h>
  30. #else
  31. #include <sys/syslimits.h>
  32. #endif
  33. #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
  34. // isatty
  35. #if defined(_WIN32)
  36. #include <io.h>
  37. #else
  38. #include <unistd.h>
  39. #endif
  40. using json = nlohmann::ordered_json;
  41. //
  42. // downloader
  43. //
  44. static std::string read_file(const std::string & fname) {
  45. std::ifstream file(fname);
  46. if (!file) {
  47. throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
  48. }
  49. std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
  50. file.close();
  51. return content;
  52. }
  53. static void write_file(const std::string & fname, const std::string & content) {
  54. const std::string fname_tmp = fname + ".tmp";
  55. std::ofstream file(fname_tmp);
  56. if (!file) {
  57. throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
  58. }
  59. try {
  60. file << content;
  61. file.close();
  62. // Makes write atomic
  63. if (rename(fname_tmp.c_str(), fname.c_str()) != 0) {
  64. LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, fname_tmp.c_str(), fname.c_str());
  65. // If rename fails, try to delete the temporary file
  66. if (remove(fname_tmp.c_str()) != 0) {
  67. LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
  68. }
  69. }
  70. } catch (...) {
  71. // If anything fails, try to delete the temporary file
  72. if (remove(fname_tmp.c_str()) != 0) {
  73. LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
  74. }
  75. throw std::runtime_error(string_format("error: failed to write file '%s'\n", fname.c_str()));
  76. }
  77. }
  78. static void write_etag(const std::string & path, const std::string & etag) {
  79. const std::string etag_path = path + ".etag";
  80. write_file(etag_path, etag);
  81. LOG_DBG("%s: file etag saved: %s\n", __func__, etag_path.c_str());
  82. }
  83. static std::string read_etag(const std::string & path) {
  84. std::string none;
  85. const std::string etag_path = path + ".etag";
  86. if (std::filesystem::exists(etag_path)) {
  87. std::ifstream etag_in(etag_path);
  88. if (!etag_in) {
  89. LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
  90. return none;
  91. }
  92. std::string etag;
  93. std::getline(etag_in, etag);
  94. return etag;
  95. }
  96. // no etag file, but maybe there is an old .json
  97. // remove this code later
  98. const std::string metadata_path = path + ".json";
  99. if (std::filesystem::exists(metadata_path)) {
  100. std::ifstream metadata_in(metadata_path);
  101. try {
  102. nlohmann::json metadata_json;
  103. metadata_in >> metadata_json;
  104. LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
  105. metadata_json.dump().c_str());
  106. if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
  107. std::string etag = metadata_json.at("etag");
  108. write_etag(path, etag);
  109. if (!std::filesystem::remove(metadata_path)) {
  110. LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str());
  111. }
  112. return etag;
  113. }
  114. } catch (const nlohmann::json::exception & e) {
  115. LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
  116. }
  117. }
  118. return none;
  119. }
  120. #ifdef LLAMA_USE_CURL
  121. //
  122. // CURL utils
  123. //
  124. using curl_ptr = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
  125. // cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one
  126. struct curl_slist_ptr {
  127. struct curl_slist * ptr = nullptr;
  128. ~curl_slist_ptr() {
  129. if (ptr) {
  130. curl_slist_free_all(ptr);
  131. }
  132. }
  133. };
  134. static CURLcode common_curl_perf(CURL * curl) {
  135. CURLcode res = curl_easy_perform(curl);
  136. if (res != CURLE_OK) {
  137. LOG_ERR("%s: curl_easy_perform() failed\n", __func__);
  138. }
  139. return res;
  140. }
  141. // Send a HEAD request to retrieve the etag and last-modified headers
  142. struct common_load_model_from_url_headers {
  143. std::string etag;
  144. std::string last_modified;
  145. std::string accept_ranges;
  146. };
  147. struct FILE_deleter {
  148. void operator()(FILE * f) const { fclose(f); }
  149. };
  150. static size_t common_header_callback(char * buffer, size_t, size_t n_items, void * userdata) {
  151. common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
  152. static std::regex header_regex("([^:]+): (.*)\r\n");
  153. static std::regex etag_regex("ETag", std::regex_constants::icase);
  154. static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
  155. static std::regex accept_ranges_regex("Accept-Ranges", std::regex_constants::icase);
  156. std::string header(buffer, n_items);
  157. std::smatch match;
  158. if (std::regex_match(header, match, header_regex)) {
  159. const std::string & key = match[1];
  160. const std::string & value = match[2];
  161. if (std::regex_match(key, match, etag_regex)) {
  162. headers->etag = value;
  163. } else if (std::regex_match(key, match, last_modified_regex)) {
  164. headers->last_modified = value;
  165. } else if (std::regex_match(key, match, accept_ranges_regex)) {
  166. headers->accept_ranges = value;
  167. }
  168. }
  169. return n_items;
  170. }
  171. static size_t common_write_callback(void * data, size_t size, size_t nmemb, void * fd) {
  172. return std::fwrite(data, size, nmemb, static_cast<FILE *>(fd));
  173. }
  174. // helper function to hide password in URL
  175. static std::string llama_download_hide_password_in_url(const std::string & url) {
  176. // Use regex to match and replace the user[:password]@ pattern in URLs
  177. // Pattern: scheme://[user[:password]@]host[...]
  178. static const std::regex url_regex(R"(^(?:[A-Za-z][A-Za-z0-9+.-]://)(?:[^/@]+@)?.$)");
  179. std::smatch match;
  180. if (std::regex_match(url, match, url_regex)) {
  181. // match[1] = scheme (e.g., "https://")
  182. // match[2] = user[:password]@ part
  183. // match[3] = rest of URL (host and path)
  184. return match[1].str() + "********@" + match[3].str();
  185. }
  186. return url; // No credentials found or malformed URL
  187. }
  188. static void common_curl_easy_setopt_head(CURL * curl, const std::string & url) {
  189. // Set the URL, allow to follow http redirection
  190. curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
  191. curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
  192. # if defined(_WIN32)
  193. // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
  194. // operating system. Currently implemented under MS-Windows.
  195. curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
  196. # endif
  197. curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
  198. curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L); // hide head request progress
  199. curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, common_header_callback);
  200. }
  201. static void common_curl_easy_setopt_get(CURL * curl) {
  202. curl_easy_setopt(curl, CURLOPT_NOBODY, 0L);
  203. curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, common_write_callback);
  204. // display download progress
  205. curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
  206. }
  207. static bool common_pull_file(CURL * curl, const std::string & path_temporary) {
  208. if (std::filesystem::exists(path_temporary)) {
  209. const std::string partial_size = std::to_string(std::filesystem::file_size(path_temporary));
  210. LOG_INF("%s: server supports range requests, resuming download from byte %s\n", __func__, partial_size.c_str());
  211. const std::string range_str = partial_size + "-";
  212. curl_easy_setopt(curl, CURLOPT_RANGE, range_str.c_str());
  213. }
  214. // Always open file in append mode could be resuming
  215. std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "ab"));
  216. if (!outfile) {
  217. LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_temporary.c_str());
  218. return false;
  219. }
  220. common_curl_easy_setopt_get(curl);
  221. curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile.get());
  222. return common_curl_perf(curl) == CURLE_OK;
  223. }
  224. static bool common_download_head(CURL * curl,
  225. curl_slist_ptr & http_headers,
  226. const std::string & url,
  227. const std::string & bearer_token) {
  228. if (!curl) {
  229. LOG_ERR("%s: error initializing libcurl\n", __func__);
  230. return false;
  231. }
  232. http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
  233. // Check if hf-token or bearer-token was specified
  234. if (!bearer_token.empty()) {
  235. std::string auth_header = "Authorization: Bearer " + bearer_token;
  236. http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
  237. }
  238. curl_easy_setopt(curl, CURLOPT_HTTPHEADER, http_headers.ptr);
  239. common_curl_easy_setopt_head(curl, url);
  240. return common_curl_perf(curl) == CURLE_OK;
  241. }
  242. // download one single file from remote URL to local path
  243. static bool common_download_file_single_online(const std::string & url,
  244. const std::string & path,
  245. const std::string & bearer_token) {
  246. static const int max_attempts = 3;
  247. static const int retry_delay_seconds = 2;
  248. for (int i = 0; i < max_attempts; ++i) {
  249. std::string etag;
  250. // Check if the file already exists locally
  251. const auto file_exists = std::filesystem::exists(path);
  252. if (file_exists) {
  253. etag = read_etag(path);
  254. } else {
  255. LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
  256. }
  257. bool head_request_ok = false;
  258. bool should_download = !file_exists; // by default, we should download if the file does not exist
  259. // Initialize libcurl
  260. curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
  261. common_load_model_from_url_headers headers;
  262. curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
  263. curl_slist_ptr http_headers;
  264. const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token);
  265. if (!was_perform_successful) {
  266. head_request_ok = false;
  267. }
  268. long http_code = 0;
  269. curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
  270. if (http_code == 200) {
  271. head_request_ok = true;
  272. } else {
  273. LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
  274. head_request_ok = false;
  275. }
  276. // if head_request_ok is false, we don't have the etag or last-modified headers
  277. // we leave should_download as-is, which is true if the file does not exist
  278. bool should_download_from_scratch = false;
  279. if (head_request_ok) {
  280. // check if ETag or Last-Modified headers are different
  281. // if it is, we need to download the file again
  282. if (!etag.empty() && etag != headers.etag) {
  283. LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(),
  284. headers.etag.c_str());
  285. should_download = true;
  286. should_download_from_scratch = true;
  287. }
  288. }
  289. const bool accept_ranges_supported = !headers.accept_ranges.empty() && headers.accept_ranges != "none";
  290. if (should_download) {
  291. if (file_exists &&
  292. !accept_ranges_supported) { // Resumable downloads not supported, delete and start again.
  293. LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
  294. if (remove(path.c_str()) != 0) {
  295. LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
  296. return false;
  297. }
  298. }
  299. const std::string path_temporary = path + ".downloadInProgress";
  300. if (should_download_from_scratch) {
  301. if (std::filesystem::exists(path_temporary)) {
  302. if (remove(path_temporary.c_str()) != 0) {
  303. LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
  304. return false;
  305. }
  306. }
  307. if (std::filesystem::exists(path)) {
  308. if (remove(path.c_str()) != 0) {
  309. LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
  310. return false;
  311. }
  312. }
  313. }
  314. if (head_request_ok) {
  315. write_etag(path, headers.etag);
  316. }
  317. // start the download
  318. LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
  319. __func__, llama_download_hide_password_in_url(url).c_str(), path_temporary.c_str(),
  320. headers.etag.c_str(), headers.last_modified.c_str());
  321. const bool was_pull_successful = common_pull_file(curl.get(), path_temporary);
  322. if (!was_pull_successful) {
  323. if (i + 1 < max_attempts) {
  324. const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
  325. LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
  326. std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
  327. } else {
  328. LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
  329. }
  330. continue;
  331. }
  332. long http_code = 0;
  333. curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
  334. if (http_code < 200 || http_code >= 400) {
  335. LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
  336. return false;
  337. }
  338. if (rename(path_temporary.c_str(), path.c_str()) != 0) {
  339. LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
  340. return false;
  341. }
  342. } else {
  343. LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
  344. }
  345. break;
  346. }
  347. return true;
  348. }
  349. std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params) {
  350. curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
  351. curl_slist_ptr http_headers;
  352. std::vector<char> res_buffer;
  353. curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
  354. curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
  355. curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
  356. curl_easy_setopt(curl.get(), CURLOPT_VERBOSE, 1L);
  357. typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
  358. auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
  359. auto data_vec = static_cast<std::vector<char> *>(data);
  360. data_vec->insert(data_vec->end(), (char *)ptr, (char *)ptr + size * nmemb);
  361. return size * nmemb;
  362. };
  363. curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
  364. curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_buffer);
  365. #if defined(_WIN32)
  366. curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
  367. #endif
  368. if (params.timeout > 0) {
  369. curl_easy_setopt(curl.get(), CURLOPT_TIMEOUT, params.timeout);
  370. }
  371. if (params.max_size > 0) {
  372. curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size);
  373. }
  374. http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
  375. for (const auto & header : params.headers) {
  376. http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str());
  377. }
  378. curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
  379. CURLcode res = curl_easy_perform(curl.get());
  380. if (res != CURLE_OK) {
  381. std::string error_msg = curl_easy_strerror(res);
  382. throw std::runtime_error("error: cannot make GET request: " + error_msg);
  383. }
  384. long res_code;
  385. curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
  386. return { res_code, std::move(res_buffer) };
  387. }
  388. #else
  389. static bool is_output_a_tty() {
  390. #if defined(_WIN32)
  391. return _isatty(_fileno(stdout));
  392. #else
  393. return isatty(1);
  394. #endif
  395. }
  396. static void print_progress(size_t current, size_t total) {
  397. if (!is_output_a_tty()) {
  398. return;
  399. }
  400. if (!total) {
  401. return;
  402. }
  403. size_t width = 50;
  404. size_t pct = (100 * current) / total;
  405. size_t pos = (width * current) / total;
  406. std::cout << "["
  407. << std::string(pos, '=')
  408. << (pos < width ? ">" : "")
  409. << std::string(width - pos, ' ')
  410. << "] " << std::setw(3) << pct << "% ("
  411. << current / (1024 * 1024) << " MB / "
  412. << total / (1024 * 1024) << " MB)\r";
  413. std::cout.flush();
  414. }
  415. static bool common_pull_file(httplib::Client & cli,
  416. const std::string & resolve_path,
  417. const std::string & path_tmp,
  418. bool supports_ranges,
  419. size_t existing_size,
  420. size_t & total_size) {
  421. std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app);
  422. if (!ofs.is_open()) {
  423. LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str());
  424. return false;
  425. }
  426. httplib::Headers headers;
  427. if (supports_ranges && existing_size > 0) {
  428. headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-");
  429. }
  430. std::atomic<size_t> downloaded{existing_size};
  431. auto res = cli.Get(resolve_path, headers,
  432. [&](const httplib::Response &response) {
  433. if (existing_size > 0 && response.status != 206) {
  434. LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", __func__, response.status);
  435. return false;
  436. }
  437. if (existing_size == 0 && response.status != 200) {
  438. LOG_WRN("%s: download received non-successful status code: %d\n", __func__, response.status);
  439. return false;
  440. }
  441. if (total_size == 0 && response.has_header("Content-Length")) {
  442. try {
  443. size_t content_length = std::stoull(response.get_header_value("Content-Length"));
  444. total_size = existing_size + content_length;
  445. } catch (const std::exception &e) {
  446. LOG_WRN("%s: invalid Content-Length header: %s\n", __func__, e.what());
  447. }
  448. }
  449. return true;
  450. },
  451. [&](const char *data, size_t len) {
  452. ofs.write(data, len);
  453. if (!ofs) {
  454. LOG_ERR("%s: error writing to file: %s\n", __func__, path_tmp.c_str());
  455. return false;
  456. }
  457. downloaded += len;
  458. print_progress(downloaded, total_size);
  459. return true;
  460. },
  461. nullptr
  462. );
  463. std::cout << "\n";
  464. if (!res) {
  465. LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
  466. return false;
  467. }
  468. return true;
  469. }
  470. // download one single file from remote URL to local path
  471. static bool common_download_file_single_online(const std::string & url,
  472. const std::string & path,
  473. const std::string & bearer_token) {
  474. static const int max_attempts = 3;
  475. static const int retry_delay_seconds = 2;
  476. auto [cli, parts] = common_http_client(url);
  477. httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}};
  478. if (!bearer_token.empty()) {
  479. default_headers.insert({"Authorization", "Bearer " + bearer_token});
  480. }
  481. cli.set_default_headers(default_headers);
  482. const bool file_exists = std::filesystem::exists(path);
  483. std::string last_etag;
  484. if (file_exists) {
  485. last_etag = read_etag(path);
  486. } else {
  487. LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
  488. }
  489. for (int i = 0; i < max_attempts; ++i) {
  490. auto head = cli.Head(parts.path);
  491. bool head_ok = head && head->status >= 200 && head->status < 300;
  492. if (!head_ok) {
  493. LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
  494. if (file_exists) {
  495. LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
  496. return true;
  497. }
  498. }
  499. std::string etag;
  500. if (head_ok && head->has_header("ETag")) {
  501. etag = head->get_header_value("ETag");
  502. }
  503. size_t total_size = 0;
  504. if (head_ok && head->has_header("Content-Length")) {
  505. try {
  506. total_size = std::stoull(head->get_header_value("Content-Length"));
  507. } catch (const std::exception& e) {
  508. LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
  509. }
  510. }
  511. bool supports_ranges = false;
  512. if (head_ok && head->has_header("Accept-Ranges")) {
  513. supports_ranges = head->get_header_value("Accept-Ranges") != "none";
  514. }
  515. bool should_download_from_scratch = false;
  516. if (!last_etag.empty() && !etag.empty() && last_etag != etag) {
  517. LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
  518. last_etag.c_str(), etag.c_str());
  519. should_download_from_scratch = true;
  520. }
  521. if (file_exists) {
  522. if (!should_download_from_scratch) {
  523. LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
  524. return true;
  525. }
  526. LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
  527. if (remove(path.c_str()) != 0) {
  528. LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
  529. return false;
  530. }
  531. }
  532. const std::string path_temporary = path + ".downloadInProgress";
  533. size_t existing_size = 0;
  534. if (std::filesystem::exists(path_temporary)) {
  535. if (supports_ranges && !should_download_from_scratch) {
  536. existing_size = std::filesystem::file_size(path_temporary);
  537. } else if (remove(path_temporary.c_str()) != 0) {
  538. LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
  539. return false;
  540. }
  541. }
  542. // start the download
  543. LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
  544. __func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
  545. const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
  546. if (!was_pull_successful) {
  547. if (i + 1 < max_attempts) {
  548. const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
  549. LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
  550. std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
  551. } else {
  552. LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
  553. }
  554. continue;
  555. }
  556. if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
  557. LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
  558. return false;
  559. }
  560. if (!etag.empty()) {
  561. write_etag(path, etag);
  562. }
  563. break;
  564. }
  565. return true;
  566. }
  567. std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
  568. const common_remote_params & params) {
  569. auto [cli, parts] = common_http_client(url);
  570. httplib::Headers headers = {{"User-Agent", "llama-cpp"}};
  571. for (const auto & header : params.headers) {
  572. size_t pos = header.find(':');
  573. if (pos != std::string::npos) {
  574. headers.emplace(header.substr(0, pos), header.substr(pos + 1));
  575. } else {
  576. headers.emplace(header, "");
  577. }
  578. }
  579. if (params.timeout > 0) {
  580. cli.set_read_timeout(params.timeout, 0);
  581. cli.set_write_timeout(params.timeout, 0);
  582. }
  583. std::vector<char> buf;
  584. auto res = cli.Get(parts.path, headers,
  585. [&](const char *data, size_t len) {
  586. buf.insert(buf.end(), data, data + len);
  587. return params.max_size == 0 ||
  588. buf.size() <= static_cast<size_t>(params.max_size);
  589. },
  590. nullptr
  591. );
  592. if (!res) {
  593. throw std::runtime_error("error: cannot make GET request");
  594. }
  595. return { res->status, std::move(buf) };
  596. }
  597. #endif // LLAMA_USE_CURL
  598. static bool common_download_file_single(const std::string & url,
  599. const std::string & path,
  600. const std::string & bearer_token,
  601. bool offline) {
  602. if (!offline) {
  603. return common_download_file_single_online(url, path, bearer_token);
  604. }
  605. if (!std::filesystem::exists(path)) {
  606. LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
  607. return false;
  608. }
  609. LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
  610. return true;
  611. }
  612. // download multiple files from remote URLs to local paths
  613. // the input is a vector of pairs <url, path>
  614. static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls, const std::string & bearer_token, bool offline) {
  615. // Prepare download in parallel
  616. std::vector<std::future<bool>> futures_download;
  617. for (auto const & item : urls) {
  618. futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair<std::string, std::string> & it) -> bool {
  619. return common_download_file_single(it.first, it.second, bearer_token, offline);
  620. }, item));
  621. }
  622. // Wait for all downloads to complete
  623. for (auto & f : futures_download) {
  624. if (!f.get()) {
  625. return false;
  626. }
  627. }
  628. return true;
  629. }
  630. bool common_download_model(
  631. const common_params_model & model,
  632. const std::string & bearer_token,
  633. bool offline) {
  634. // Basic validation of the model.url
  635. if (model.url.empty()) {
  636. LOG_ERR("%s: invalid model url\n", __func__);
  637. return false;
  638. }
  639. if (!common_download_file_single(model.url, model.path, bearer_token, offline)) {
  640. return false;
  641. }
  642. // check for additional GGUFs split to download
  643. int n_split = 0;
  644. {
  645. struct gguf_init_params gguf_params = {
  646. /*.no_alloc = */ true,
  647. /*.ctx = */ NULL,
  648. };
  649. auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params);
  650. if (!ctx_gguf) {
  651. LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str());
  652. return false;
  653. }
  654. auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
  655. if (key_n_split >= 0) {
  656. n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
  657. }
  658. gguf_free(ctx_gguf);
  659. }
  660. if (n_split > 1) {
  661. char split_prefix[PATH_MAX] = {0};
  662. char split_url_prefix[LLAMA_MAX_URL_LENGTH] = {0};
  663. // Verify the first split file format
  664. // and extract split URL and PATH prefixes
  665. {
  666. if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) {
  667. LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split);
  668. return false;
  669. }
  670. if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) {
  671. LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split);
  672. return false;
  673. }
  674. }
  675. std::vector<std::pair<std::string, std::string>> urls;
  676. for (int idx = 1; idx < n_split; idx++) {
  677. char split_path[PATH_MAX] = {0};
  678. llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split);
  679. char split_url[LLAMA_MAX_URL_LENGTH] = {0};
  680. llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split);
  681. if (std::string(split_path) == model.path) {
  682. continue; // skip the already downloaded file
  683. }
  684. urls.push_back({split_url, split_path});
  685. }
  686. // Download in parallel
  687. common_download_file_multiple(urls, bearer_token, offline);
  688. }
  689. return true;
  690. }
  691. common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token, bool offline) {
  692. auto parts = string_split<std::string>(hf_repo_with_tag, ':');
  693. std::string tag = parts.size() > 1 ? parts.back() : "latest";
  694. std::string hf_repo = parts[0];
  695. if (string_split<std::string>(hf_repo, '/').size() != 2) {
  696. throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
  697. }
  698. std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
  699. // headers
  700. std::vector<std::string> headers;
  701. headers.push_back("Accept: application/json");
  702. if (!bearer_token.empty()) {
  703. headers.push_back("Authorization: Bearer " + bearer_token);
  704. }
  705. // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
  706. // User-Agent header is already set in common_remote_get_content, no need to set it here
  707. // we use "=" to avoid clashing with other component, while still being allowed on windows
  708. std::string cached_response_fname = "manifest=" + hf_repo + "=" + tag + ".json";
  709. string_replace_all(cached_response_fname, "/", "_");
  710. std::string cached_response_path = fs_get_cache_file(cached_response_fname);
  711. // make the request
  712. common_remote_params params;
  713. params.headers = headers;
  714. long res_code = 0;
  715. std::string res_str;
  716. bool use_cache = false;
  717. if (!offline) {
  718. try {
  719. auto res = common_remote_get_content(url, params);
  720. res_code = res.first;
  721. res_str = std::string(res.second.data(), res.second.size());
  722. } catch (const std::exception & e) {
  723. LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what());
  724. }
  725. }
  726. if (res_code == 0) {
  727. if (std::filesystem::exists(cached_response_path)) {
  728. LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str());
  729. res_str = read_file(cached_response_path);
  730. res_code = 200;
  731. use_cache = true;
  732. } else {
  733. throw std::runtime_error(
  734. offline ? "error: failed to get manifest (offline mode)"
  735. : "error: failed to get manifest (check your internet connection)");
  736. }
  737. }
  738. std::string ggufFile;
  739. std::string mmprojFile;
  740. if (res_code == 200 || res_code == 304) {
  741. try {
  742. auto j = json::parse(res_str);
  743. if (j.contains("ggufFile") && j["ggufFile"].contains("rfilename")) {
  744. ggufFile = j["ggufFile"]["rfilename"].get<std::string>();
  745. }
  746. if (j.contains("mmprojFile") && j["mmprojFile"].contains("rfilename")) {
  747. mmprojFile = j["mmprojFile"]["rfilename"].get<std::string>();
  748. }
  749. } catch (const std::exception & e) {
  750. throw std::runtime_error(std::string("error parsing manifest JSON: ") + e.what());
  751. }
  752. if (!use_cache) {
  753. // if not using cached response, update the cache file
  754. write_file(cached_response_path, res_str);
  755. }
  756. } else if (res_code == 401) {
  757. throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
  758. } else {
  759. throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
  760. }
  761. // check response
  762. if (ggufFile.empty()) {
  763. throw std::runtime_error("error: model does not have ggufFile");
  764. }
  765. return { hf_repo, ggufFile, mmprojFile };
  766. }
  767. //
  768. // Docker registry functions
  769. //
  770. static std::string common_docker_get_token(const std::string & repo) {
  771. std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
  772. common_remote_params params;
  773. auto res = common_remote_get_content(url, params);
  774. if (res.first != 200) {
  775. throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
  776. }
  777. std::string response_str(res.second.begin(), res.second.end());
  778. nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
  779. if (!response.contains("token")) {
  780. throw std::runtime_error("Docker registry token response missing 'token' field");
  781. }
  782. return response["token"].get<std::string>();
  783. }
  784. std::string common_docker_resolve_model(const std::string & docker) {
  785. // Parse ai/smollm2:135M-Q4_0
  786. size_t colon_pos = docker.find(':');
  787. std::string repo, tag;
  788. if (colon_pos != std::string::npos) {
  789. repo = docker.substr(0, colon_pos);
  790. tag = docker.substr(colon_pos + 1);
  791. } else {
  792. repo = docker;
  793. tag = "latest";
  794. }
  795. // ai/ is the default
  796. size_t slash_pos = docker.find('/');
  797. if (slash_pos == std::string::npos) {
  798. repo.insert(0, "ai/");
  799. }
  800. LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
  801. try {
  802. // --- helper: digest validation ---
  803. auto validate_oci_digest = [](const std::string & digest) -> std::string {
  804. // Expected: algo:hex ; start with sha256 (64 hex chars)
  805. // You can extend this map if supporting other algorithms in future.
  806. static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
  807. std::smatch m;
  808. if (!std::regex_match(digest, m, re)) {
  809. throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
  810. }
  811. // normalize hex to lowercase
  812. std::string normalized = digest;
  813. std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
  814. return std::tolower(c);
  815. });
  816. return normalized;
  817. };
  818. std::string token = common_docker_get_token(repo); // Get authentication token
  819. // Get manifest
  820. const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
  821. std::string manifest_url = url_prefix + "/manifests/" + tag;
  822. common_remote_params manifest_params;
  823. manifest_params.headers.push_back("Authorization: Bearer " + token);
  824. manifest_params.headers.push_back(
  825. "Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
  826. auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
  827. if (manifest_res.first != 200) {
  828. throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
  829. }
  830. std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
  831. nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
  832. std::string gguf_digest; // Find the GGUF layer
  833. if (manifest.contains("layers")) {
  834. for (const auto & layer : manifest["layers"]) {
  835. if (layer.contains("mediaType")) {
  836. std::string media_type = layer["mediaType"].get<std::string>();
  837. if (media_type == "application/vnd.docker.ai.gguf.v3" ||
  838. media_type.find("gguf") != std::string::npos) {
  839. gguf_digest = layer["digest"].get<std::string>();
  840. break;
  841. }
  842. }
  843. }
  844. }
  845. if (gguf_digest.empty()) {
  846. throw std::runtime_error("No GGUF layer found in Docker manifest");
  847. }
  848. // Validate & normalize digest
  849. gguf_digest = validate_oci_digest(gguf_digest);
  850. LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());
  851. // Prepare local filename
  852. std::string model_filename = repo;
  853. std::replace(model_filename.begin(), model_filename.end(), '/', '_');
  854. model_filename += "_" + tag + ".gguf";
  855. std::string local_path = fs_get_cache_file(model_filename);
  856. const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
  857. if (!common_download_file_single(blob_url, local_path, token, false)) {
  858. throw std::runtime_error("Failed to download Docker Model");
  859. }
  860. LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
  861. return local_path;
  862. } catch (const std::exception & e) {
  863. LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
  864. throw;
  865. }
  866. }