|
|
@@ -11,12 +11,16 @@ int main(int argc, char ** argv) {
|
|
|
gpt_params params;
|
|
|
|
|
|
if (argc == 1 || argv[1][0] == '-') {
|
|
|
- printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL]\n" , argv[0]);
|
|
|
+ printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN]\n" , argv[0]);
|
|
|
return 1 ;
|
|
|
}
|
|
|
|
|
|
+ // number of parallel batches
|
|
|
int n_parallel = 1;
|
|
|
|
|
|
+ // total length of the sequences including the prompt
|
|
|
+ int n_len = 32;
|
|
|
+
|
|
|
if (argc >= 2) {
|
|
|
params.model = argv[1];
|
|
|
}
|
|
|
@@ -29,13 +33,14 @@ int main(int argc, char ** argv) {
|
|
|
n_parallel = std::atoi(argv[3]);
|
|
|
}
|
|
|
|
|
|
+ if (argc >= 5) {
|
|
|
+ n_len = std::atoi(argv[4]);
|
|
|
+ }
|
|
|
+
|
|
|
if (params.prompt.empty()) {
|
|
|
params.prompt = "Hello my name is";
|
|
|
}
|
|
|
|
|
|
- // total length of the sequences including the prompt
|
|
|
- const int n_len = 32;
|
|
|
-
|
|
|
// init LLM
|
|
|
|
|
|
llama_backend_init(params.numa);
|