|
|
@@ -70,6 +70,7 @@ struct mtmd_cli_context {
|
|
|
llama_model * model;
|
|
|
llama_context * lctx;
|
|
|
const llama_vocab * vocab;
|
|
|
+ common_sampler * smpl;
|
|
|
llama_batch batch;
|
|
|
int n_batch;
|
|
|
|
|
|
@@ -89,8 +90,9 @@ struct mtmd_cli_context {
|
|
|
model = llama_init.model.get();
|
|
|
lctx = llama_init.context.get();
|
|
|
vocab = llama_model_get_vocab(model);
|
|
|
+ smpl = common_sampler_init(model, params.sampling);
|
|
|
n_threads = params.cpuparams.n_threads;
|
|
|
- batch = llama_batch_init(params.n_batch, 0, 1);
|
|
|
+ batch = llama_batch_init(1, 0, 1); // batch for next token generation
|
|
|
n_batch = params.n_batch;
|
|
|
|
|
|
if (!model || !lctx) {
|
|
|
@@ -118,6 +120,11 @@ struct mtmd_cli_context {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ ~mtmd_cli_context() {
|
|
|
+ llama_batch_free(batch);
|
|
|
+ common_sampler_free(smpl);
|
|
|
+ }
|
|
|
+
|
|
|
void init_vision_context(common_params & params) {
|
|
|
const char * clip_path = params.mmproj.path.c_str();
|
|
|
mtmd_context_params mparams = mtmd_context_params_default();
|
|
|
@@ -153,7 +160,7 @@ struct mtmd_cli_context {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
|
|
|
+static int generate_response(mtmd_cli_context & ctx, int n_predict) {
|
|
|
llama_tokens generated_tokens;
|
|
|
for (int i = 0; i < n_predict; i++) {
|
|
|
if (i > n_predict || !g_is_generating || g_is_interrupted) {
|
|
|
@@ -161,9 +168,9 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
|
|
|
break;
|
|
|
}
|
|
|
|
|
|
- llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1);
|
|
|
+ llama_token token_id = common_sampler_sample(ctx.smpl, ctx.lctx, -1);
|
|
|
generated_tokens.push_back(token_id);
|
|
|
- common_sampler_accept(smpl, token_id, true);
|
|
|
+ common_sampler_accept(ctx.smpl, token_id, true);
|
|
|
|
|
|
if (llama_vocab_is_eog(ctx.vocab, token_id) || ctx.check_antiprompt(generated_tokens)) {
|
|
|
LOG("\n");
|
|
|
@@ -261,7 +268,6 @@ int main(int argc, char ** argv) {
|
|
|
|
|
|
bool is_single_turn = !params.prompt.empty() && !params.image.empty();
|
|
|
|
|
|
- struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling);
|
|
|
int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict;
|
|
|
|
|
|
// Ctrl+C handling
|
|
|
@@ -300,7 +306,7 @@ int main(int argc, char ** argv) {
|
|
|
if (eval_message(ctx, msg, true)) {
|
|
|
return 1;
|
|
|
}
|
|
|
- if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) {
|
|
|
+ if (!g_is_interrupted && generate_response(ctx, n_predict)) {
|
|
|
return 1;
|
|
|
}
|
|
|
|
|
|
@@ -366,7 +372,7 @@ int main(int argc, char ** argv) {
|
|
|
return 1;
|
|
|
}
|
|
|
if (g_is_interrupted) break;
|
|
|
- if (generate_response(ctx, smpl, n_predict)) {
|
|
|
+ if (generate_response(ctx, n_predict)) {
|
|
|
return 1;
|
|
|
}
|
|
|
content.clear();
|