Răsfoiți Sursa

server: add rms_norm_eps parameter (#2380)

slaren 2 ani în urmă
părinte
comite
d5512b782b
1 a modificat fișierele cu 9 adăugiri și 0 ștergeri
  1. 9 0
      examples/server/server.cpp

+ 9 - 0
examples/server/server.cpp

@@ -609,6 +609,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
     fprintf(stdout, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
     fprintf(stdout, "  -c N, --ctx-size N    size of the prompt context (default: %d)\n", params.n_ctx);
     fprintf(stdout, "  -gqa N, --gqa N       grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
+    fprintf(stdout, "  -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
     fprintf(stdout, "  --rope-freq-base N    RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
     fprintf(stdout, "  --rope-freq-scale N   RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
     fprintf(stdout, "  -b N, --batch-size N  batch size for prompt processing (default: %d)\n", params.n_batch);
@@ -734,6 +735,14 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
             }
             params.n_gqa = std::stoi(argv[i]);
         }
+        else if (arg == "-eps" || arg == "--rms-norm-eps") {
+            if (++i >= argc)
+            {
+                invalid_param = true;
+                break;
+            }
+            params.rms_norm_eps = std::stof(argv[i]);
+        }
         else if (arg == "--rope-freq-base")
         {
             if (++i >= argc)