download.cpp 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850
  1. #include "arg.h"
  2. #include "common.h"
  3. #include "gguf.h" // for reading GGUF splits
  4. #include "log.h"
  5. #include "download.h"
  6. #define JSON_ASSERT GGML_ASSERT
  7. #include <nlohmann/json.hpp>
  8. #include <algorithm>
  9. #include <filesystem>
  10. #include <fstream>
  11. #include <future>
  12. #include <map>
  13. #include <mutex>
  14. #include <regex>
  15. #include <string>
  16. #include <thread>
  17. #include <vector>
  18. #if defined(LLAMA_USE_HTTPLIB)
  19. #include "http.h"
  20. #endif
  21. #ifndef __EMSCRIPTEN__
  22. #ifdef __linux__
  23. #include <linux/limits.h>
  24. #elif defined(_WIN32)
  25. # if !defined(PATH_MAX)
  26. # define PATH_MAX MAX_PATH
  27. # endif
  28. #elif defined(_AIX)
  29. #include <sys/limits.h>
  30. #else
  31. #include <sys/syslimits.h>
  32. #endif
  33. #endif
  34. #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
  35. // isatty
  36. #if defined(_WIN32)
  37. #include <io.h>
  38. #else
  39. #include <unistd.h>
  40. #endif
  41. using json = nlohmann::ordered_json;
  42. //
  43. // downloader
  44. //
  45. // validate repo name format: owner/repo
  46. static bool validate_repo_name(const std::string & repo) {
  47. static const std::regex repo_regex(R"(^[A-Za-z0-9_.\-]+\/[A-Za-z0-9_.\-]+$)");
  48. return std::regex_match(repo, repo_regex);
  49. }
  50. static std::string get_manifest_path(const std::string & repo, const std::string & tag) {
  51. // we use "=" to avoid clashing with other component, while still being allowed on windows
  52. std::string fname = "manifest=" + repo + "=" + tag + ".json";
  53. if (!validate_repo_name(repo)) {
  54. throw std::runtime_error("error: repo name must be in the format 'owner/repo'");
  55. }
  56. string_replace_all(fname, "/", "=");
  57. return fs_get_cache_file(fname);
  58. }
  59. static std::string read_file(const std::string & fname) {
  60. std::ifstream file(fname);
  61. if (!file) {
  62. throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
  63. }
  64. std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
  65. file.close();
  66. return content;
  67. }
  68. static void write_file(const std::string & fname, const std::string & content) {
  69. const std::string fname_tmp = fname + ".tmp";
  70. std::ofstream file(fname_tmp);
  71. if (!file) {
  72. throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
  73. }
  74. try {
  75. file << content;
  76. file.close();
  77. // Makes write atomic
  78. if (rename(fname_tmp.c_str(), fname.c_str()) != 0) {
  79. LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, fname_tmp.c_str(), fname.c_str());
  80. // If rename fails, try to delete the temporary file
  81. if (remove(fname_tmp.c_str()) != 0) {
  82. LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
  83. }
  84. }
  85. } catch (...) {
  86. // If anything fails, try to delete the temporary file
  87. if (remove(fname_tmp.c_str()) != 0) {
  88. LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
  89. }
  90. throw std::runtime_error(string_format("error: failed to write file '%s'\n", fname.c_str()));
  91. }
  92. }
  93. static void write_etag(const std::string & path, const std::string & etag) {
  94. const std::string etag_path = path + ".etag";
  95. write_file(etag_path, etag);
  96. LOG_DBG("%s: file etag saved: %s\n", __func__, etag_path.c_str());
  97. }
  98. static std::string read_etag(const std::string & path) {
  99. std::string none;
  100. const std::string etag_path = path + ".etag";
  101. if (std::filesystem::exists(etag_path)) {
  102. std::ifstream etag_in(etag_path);
  103. if (!etag_in) {
  104. LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
  105. return none;
  106. }
  107. std::string etag;
  108. std::getline(etag_in, etag);
  109. return etag;
  110. }
  111. // no etag file, but maybe there is an old .json
  112. // remove this code later
  113. const std::string metadata_path = path + ".json";
  114. if (std::filesystem::exists(metadata_path)) {
  115. std::ifstream metadata_in(metadata_path);
  116. try {
  117. nlohmann::json metadata_json;
  118. metadata_in >> metadata_json;
  119. LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
  120. metadata_json.dump().c_str());
  121. if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
  122. std::string etag = metadata_json.at("etag");
  123. write_etag(path, etag);
  124. if (!std::filesystem::remove(metadata_path)) {
  125. LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str());
  126. }
  127. return etag;
  128. }
  129. } catch (const nlohmann::json::exception & e) {
  130. LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
  131. }
  132. }
  133. return none;
  134. }
  135. static bool is_http_status_ok(int status) {
  136. return status >= 200 && status < 400;
  137. }
  138. std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag) {
  139. auto parts = string_split<std::string>(hf_repo_with_tag, ':');
  140. std::string tag = parts.size() > 1 ? parts.back() : "latest";
  141. std::string hf_repo = parts[0];
  142. if (string_split<std::string>(hf_repo, '/').size() != 2) {
  143. throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
  144. }
  145. return {hf_repo, tag};
  146. }
  147. #if defined(LLAMA_USE_HTTPLIB)
  148. class ProgressBar {
  149. static inline std::mutex mutex;
  150. static inline std::map<const ProgressBar *, int> lines;
  151. static inline int max_line = 0;
  152. static void cleanup(const ProgressBar * line) {
  153. lines.erase(line);
  154. if (lines.empty()) {
  155. max_line = 0;
  156. }
  157. }
  158. static bool is_output_a_tty() {
  159. #if defined(_WIN32)
  160. return _isatty(_fileno(stdout));
  161. #else
  162. return isatty(1);
  163. #endif
  164. }
  165. public:
  166. ProgressBar() = default;
  167. ~ProgressBar() {
  168. std::lock_guard<std::mutex> lock(mutex);
  169. cleanup(this);
  170. }
  171. void update(size_t current, size_t total) {
  172. if (!is_output_a_tty()) {
  173. return;
  174. }
  175. if (!total) {
  176. return;
  177. }
  178. std::lock_guard<std::mutex> lock(mutex);
  179. if (lines.find(this) == lines.end()) {
  180. lines[this] = max_line++;
  181. std::cout << "\n";
  182. }
  183. int lines_up = max_line - lines[this];
  184. size_t width = 50;
  185. size_t pct = (100 * current) / total;
  186. size_t pos = (width * current) / total;
  187. std::cout << "\033[s";
  188. if (lines_up > 0) {
  189. std::cout << "\033[" << lines_up << "A";
  190. }
  191. std::cout << "\033[2K\r["
  192. << std::string(pos, '=')
  193. << (pos < width ? ">" : "")
  194. << std::string(width - pos, ' ')
  195. << "] " << std::setw(3) << pct << "% ("
  196. << current / (1024 * 1024) << " MB / "
  197. << total / (1024 * 1024) << " MB) "
  198. << "\033[u";
  199. std::cout.flush();
  200. if (current == total) {
  201. cleanup(this);
  202. }
  203. }
  204. ProgressBar(const ProgressBar &) = delete;
  205. ProgressBar & operator=(const ProgressBar &) = delete;
  206. };
  207. static bool common_pull_file(httplib::Client & cli,
  208. const std::string & resolve_path,
  209. const std::string & path_tmp,
  210. bool supports_ranges,
  211. size_t existing_size,
  212. size_t & total_size) {
  213. std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app);
  214. if (!ofs.is_open()) {
  215. LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str());
  216. return false;
  217. }
  218. httplib::Headers headers;
  219. if (supports_ranges && existing_size > 0) {
  220. headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-");
  221. }
  222. const char * func = __func__; // avoid __func__ inside a lambda
  223. size_t downloaded = existing_size;
  224. size_t progress_step = 0;
  225. ProgressBar bar;
  226. auto res = cli.Get(resolve_path, headers,
  227. [&](const httplib::Response &response) {
  228. if (existing_size > 0 && response.status != 206) {
  229. LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status);
  230. return false;
  231. }
  232. if (existing_size == 0 && response.status != 200) {
  233. LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status);
  234. return false;
  235. }
  236. if (total_size == 0 && response.has_header("Content-Length")) {
  237. try {
  238. size_t content_length = std::stoull(response.get_header_value("Content-Length"));
  239. total_size = existing_size + content_length;
  240. } catch (const std::exception &e) {
  241. LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what());
  242. }
  243. }
  244. return true;
  245. },
  246. [&](const char *data, size_t len) {
  247. ofs.write(data, len);
  248. if (!ofs) {
  249. LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str());
  250. return false;
  251. }
  252. downloaded += len;
  253. progress_step += len;
  254. if (progress_step >= total_size / 1000 || downloaded == total_size) {
  255. bar.update(downloaded, total_size);
  256. progress_step = 0;
  257. }
  258. return true;
  259. },
  260. nullptr
  261. );
  262. if (!res) {
  263. LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
  264. return false;
  265. }
  266. return true;
  267. }
  268. // download one single file from remote URL to local path
  269. // returns status code or -1 on error
  270. static int common_download_file_single_online(const std::string & url,
  271. const std::string & path,
  272. const std::string & bearer_token,
  273. const common_header_list & custom_headers) {
  274. static const int max_attempts = 3;
  275. static const int retry_delay_seconds = 2;
  276. auto [cli, parts] = common_http_client(url);
  277. httplib::Headers headers;
  278. for (const auto & h : custom_headers) {
  279. headers.emplace(h.first, h.second);
  280. }
  281. if (headers.find("User-Agent") == headers.end()) {
  282. headers.emplace("User-Agent", "llama-cpp/" + build_info);
  283. }
  284. if (!bearer_token.empty()) {
  285. headers.emplace("Authorization", "Bearer " + bearer_token);
  286. }
  287. cli.set_default_headers(headers);
  288. const bool file_exists = std::filesystem::exists(path);
  289. std::string last_etag;
  290. if (file_exists) {
  291. last_etag = read_etag(path);
  292. } else {
  293. LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
  294. }
  295. for (int i = 0; i < max_attempts; ++i) {
  296. auto head = cli.Head(parts.path);
  297. bool head_ok = head && head->status >= 200 && head->status < 300;
  298. if (!head_ok) {
  299. LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
  300. if (file_exists) {
  301. LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
  302. return 304; // 304 Not Modified - fake cached response
  303. }
  304. return head->status; // cannot use cached file, return raw status code
  305. // TODO: maybe retry only on certain codes
  306. }
  307. std::string etag;
  308. if (head_ok && head->has_header("ETag")) {
  309. etag = head->get_header_value("ETag");
  310. }
  311. size_t total_size = 0;
  312. if (head_ok && head->has_header("Content-Length")) {
  313. try {
  314. total_size = std::stoull(head->get_header_value("Content-Length"));
  315. } catch (const std::exception& e) {
  316. LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
  317. }
  318. }
  319. bool supports_ranges = false;
  320. if (head_ok && head->has_header("Accept-Ranges")) {
  321. supports_ranges = head->get_header_value("Accept-Ranges") != "none";
  322. }
  323. bool should_download_from_scratch = false;
  324. if (!last_etag.empty() && !etag.empty() && last_etag != etag) {
  325. LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
  326. last_etag.c_str(), etag.c_str());
  327. should_download_from_scratch = true;
  328. }
  329. if (file_exists) {
  330. if (!should_download_from_scratch) {
  331. LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
  332. return 304; // 304 Not Modified - fake cached response
  333. }
  334. LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
  335. if (remove(path.c_str()) != 0) {
  336. LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
  337. return -1;
  338. }
  339. }
  340. const std::string path_temporary = path + ".downloadInProgress";
  341. size_t existing_size = 0;
  342. if (std::filesystem::exists(path_temporary)) {
  343. if (supports_ranges && !should_download_from_scratch) {
  344. existing_size = std::filesystem::file_size(path_temporary);
  345. } else if (remove(path_temporary.c_str()) != 0) {
  346. LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
  347. return -1;
  348. }
  349. }
  350. // start the download
  351. LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
  352. __func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
  353. const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
  354. if (!was_pull_successful) {
  355. if (i + 1 < max_attempts) {
  356. const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
  357. LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
  358. std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
  359. } else {
  360. LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
  361. }
  362. continue;
  363. }
  364. if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
  365. LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
  366. return -1;
  367. }
  368. if (!etag.empty()) {
  369. write_etag(path, etag);
  370. }
  371. return head->status; // TODO: use actual GET status?
  372. }
  373. return -1; // max attempts reached
  374. }
  375. std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
  376. const common_remote_params & params) {
  377. auto [cli, parts] = common_http_client(url);
  378. httplib::Headers headers;
  379. for (const auto & h : params.headers) {
  380. headers.emplace(h.first, h.second);
  381. }
  382. if (headers.find("User-Agent") == headers.end()) {
  383. headers.emplace("User-Agent", "llama-cpp/" + build_info);
  384. }
  385. if (params.timeout > 0) {
  386. cli.set_read_timeout(params.timeout, 0);
  387. cli.set_write_timeout(params.timeout, 0);
  388. }
  389. std::vector<char> buf;
  390. auto res = cli.Get(parts.path, headers,
  391. [&](const char *data, size_t len) {
  392. buf.insert(buf.end(), data, data + len);
  393. return params.max_size == 0 ||
  394. buf.size() <= static_cast<size_t>(params.max_size);
  395. },
  396. nullptr
  397. );
  398. if (!res) {
  399. throw std::runtime_error("error: cannot make GET request");
  400. }
  401. return { res->status, std::move(buf) };
  402. }
  403. int common_download_file_single(const std::string & url,
  404. const std::string & path,
  405. const std::string & bearer_token,
  406. bool offline,
  407. const common_header_list & headers) {
  408. if (!offline) {
  409. return common_download_file_single_online(url, path, bearer_token, headers);
  410. }
  411. if (!std::filesystem::exists(path)) {
  412. LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
  413. return -1;
  414. }
  415. LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
  416. return 304; // Not Modified - fake cached response
  417. }
  418. // download multiple files from remote URLs to local paths
  419. // the input is a vector of pairs <url, path>
  420. static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls,
  421. const std::string & bearer_token,
  422. bool offline,
  423. const common_header_list & headers) {
  424. // Prepare download in parallel
  425. std::vector<std::future<bool>> futures_download;
  426. futures_download.reserve(urls.size());
  427. for (auto const & item : urls) {
  428. futures_download.push_back(
  429. std::async(
  430. std::launch::async,
  431. [&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
  432. const int http_status = common_download_file_single(it.first, it.second, bearer_token, offline, headers);
  433. return is_http_status_ok(http_status);
  434. },
  435. item
  436. )
  437. );
  438. }
  439. // Wait for all downloads to complete
  440. for (auto & f : futures_download) {
  441. if (!f.get()) {
  442. return false;
  443. }
  444. }
  445. return true;
  446. }
  447. bool common_download_model(const common_params_model & model,
  448. const std::string & bearer_token,
  449. bool offline,
  450. const common_header_list & headers) {
  451. // Basic validation of the model.url
  452. if (model.url.empty()) {
  453. LOG_ERR("%s: invalid model url\n", __func__);
  454. return false;
  455. }
  456. const int http_status = common_download_file_single(model.url, model.path, bearer_token, offline, headers);
  457. if (!is_http_status_ok(http_status)) {
  458. return false;
  459. }
  460. // check for additional GGUFs split to download
  461. int n_split = 0;
  462. {
  463. struct gguf_init_params gguf_params = {
  464. /*.no_alloc = */ true,
  465. /*.ctx = */ NULL,
  466. };
  467. auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params);
  468. if (!ctx_gguf) {
  469. LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str());
  470. return false;
  471. }
  472. auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
  473. if (key_n_split >= 0) {
  474. n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
  475. }
  476. gguf_free(ctx_gguf);
  477. }
  478. if (n_split > 1) {
  479. char split_prefix[PATH_MAX] = {0};
  480. char split_url_prefix[LLAMA_MAX_URL_LENGTH] = {0};
  481. // Verify the first split file format
  482. // and extract split URL and PATH prefixes
  483. {
  484. if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) {
  485. LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split);
  486. return false;
  487. }
  488. if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) {
  489. LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split);
  490. return false;
  491. }
  492. }
  493. std::vector<std::pair<std::string, std::string>> urls;
  494. for (int idx = 1; idx < n_split; idx++) {
  495. char split_path[PATH_MAX] = {0};
  496. llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split);
  497. char split_url[LLAMA_MAX_URL_LENGTH] = {0};
  498. llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split);
  499. if (std::string(split_path) == model.path) {
  500. continue; // skip the already downloaded file
  501. }
  502. urls.push_back({split_url, split_path});
  503. }
  504. // Download in parallel
  505. common_download_file_multiple(urls, bearer_token, offline, headers);
  506. }
  507. return true;
  508. }
  509. common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag,
  510. const std::string & bearer_token,
  511. bool offline,
  512. const common_header_list & custom_headers) {
  513. // the returned hf_repo is without tag
  514. auto [hf_repo, tag] = common_download_split_repo_tag(hf_repo_with_tag);
  515. std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
  516. // headers
  517. common_header_list headers = custom_headers;
  518. headers.push_back({"Accept", "application/json"});
  519. if (!bearer_token.empty()) {
  520. headers.push_back({"Authorization", "Bearer " + bearer_token});
  521. }
  522. // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
  523. // User-Agent header is already set in common_remote_get_content, no need to set it here
  524. // make the request
  525. common_remote_params params;
  526. params.headers = headers;
  527. long res_code = 0;
  528. std::string res_str;
  529. bool use_cache = false;
  530. std::string cached_response_path = get_manifest_path(hf_repo, tag);
  531. if (!offline) {
  532. try {
  533. auto res = common_remote_get_content(url, params);
  534. res_code = res.first;
  535. res_str = std::string(res.second.data(), res.second.size());
  536. } catch (const std::exception & e) {
  537. LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what());
  538. }
  539. }
  540. if (res_code == 0) {
  541. if (std::filesystem::exists(cached_response_path)) {
  542. LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str());
  543. res_str = read_file(cached_response_path);
  544. res_code = 200;
  545. use_cache = true;
  546. } else {
  547. throw std::runtime_error(
  548. offline ? "error: failed to get manifest (offline mode)"
  549. : "error: failed to get manifest (check your internet connection)");
  550. }
  551. }
  552. std::string ggufFile;
  553. std::string mmprojFile;
  554. if (res_code == 200 || res_code == 304) {
  555. try {
  556. auto j = json::parse(res_str);
  557. if (j.contains("ggufFile") && j["ggufFile"].contains("rfilename")) {
  558. ggufFile = j["ggufFile"]["rfilename"].get<std::string>();
  559. }
  560. if (j.contains("mmprojFile") && j["mmprojFile"].contains("rfilename")) {
  561. mmprojFile = j["mmprojFile"]["rfilename"].get<std::string>();
  562. }
  563. } catch (const std::exception & e) {
  564. throw std::runtime_error(std::string("error parsing manifest JSON: ") + e.what());
  565. }
  566. if (!use_cache) {
  567. // if not using cached response, update the cache file
  568. write_file(cached_response_path, res_str);
  569. }
  570. } else if (res_code == 401) {
  571. throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
  572. } else {
  573. throw std::runtime_error(string_format("error from HF API (%s), response code: %ld, data: %s", url.c_str(), res_code, res_str.c_str()));
  574. }
  575. // check response
  576. if (ggufFile.empty()) {
  577. throw std::runtime_error("error: model does not have ggufFile");
  578. }
  579. return { hf_repo, ggufFile, mmprojFile };
  580. }
  581. //
  582. // Docker registry functions
  583. //
  584. static std::string common_docker_get_token(const std::string & repo) {
  585. std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
  586. common_remote_params params;
  587. auto res = common_remote_get_content(url, params);
  588. if (res.first != 200) {
  589. throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
  590. }
  591. std::string response_str(res.second.begin(), res.second.end());
  592. nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
  593. if (!response.contains("token")) {
  594. throw std::runtime_error("Docker registry token response missing 'token' field");
  595. }
  596. return response["token"].get<std::string>();
  597. }
  598. std::string common_docker_resolve_model(const std::string & docker) {
  599. // Parse ai/smollm2:135M-Q4_0
  600. size_t colon_pos = docker.find(':');
  601. std::string repo, tag;
  602. if (colon_pos != std::string::npos) {
  603. repo = docker.substr(0, colon_pos);
  604. tag = docker.substr(colon_pos + 1);
  605. } else {
  606. repo = docker;
  607. tag = "latest";
  608. }
  609. // ai/ is the default
  610. size_t slash_pos = docker.find('/');
  611. if (slash_pos == std::string::npos) {
  612. repo.insert(0, "ai/");
  613. }
  614. LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
  615. try {
  616. // --- helper: digest validation ---
  617. auto validate_oci_digest = [](const std::string & digest) -> std::string {
  618. // Expected: algo:hex ; start with sha256 (64 hex chars)
  619. // You can extend this map if supporting other algorithms in future.
  620. static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
  621. std::smatch m;
  622. if (!std::regex_match(digest, m, re)) {
  623. throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
  624. }
  625. // normalize hex to lowercase
  626. std::string normalized = digest;
  627. std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
  628. return std::tolower(c);
  629. });
  630. return normalized;
  631. };
  632. std::string token = common_docker_get_token(repo); // Get authentication token
  633. // Get manifest
  634. // TODO: cache the manifest response so that it appears in the model list
  635. const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
  636. std::string manifest_url = url_prefix + "/manifests/" + tag;
  637. common_remote_params manifest_params;
  638. manifest_params.headers.push_back({"Authorization", "Bearer " + token});
  639. manifest_params.headers.push_back({"Accept",
  640. "application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"
  641. });
  642. auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
  643. if (manifest_res.first != 200) {
  644. throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
  645. }
  646. std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
  647. nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
  648. std::string gguf_digest; // Find the GGUF layer
  649. if (manifest.contains("layers")) {
  650. for (const auto & layer : manifest["layers"]) {
  651. if (layer.contains("mediaType")) {
  652. std::string media_type = layer["mediaType"].get<std::string>();
  653. if (media_type == "application/vnd.docker.ai.gguf.v3" ||
  654. media_type.find("gguf") != std::string::npos) {
  655. gguf_digest = layer["digest"].get<std::string>();
  656. break;
  657. }
  658. }
  659. }
  660. }
  661. if (gguf_digest.empty()) {
  662. throw std::runtime_error("No GGUF layer found in Docker manifest");
  663. }
  664. // Validate & normalize digest
  665. gguf_digest = validate_oci_digest(gguf_digest);
  666. LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());
  667. // Prepare local filename
  668. std::string model_filename = repo;
  669. std::replace(model_filename.begin(), model_filename.end(), '/', '_');
  670. model_filename += "_" + tag + ".gguf";
  671. std::string local_path = fs_get_cache_file(model_filename);
  672. const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
  673. const int http_status = common_download_file_single(blob_url, local_path, token, false, {});
  674. if (!is_http_status_ok(http_status)) {
  675. throw std::runtime_error("Failed to download Docker Model");
  676. }
  677. LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
  678. return local_path;
  679. } catch (const std::exception & e) {
  680. LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
  681. throw;
  682. }
  683. }
  684. #else
  685. common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) {
  686. throw std::runtime_error("download functionality is not enabled in this build");
  687. }
  688. bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) {
  689. throw std::runtime_error("download functionality is not enabled in this build");
  690. }
  691. std::string common_docker_resolve_model(const std::string &) {
  692. throw std::runtime_error("download functionality is not enabled in this build");
  693. }
  694. int common_download_file_single(const std::string &,
  695. const std::string &,
  696. const std::string &,
  697. bool,
  698. const common_header_list &) {
  699. throw std::runtime_error("download functionality is not enabled in this build");
  700. }
  701. #endif // defined(LLAMA_USE_HTTPLIB)
  702. std::vector<common_cached_model_info> common_list_cached_models() {
  703. std::vector<common_cached_model_info> models;
  704. const std::string cache_dir = fs_get_cache_directory();
  705. const std::vector<common_file_info> files = fs_list(cache_dir, false);
  706. for (const auto & file : files) {
  707. if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
  708. common_cached_model_info model_info;
  709. model_info.manifest_path = file.path;
  710. std::string fname = file.name;
  711. string_replace_all(fname, ".json", ""); // remove extension
  712. auto parts = string_split<std::string>(fname, '=');
  713. if (parts.size() == 4) {
  714. // expect format: manifest=<user>=<model>=<tag>=<other>
  715. model_info.user = parts[1];
  716. model_info.model = parts[2];
  717. model_info.tag = parts[3];
  718. } else {
  719. // invalid format
  720. continue;
  721. }
  722. model_info.size = 0; // TODO: get GGUF size, not manifest size
  723. models.push_back(model_info);
  724. }
  725. }
  726. return models;
  727. }