浏览代码

mtmd: fix --no-warmup (#17695)

Xuan-Son Nguyen 1 月之前
父节点
当前提交
a96283adc4
共有 2 个文件被更改,包括 31 次插入20 次删除
  1. 1 1
      common/arg.cpp
  2. 30 19
      tools/mtmd/clip.cpp

+ 1 - 1
common/arg.cpp

@@ -1226,7 +1226,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         [](common_params & params) {
             params.warmup = false;
         }
-    ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY}));
+    ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY}));
     add_opt(common_arg(
         {"--spm-infill"},
         string_format(

+ 30 - 19
tools/mtmd/clip.cpp

@@ -428,6 +428,7 @@ struct clip_ctx {
     int max_nodes = 8192;
     ggml_backend_sched_ptr sched;
     clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
+    bool is_allocated = false;
 
     // for debugging
     bool debug_graph = false;
@@ -3305,12 +3306,30 @@ struct clip_model_loader {
     };
 
     static void warmup(clip_ctx & ctx_clip) {
+        // create a fake batch
+        const auto & hparams = ctx_clip.model.hparams;
+        clip_image_f32_batch batch;
+        clip_image_f32_ptr img(clip_image_f32_init());
+        if (ctx_clip.model.modality == CLIP_MODALITY_VISION) {
+            img->nx = hparams.warmup_image_size;
+            img->ny = hparams.warmup_image_size;
+            LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny);
+        } else {
+            img->nx = hparams.warmup_audio_size;
+            img->ny = hparams.n_mel_bins;
+            LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx);
+        }
+        batch.entries.push_back(std::move(img));
+        warmup(ctx_clip, batch);
+    }
+
+    static void warmup(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) {
         support_info_graph info;
 
         if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) {
             // try to enable flash attention to see if it's supported
             ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED;
-            info = alloc_compute_meta(ctx_clip);
+            info = alloc_compute_meta(ctx_clip, batch);
             if (!info.fattn && info.fattn_op) {
                 auto op = info.fattn_op;
                 LOG_WRN("%s: *****************************************************************\n", __func__);
@@ -3329,15 +3348,17 @@ struct clip_model_loader {
                 LOG_WRN("%s: please report this on github as an issue\n", __func__);
                 LOG_WRN("%s: *****************************************************************\n", __func__);
                 ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED;
-                alloc_compute_meta(ctx_clip);
+                alloc_compute_meta(ctx_clip, batch);
             }
         } else {
-            info = alloc_compute_meta(ctx_clip);
+            info = alloc_compute_meta(ctx_clip, batch);
             if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
                 LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__);
             }
         }
 
+        ctx_clip.is_allocated = true; // mark buffers as allocated
+
         LOG_INF("%s: flash attention is %s\n", __func__,
             (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
 
@@ -3369,24 +3390,9 @@ struct clip_model_loader {
         }
     }
 
-    static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) {
-        const auto & hparams = ctx_clip.model.hparams;
+    static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) {
         ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
 
-        // create a fake batch
-        clip_image_f32_batch batch;
-        clip_image_f32_ptr img(clip_image_f32_init());
-        if (ctx_clip.model.modality == CLIP_MODALITY_VISION) {
-            img->nx = hparams.warmup_image_size;
-            img->ny = hparams.warmup_image_size;
-            LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny);
-        } else {
-            img->nx = hparams.warmup_audio_size;
-            img->ny = hparams.n_mel_bins;
-            LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx);
-        }
-        batch.entries.push_back(std::move(img));
-
         ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch);
         ggml_backend_sched_reserve(ctx_clip.sched.get(), gf);
 
@@ -4630,6 +4636,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         return false; // only support batch size of 1
     }
 
+    // if buffers are not allocated, we need to do a warmup run to allocate them
+    if (!ctx->is_allocated) {
+        clip_model_loader::warmup(*ctx, *imgs_c_ptr);
+    }
+
     // build the inference graph
     ctx->debug_print_tensors.clear();
     ggml_backend_sched_reset(ctx->sched.get());