|
|
@@ -39,6 +39,11 @@ int main(int argc, char ** argv) {
|
|
|
return 1;
|
|
|
}
|
|
|
|
|
|
+ if (params.n_predict < -1) {
|
|
|
+ LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
|
|
|
+ return 1;
|
|
|
+ }
|
|
|
+
|
|
|
common_init();
|
|
|
|
|
|
if (params.model_draft.empty()) {
|
|
|
@@ -190,8 +195,8 @@ int main(int argc, char ** argv) {
|
|
|
drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
|
|
|
}
|
|
|
|
|
|
- llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
|
|
- llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
|
|
|
+ llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
|
|
+ llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);
|
|
|
|
|
|
const auto t_dec_start = ggml_time_us();
|
|
|
|
|
|
@@ -441,7 +446,7 @@ int main(int argc, char ** argv) {
|
|
|
++n_past_dft;
|
|
|
}
|
|
|
|
|
|
- if (n_predict > params.n_predict || has_eos) {
|
|
|
+ if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
|
|
break;
|
|
|
}
|
|
|
|