Browse Source

cmdline option for custom amount of model parts (--n_parts N) (#348)

* cmdline option for custom amount of model parts (--n_parts N)

* Update main.cpp

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
anzz1 2 years ago
parent
commit
975d2cebf9
3 changed files with 12 additions and 5 deletions
  1. 7 4
      main.cpp
  2. 3 0
      utils.cpp
  3. 2 1
      utils.h

+ 7 - 4
main.cpp

@@ -90,7 +90,8 @@ struct llama_model {
 };
 };
 
 
 // load the model's weights from a file
 // load the model's weights from a file
-bool llama_model_load(const std::string & fname, llama_model & model, llama_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32) {
+
+bool llama_model_load(const std::string & fname, llama_model & model, llama_vocab & vocab, int n_ctx, int n_parts, ggml_type memory_type = GGML_TYPE_F32) {
     fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
     fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
 
 
     std::vector<char> f_buf(1024*1024);
     std::vector<char> f_buf(1024*1024);
@@ -127,7 +128,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca
     }
     }
 
 
     int n_ff = 0;
     int n_ff = 0;
-    int n_parts = 0;
 
 
     // load hparams
     // load hparams
     {
     {
@@ -145,7 +145,10 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca
         hparams.n_ctx = n_ctx;
         hparams.n_ctx = n_ctx;
 
 
         n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
         n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
-        n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
+
+        if (n_parts < 1) {
+            n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
+        }
 
 
         fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
         fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
         fprintf(stderr, "%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
         fprintf(stderr, "%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
@@ -839,7 +842,7 @@ int main(int argc, char ** argv) {
     {
     {
         const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
         const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
         const int64_t t_start_us = ggml_time_us();
         const int64_t t_start_us = ggml_time_us();
-        if (!llama_model_load(params.model, model, vocab, params.n_ctx, memory_type)) {
+        if (!llama_model_load(params.model, model, vocab, params.n_ctx, params.n_parts, memory_type)) {
             fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
             fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
             return 1;
             return 1;
         }
         }

+ 3 - 0
utils.cpp

@@ -74,6 +74,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.antiprompt.push_back(argv[++i]);
             params.antiprompt.push_back(argv[++i]);
         } else if (arg == "--ignore-eos") {
         } else if (arg == "--ignore-eos") {
             params.ignore_eos = true;
             params.ignore_eos = true;
+        } else if (arg == "--n_parts") {
+            params.n_parts = std::stoi(argv[++i]);
         } else if (arg == "-h" || arg == "--help") {
         } else if (arg == "-h" || arg == "--help") {
             gpt_print_usage(argc, argv, params);
             gpt_print_usage(argc, argv, params);
             exit(0);
             exit(0);
@@ -116,6 +118,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "  --ignore-eos          ignore end of stream token and continue generating\n");
     fprintf(stderr, "  --ignore-eos          ignore end of stream token and continue generating\n");
     fprintf(stderr, "  --memory_f16          use f16 instead of f32 for memory key+value\n");
     fprintf(stderr, "  --memory_f16          use f16 instead of f32 for memory key+value\n");
     fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", params.temp);
     fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", params.temp);
+    fprintf(stderr, "  --n_parts N           number of model parts (default: -1 = determine from dimensions)\n");
     fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch);
     fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch);
     fprintf(stderr, "  -m FNAME, --model FNAME\n");
     fprintf(stderr, "  -m FNAME, --model FNAME\n");
     fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
     fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());

+ 2 - 1
utils.h

@@ -13,10 +13,11 @@
 //
 //
 
 
 struct gpt_params {
 struct gpt_params {
-    int32_t seed          = -1; // RNG seed
+    int32_t seed          = -1;  // RNG seed
     int32_t n_threads     = std::min(4, (int32_t) std::thread::hardware_concurrency());
     int32_t n_threads     = std::min(4, (int32_t) std::thread::hardware_concurrency());
     int32_t n_predict     = 128; // new tokens to predict
     int32_t n_predict     = 128; // new tokens to predict
     int32_t repeat_last_n = 64;  // last n tokens to penalize
     int32_t repeat_last_n = 64;  // last n tokens to penalize
+    int32_t n_parts       = -1;  // amount of model parts (-1 = determine from model dimensions)
     int32_t n_ctx         = 512; //context size
     int32_t n_ctx         = 512; //context size
 
 
     // sampling parameters
     // sampling parameters