|
|
@@ -986,7 +986,12 @@ int main(int argc, char ** argv) {
|
|
|
test t(inst, lmodel, ctx);
|
|
|
|
|
|
// warmup run
|
|
|
- test_gen(ctx, 1, 0, t.n_threads);
|
|
|
+ if (t.n_prompt > 0) {
|
|
|
+ test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads);
|
|
|
+ }
|
|
|
+ if (t.n_gen > 0) {
|
|
|
+ test_gen(ctx, 1, 0, t.n_threads);
|
|
|
+ }
|
|
|
|
|
|
for (int i = 0; i < params.reps; i++) {
|
|
|
uint64_t t_start = get_time_ns();
|