Selaa lähdekoodia

llava-cli: fix base64 prompt (#7248)

k.h.lai 1 vuosi sitten
vanhempi
sitoutus
30e70334f7
1 muutettua tiedostoa jossa 21 lisäystä ja 6 poistoa
  1. 21 6
      examples/llava/llava-cli.cpp

+ 21 - 6
examples/llava/llava-cli.cpp

@@ -300,14 +300,10 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    for (auto & image : params.image) {
+    if (prompt_contains_image(params.prompt)) {
         auto ctx_llava = llava_init_context(&params, model);
 
-        auto image_embed = load_image(ctx_llava, &params, image);
-        if (!image_embed) {
-            std::cerr << "error: failed to load image " << image << ". Terminating\n\n";
-            return 1;
-        }
+        auto image_embed = load_image(ctx_llava, &params, "");
 
         // process the prompt
         process_prompt(ctx_llava, image_embed, &params, params.prompt);
@@ -316,7 +312,26 @@ int main(int argc, char ** argv) {
         llava_image_embed_free(image_embed);
         ctx_llava->model = NULL;
         llava_free(ctx_llava);
+    } else {
+        for (auto & image : params.image) {
+            auto ctx_llava = llava_init_context(&params, model);
+
+            auto image_embed = load_image(ctx_llava, &params, image);
+            if (!image_embed) {
+                std::cerr << "error: failed to load image " << image << ". Terminating\n\n";
+                return 1;
+            }
+
+            // process the prompt
+            process_prompt(ctx_llava, image_embed, &params, params.prompt);
+
+            llama_print_timings(ctx_llava->ctx_llama);
+            llava_image_embed_free(image_embed);
+            ctx_llava->model = NULL;
+            llava_free(ctx_llava);
+        }
     }
+
     llama_free_model(model);
 
     return 0;