Ben Garney 2 лет назад
Родитель
Сommit
f385f8dee8
1 измененных файлов с 14 добавлено и 0 удалено
  1. 14 0
      utils.cpp

+ 14 - 0
utils.cpp

@@ -4,6 +4,10 @@
 #include <cstring>
 #include <cstring>
 #include <fstream>
 #include <fstream>
 #include <regex>
 #include <regex>
+#include <iostream>
+#include <iterator>
+#include <string>
+#include <math.h>
 
 
  #if defined(_MSC_VER) || defined(__MINGW32__)
  #if defined(_MSC_VER) || defined(__MINGW32__)
  #include <malloc.h> // using malloc.h with MSC/MINGW
  #include <malloc.h> // using malloc.h with MSC/MINGW
@@ -21,6 +25,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.n_threads = std::stoi(argv[++i]);
             params.n_threads = std::stoi(argv[++i]);
         } else if (arg == "-p" || arg == "--prompt") {
         } else if (arg == "-p" || arg == "--prompt") {
             params.prompt = argv[++i];
             params.prompt = argv[++i];
+        } else if (arg == "-f" || arg == "--file") {
+
+            std::ifstream file(argv[++i]);
+
+            std::copy(std::istreambuf_iterator<char>(file),
+                    std::istreambuf_iterator<char>(),
+                    back_inserter(params.prompt));
+                
         } else if (arg == "-n" || arg == "--n_predict") {
         } else if (arg == "-n" || arg == "--n_predict") {
             params.n_predict = std::stoi(argv[++i]);
             params.n_predict = std::stoi(argv[++i]);
         } else if (arg == "--top_k") {
         } else if (arg == "--top_k") {
@@ -59,6 +71,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
     fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
     fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
     fprintf(stderr, "  -p PROMPT, --prompt PROMPT\n");
     fprintf(stderr, "  -p PROMPT, --prompt PROMPT\n");
     fprintf(stderr, "                        prompt to start generation with (default: random)\n");
     fprintf(stderr, "                        prompt to start generation with (default: random)\n");
+    fprintf(stderr, "  -f FNAME, --file FNAME\n");
+    fprintf(stderr, "                        prompt file to start generation.\n");
     fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d)\n", params.n_predict);
     fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d)\n", params.n_predict);
     fprintf(stderr, "  --top_k N             top-k sampling (default: %d)\n", params.top_k);
     fprintf(stderr, "  --top_k N             top-k sampling (default: %d)\n", params.top_k);
     fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", params.top_p);
     fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", params.top_p);