Browse Source

Add --cfg-negative-prompt-file option for examples (#2591)

Add --cfg-negative-prompt-file option for examples
Kerfuffle 2 years ago
parent
commit
8dae7ce684
1 changed files with 18 additions and 1 deletions
  1. 18 1
      examples/common.cpp

+ 18 - 1
examples/common.cpp

@@ -274,6 +274,21 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
                 break;
             }
             }
             params.cfg_negative_prompt = argv[i];
             params.cfg_negative_prompt = argv[i];
+        } else if (arg == "--cfg-negative-prompt-file") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            std::ifstream file(argv[i]);
+            if (!file) {
+                fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+                invalid_param = true;
+                break;
+            }
+            std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.cfg_negative_prompt));
+            if (params.cfg_negative_prompt.back() == '\n') {
+                params.cfg_negative_prompt.pop_back();
+            }
         } else if (arg == "--cfg-scale") {
         } else if (arg == "--cfg-scale") {
             if (++i >= argc) {
             if (++i >= argc) {
                 invalid_param = true;
                 invalid_param = true;
@@ -567,8 +582,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stdout, "                        or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
     fprintf(stdout, "                        or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
     fprintf(stdout, "  --grammar GRAMMAR     BNF-like grammar to constrain generations (see samples in grammars/ dir)\n");
     fprintf(stdout, "  --grammar GRAMMAR     BNF-like grammar to constrain generations (see samples in grammars/ dir)\n");
     fprintf(stdout, "  --grammar-file FNAME  file to read grammar from\n");
     fprintf(stdout, "  --grammar-file FNAME  file to read grammar from\n");
-    fprintf(stdout, "  --cfg-negative-prompt PROMPT \n");
+    fprintf(stdout, "  --cfg-negative-prompt PROMPT\n");
     fprintf(stdout, "                        negative prompt to use for guidance. (default: empty)\n");
     fprintf(stdout, "                        negative prompt to use for guidance. (default: empty)\n");
+    fprintf(stdout, "  --cfg-negative-prompt-file FNAME\n");
+    fprintf(stdout, "                        negative prompt file to use for guidance. (default: empty)\n");
     fprintf(stdout, "  --cfg-scale N         strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
     fprintf(stdout, "  --cfg-scale N         strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
     fprintf(stdout, "  --rope-scale N        RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale);
     fprintf(stdout, "  --rope-scale N        RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale);
     fprintf(stdout, "  --rope-freq-base N    RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base);
     fprintf(stdout, "  --rope-freq-base N    RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base);