|
|
@@ -145,6 +145,20 @@ int main(int argc, char ** argv) {
|
|
|
|
|
|
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
|
|
|
|
|
|
+ if (llama_model_has_encoder(model)) {
|
|
|
+ if (llama_encode(ctx, batch)) {
|
|
|
+ fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
|
+ return 1;
|
|
|
+ }
|
|
|
+
|
|
|
+ llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
|
|
|
+ if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
|
|
|
+ decoder_start_token_id = llama_vocab_bos(vocab);
|
|
|
+ }
|
|
|
+
|
|
|
+ batch = llama_batch_get_one(&decoder_start_token_id, 1);
|
|
|
+ }
|
|
|
+
|
|
|
// main loop
|
|
|
|
|
|
const auto t_main_start = ggml_time_us();
|