Explorar o código

arg: fix common_params_parse not accepting negated arg (#17991)

Xuan-Son Nguyen hai 1 mes
pai
achega
4d5ae24c0a
Modificáronse 4 ficheiros con 10 adicións e 3 borrados
  1. 4 1
      common/arg.cpp
  2. 1 1
      common/arg.h
  3. 4 0
      tests/test-arg-parser.cpp
  4. 1 1
      tools/server/server-models.cpp

+ 4 - 1
common/arg.cpp

@@ -724,7 +724,7 @@ static void add_rpc_devices(const std::string & servers) {
     }
 }
 
-bool common_params_parse(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map) {
+bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map) {
     common_params dummy_params;
     common_params_context ctx_arg = common_params_parser_init(dummy_params, ex, nullptr);
 
@@ -733,6 +733,9 @@ bool common_params_parse(int argc, char ** argv, llama_example ex, std::map<comm
         for (const auto & arg : opt.args) {
             arg_to_options[arg] = &opt;
         }
+        for (const auto & arg : opt.args_neg) {
+            arg_to_options[arg] = &opt;
+        }
     }
 
     // TODO @ngxson : find a way to deduplicate this code

+ 1 - 1
common/arg.h

@@ -115,7 +115,7 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
 
 // parse input arguments from CLI into a map
 // TODO: support repeated args in the future
-bool common_params_parse(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map);
+bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map);
 
 // initialize argument parser context - used by test-arg-parser and preset
 common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);

+ 4 - 0
tests/test-arg-parser.cpp

@@ -72,6 +72,10 @@ int main(void) {
     argv = {"binary_name", "--draft", "123"};
     assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_EMBEDDING));
 
+    // negated arg
+    argv = {"binary_name", "--no-mmap"};
+    assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
+
 
     printf("test-arg-parser: test valid usage\n\n");
 

+ 1 - 1
tools/server/server-models.cpp

@@ -171,7 +171,7 @@ server_presets::server_presets(int argc, char ** argv, common_params & base_para
     }
 
     // read base args from router's argv
-    common_params_parse(argc, argv, LLAMA_EXAMPLE_SERVER, base_args);
+    common_params_to_map(argc, argv, LLAMA_EXAMPLE_SERVER, base_args);
 
     // remove any router-controlled args from base_args
     for (const auto & cargs : control_args) {