preset.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. #include "arg.h"
  2. #include "preset.h"
  3. #include "peg-parser.h"
  4. #include "log.h"
  5. #include "download.h"
  6. #include <fstream>
  7. #include <sstream>
  8. #include <filesystem>
  9. static std::string rm_leading_dashes(const std::string & str) {
  10. size_t pos = 0;
  11. while (pos < str.size() && str[pos] == '-') {
  12. ++pos;
  13. }
  14. return str.substr(pos);
  15. }
  16. // only allow a subset of args for remote presets for security reasons
  17. // do not add more args unless absolutely necessary
  18. // args that output to files are strictly prohibited
  19. static std::set<std::string> get_remote_preset_whitelist(const std::map<std::string, common_arg> & key_to_opt) {
  20. static const std::set<std::string> allowed_options = {
  21. "model-url",
  22. "hf-repo",
  23. "hf-repo-draft",
  24. "hf-repo-v", // vocoder
  25. "hf-file-v", // vocoder
  26. "mmproj-url",
  27. "pooling",
  28. "jinja",
  29. "batch-size",
  30. "ubatch-size",
  31. "cache-reuse",
  32. "chat-template-kwargs",
  33. "mmap",
  34. // note: sampling params are automatically allowed by default
  35. // negated args will be added automatically if the positive arg is specified above
  36. };
  37. std::set<std::string> allowed_keys;
  38. for (const auto & it : key_to_opt) {
  39. const std::string & key = it.first;
  40. const common_arg & opt = it.second;
  41. if (allowed_options.find(key) != allowed_options.end() || opt.is_sparam) {
  42. allowed_keys.insert(key);
  43. // also add variant keys (args without leading dashes and env vars)
  44. for (const auto & arg : opt.get_args()) {
  45. allowed_keys.insert(rm_leading_dashes(arg));
  46. }
  47. for (const auto & env : opt.get_env()) {
  48. allowed_keys.insert(env);
  49. }
  50. }
  51. }
  52. return allowed_keys;
  53. }
  54. std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
  55. std::vector<std::string> args;
  56. if (!bin_path.empty()) {
  57. args.push_back(bin_path);
  58. }
  59. for (const auto & [opt, value] : options) {
  60. if (opt.is_preset_only) {
  61. continue; // skip preset-only options (they are not CLI args)
  62. }
  63. // use the last arg as the main arg (i.e. --long-form)
  64. args.push_back(opt.args.back());
  65. // handle value(s)
  66. if (opt.value_hint == nullptr && opt.value_hint_2 == nullptr) {
  67. // flag option, no value
  68. if (common_arg_utils::is_falsey(value)) {
  69. // use negative arg if available
  70. if (!opt.args_neg.empty()) {
  71. args.back() = opt.args_neg.back();
  72. } else {
  73. // otherwise, skip the flag
  74. // TODO: maybe throw an error instead?
  75. args.pop_back();
  76. }
  77. }
  78. }
  79. if (opt.value_hint != nullptr) {
  80. // single value
  81. args.push_back(value);
  82. }
  83. if (opt.value_hint != nullptr && opt.value_hint_2 != nullptr) {
  84. throw std::runtime_error(string_format(
  85. "common_preset::to_args(): option '%s' has two values, which is not supported yet",
  86. opt.args.back()
  87. ));
  88. }
  89. }
  90. return args;
  91. }
  92. std::string common_preset::to_ini() const {
  93. std::ostringstream ss;
  94. ss << "[" << name << "]\n";
  95. for (const auto & [opt, value] : options) {
  96. auto espaced_value = value;
  97. string_replace_all(espaced_value, "\n", "\\\n");
  98. ss << rm_leading_dashes(opt.args.back()) << " = ";
  99. ss << espaced_value << "\n";
  100. }
  101. ss << "\n";
  102. return ss.str();
  103. }
  104. void common_preset::set_option(const common_preset_context & ctx, const std::string & env, const std::string & value) {
  105. // try if option exists, update it
  106. for (auto & [opt, val] : options) {
  107. if (opt.env && env == opt.env) {
  108. val = value;
  109. return;
  110. }
  111. }
  112. // if option does not exist, we need to add it
  113. if (ctx.key_to_opt.find(env) == ctx.key_to_opt.end()) {
  114. throw std::runtime_error(string_format(
  115. "%s: option with env '%s' not found in ctx_params",
  116. __func__, env.c_str()
  117. ));
  118. }
  119. options[ctx.key_to_opt.at(env)] = value;
  120. }
  121. void common_preset::unset_option(const std::string & env) {
  122. for (auto it = options.begin(); it != options.end(); ) {
  123. const common_arg & opt = it->first;
  124. if (opt.env && env == opt.env) {
  125. it = options.erase(it);
  126. return;
  127. } else {
  128. ++it;
  129. }
  130. }
  131. }
  132. bool common_preset::get_option(const std::string & env, std::string & value) const {
  133. for (const auto & [opt, val] : options) {
  134. if (opt.env && env == opt.env) {
  135. value = val;
  136. return true;
  137. }
  138. }
  139. return false;
  140. }
  141. void common_preset::merge(const common_preset & other) {
  142. for (const auto & [opt, val] : other.options) {
  143. options[opt] = val; // overwrite existing options
  144. }
  145. }
  146. void common_preset::apply_to_params(common_params & params) const {
  147. for (const auto & [opt, val] : options) {
  148. // apply each option to params
  149. if (opt.handler_string) {
  150. opt.handler_string(params, val);
  151. } else if (opt.handler_int) {
  152. opt.handler_int(params, std::stoi(val));
  153. } else if (opt.handler_bool) {
  154. opt.handler_bool(params, common_arg_utils::is_truthy(val));
  155. } else if (opt.handler_str_str) {
  156. // not supported yet
  157. throw std::runtime_error(string_format(
  158. "%s: option with two values is not supported yet",
  159. __func__
  160. ));
  161. } else if (opt.handler_void) {
  162. opt.handler_void(params);
  163. } else {
  164. GGML_ABORT("unknown handler type");
  165. }
  166. }
  167. }
  168. static std::map<std::string, std::map<std::string, std::string>> parse_ini_from_file(const std::string & path) {
  169. std::map<std::string, std::map<std::string, std::string>> parsed;
  170. if (!std::filesystem::exists(path)) {
  171. throw std::runtime_error("preset file does not exist: " + path);
  172. }
  173. std::ifstream file(path);
  174. if (!file.good()) {
  175. throw std::runtime_error("failed to open server preset file: " + path);
  176. }
  177. std::string contents((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
  178. static const auto parser = build_peg_parser([](auto & p) {
  179. // newline ::= "\r\n" / "\n" / "\r"
  180. auto newline = p.rule("newline", p.literal("\r\n") | p.literal("\n") | p.literal("\r"));
  181. // ws ::= [ \t]*
  182. auto ws = p.rule("ws", p.chars("[ \t]", 0, -1));
  183. // comment ::= [;#] (!newline .)*
  184. auto comment = p.rule("comment", p.chars("[;#]", 1, 1) + p.zero_or_more(p.negate(newline) + p.any()));
  185. // eol ::= ws comment? (newline / EOF)
  186. auto eol = p.rule("eol", ws + p.optional(comment) + (newline | p.end()));
  187. // ident ::= [a-zA-Z_] [a-zA-Z0-9_.-]*
  188. auto ident = p.rule("ident", p.chars("[a-zA-Z_]", 1, 1) + p.chars("[a-zA-Z0-9_.-]", 0, -1));
  189. // value ::= (!eol-start .)*
  190. auto eol_start = p.rule("eol-start", ws + (p.chars("[;#]", 1, 1) | newline | p.end()));
  191. auto value = p.rule("value", p.zero_or_more(p.negate(eol_start) + p.any()));
  192. // header-line ::= "[" ws ident ws "]" eol
  193. auto header_line = p.rule("header-line", "[" + ws + p.tag("section-name", p.chars("[^]]")) + ws + "]" + eol);
  194. // kv-line ::= ident ws "=" ws value eol
  195. auto kv_line = p.rule("kv-line", p.tag("key", ident) + ws + "=" + ws + p.tag("value", value) + eol);
  196. // comment-line ::= ws comment (newline / EOF)
  197. auto comment_line = p.rule("comment-line", ws + comment + (newline | p.end()));
  198. // blank-line ::= ws (newline / EOF)
  199. auto blank_line = p.rule("blank-line", ws + (newline | p.end()));
  200. // line ::= header-line / kv-line / comment-line / blank-line
  201. auto line = p.rule("line", header_line | kv_line | comment_line | blank_line);
  202. // ini ::= line* EOF
  203. auto ini = p.rule("ini", p.zero_or_more(line) + p.end());
  204. return ini;
  205. });
  206. common_peg_parse_context ctx(contents);
  207. const auto result = parser.parse(ctx);
  208. if (!result.success()) {
  209. throw std::runtime_error("failed to parse server config file: " + path);
  210. }
  211. std::string current_section = COMMON_PRESET_DEFAULT_NAME;
  212. std::string current_key;
  213. ctx.ast.visit(result, [&](const auto & node) {
  214. if (node.tag == "section-name") {
  215. const std::string section = std::string(node.text);
  216. current_section = section;
  217. parsed[current_section] = {};
  218. } else if (node.tag == "key") {
  219. const std::string key = std::string(node.text);
  220. current_key = key;
  221. } else if (node.tag == "value" && !current_key.empty() && !current_section.empty()) {
  222. parsed[current_section][current_key] = std::string(node.text);
  223. current_key.clear();
  224. }
  225. });
  226. return parsed;
  227. }
  228. static std::map<std::string, common_arg> get_map_key_opt(common_params_context & ctx_params) {
  229. std::map<std::string, common_arg> mapping;
  230. for (const auto & opt : ctx_params.options) {
  231. for (const auto & env : opt.get_env()) {
  232. mapping[env] = opt;
  233. }
  234. for (const auto & arg : opt.get_args()) {
  235. mapping[rm_leading_dashes(arg)] = opt;
  236. }
  237. }
  238. return mapping;
  239. }
  240. static bool is_bool_arg(const common_arg & arg) {
  241. return !arg.args_neg.empty();
  242. }
  243. static std::string parse_bool_arg(const common_arg & arg, const std::string & key, const std::string & value) {
  244. // if this is a negated arg, we need to reverse the value
  245. for (const auto & neg_arg : arg.args_neg) {
  246. if (rm_leading_dashes(neg_arg) == key) {
  247. return common_arg_utils::is_truthy(value) ? "false" : "true";
  248. }
  249. }
  250. // otherwise, not negated
  251. return value;
  252. }
  253. common_preset_context::common_preset_context(llama_example ex, bool only_remote_allowed)
  254. : ctx_params(common_params_parser_init(default_params, ex)) {
  255. common_params_add_preset_options(ctx_params.options);
  256. key_to_opt = get_map_key_opt(ctx_params);
  257. // setup allowed keys if only_remote_allowed is true
  258. if (only_remote_allowed) {
  259. filter_allowed_keys = true;
  260. allowed_keys = get_remote_preset_whitelist(key_to_opt);
  261. }
  262. }
  263. common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
  264. common_presets out;
  265. auto ini_data = parse_ini_from_file(path);
  266. for (auto section : ini_data) {
  267. common_preset preset;
  268. if (section.first.empty()) {
  269. preset.name = COMMON_PRESET_DEFAULT_NAME;
  270. } else {
  271. preset.name = section.first;
  272. }
  273. LOG_DBG("loading preset: %s\n", preset.name.c_str());
  274. for (const auto & [key, value] : section.second) {
  275. if (key == "version") {
  276. // skip version key (reserved for future use)
  277. continue;
  278. }
  279. LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str());
  280. if (filter_allowed_keys && allowed_keys.find(key) == allowed_keys.end()) {
  281. throw std::runtime_error(string_format(
  282. "option '%s' is not allowed in remote presets",
  283. key.c_str()
  284. ));
  285. }
  286. if (key_to_opt.find(key) != key_to_opt.end()) {
  287. const auto & opt = key_to_opt.at(key);
  288. if (is_bool_arg(opt)) {
  289. preset.options[opt] = parse_bool_arg(opt, key, value);
  290. } else {
  291. preset.options[opt] = value;
  292. }
  293. LOG_DBG("accepted option: %s = %s\n", key.c_str(), preset.options[opt].c_str());
  294. } else {
  295. throw std::runtime_error(string_format(
  296. "option '%s' not recognized in preset '%s'",
  297. key.c_str(), preset.name.c_str()
  298. ));
  299. }
  300. }
  301. if (preset.name == "*") {
  302. // handle global preset
  303. global = preset;
  304. } else {
  305. out[preset.name] = preset;
  306. }
  307. }
  308. return out;
  309. }
  310. common_presets common_preset_context::load_from_cache() const {
  311. common_presets out;
  312. auto cached_models = common_list_cached_models();
  313. for (const auto & model : cached_models) {
  314. common_preset preset;
  315. preset.name = model.to_string();
  316. preset.set_option(*this, "LLAMA_ARG_HF_REPO", model.to_string());
  317. out[preset.name] = preset;
  318. }
  319. return out;
  320. }
  321. struct local_model {
  322. std::string name;
  323. std::string path;
  324. std::string path_mmproj;
  325. };
  326. common_presets common_preset_context::load_from_models_dir(const std::string & models_dir) const {
  327. if (!std::filesystem::exists(models_dir) || !std::filesystem::is_directory(models_dir)) {
  328. throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", models_dir.c_str()));
  329. }
  330. std::vector<local_model> models;
  331. auto scan_subdir = [&models](const std::string & subdir_path, const std::string & name) {
  332. auto files = fs_list(subdir_path, false);
  333. common_file_info model_file;
  334. common_file_info first_shard_file;
  335. common_file_info mmproj_file;
  336. for (const auto & file : files) {
  337. if (string_ends_with(file.name, ".gguf")) {
  338. if (file.name.find("mmproj") != std::string::npos) {
  339. mmproj_file = file;
  340. } else if (file.name.find("-00001-of-") != std::string::npos) {
  341. first_shard_file = file;
  342. } else {
  343. model_file = file;
  344. }
  345. }
  346. }
  347. // single file model
  348. local_model model{
  349. /* name */ name,
  350. /* path */ first_shard_file.path.empty() ? model_file.path : first_shard_file.path,
  351. /* path_mmproj */ mmproj_file.path // can be empty
  352. };
  353. if (!model.path.empty()) {
  354. models.push_back(model);
  355. }
  356. };
  357. auto files = fs_list(models_dir, true);
  358. for (const auto & file : files) {
  359. if (file.is_dir) {
  360. scan_subdir(file.path, file.name);
  361. } else if (string_ends_with(file.name, ".gguf")) {
  362. // single file model
  363. std::string name = file.name;
  364. string_replace_all(name, ".gguf", "");
  365. local_model model{
  366. /* name */ name,
  367. /* path */ file.path,
  368. /* path_mmproj */ ""
  369. };
  370. models.push_back(model);
  371. }
  372. }
  373. // convert local models to presets
  374. common_presets out;
  375. for (const auto & model : models) {
  376. common_preset preset;
  377. preset.name = model.name;
  378. preset.set_option(*this, "LLAMA_ARG_MODEL", model.path);
  379. if (!model.path_mmproj.empty()) {
  380. preset.set_option(*this, "LLAMA_ARG_MMPROJ", model.path_mmproj);
  381. }
  382. out[preset.name] = preset;
  383. }
  384. return out;
  385. }
  386. common_preset common_preset_context::load_from_args(int argc, char ** argv) const {
  387. common_preset preset;
  388. preset.name = COMMON_PRESET_DEFAULT_NAME;
  389. bool ok = common_params_to_map(argc, argv, ctx_params.ex, preset.options);
  390. if (!ok) {
  391. throw std::runtime_error("failed to parse CLI arguments into preset");
  392. }
  393. return preset;
  394. }
  395. common_presets common_preset_context::cascade(const common_presets & base, const common_presets & added) const {
  396. common_presets out = base; // copy
  397. for (const auto & [name, preset_added] : added) {
  398. if (out.find(name) != out.end()) {
  399. // if exists, merge
  400. common_preset & target = out[name];
  401. target.merge(preset_added);
  402. } else {
  403. // otherwise, add directly
  404. out[name] = preset_added;
  405. }
  406. }
  407. return out;
  408. }
  409. common_presets common_preset_context::cascade(const common_preset & base, const common_presets & presets) const {
  410. common_presets out;
  411. for (const auto & [name, preset] : presets) {
  412. common_preset tmp = base; // copy
  413. tmp.name = name;
  414. tmp.merge(preset);
  415. out[name] = std::move(tmp);
  416. }
  417. return out;
  418. }