Prechádzať zdrojové kódy

mtmd: add mtmd_log_set (#17268)

Xuan-Son Nguyen 2 mesiacov pred
rodič
commit
9b17d74ab7

+ 1 - 5
common/common.cpp

@@ -355,11 +355,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
 }
 
 void common_init() {
-    llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
-        if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
-            common_log_add(common_log_main(), level, "%s", text);
-        }
-    }, NULL);
+    llama_log_set(common_log_default_callback, NULL);
 
 #ifdef NDEBUG
     const char * build_type = "";

+ 6 - 0
common/log.cpp

@@ -442,3 +442,9 @@ void common_log_set_prefix(struct common_log * log, bool prefix) {
 void common_log_set_timestamps(struct common_log * log, bool timestamps) {
     log->set_timestamps(timestamps);
 }
+
+void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
+    if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
+        common_log_add(common_log_main(), level, "%s", text);
+    }
+}

+ 2 - 0
common/log.h

@@ -36,6 +36,8 @@ extern int common_log_verbosity_thold;
 
 void common_log_set_verbosity_thold(int verbosity); // not thread-safe
 
+void common_log_default_callback(enum ggml_log_level level, const char * text, void * user_data);
+
 // the common_log uses an internal worker thread to print/write log messages
 // when the worker thread is paused, incoming log messages are discarded
 struct common_log;

+ 5 - 12
tools/mtmd/clip-impl.h

@@ -224,7 +224,6 @@ static void clip_log_callback_default(enum ggml_log_level level, const char * te
 }
 
 struct clip_logger_state {
-    ggml_log_level verbosity_thold;
     ggml_log_callback log_callback;
     void * log_callback_user_data;
 };
@@ -258,17 +257,11 @@ static void clip_log_internal(enum ggml_log_level level, const char * format, ..
     va_end(args);
 }
 
-#define LOG_TMPL(level, ...) \
-    do { \
-        if ((level) >= g_logger_state.verbosity_thold) { \
-            clip_log_internal((level), __VA_ARGS__); \
-        } \
-    } while (0)
-#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO,  __VA_ARGS__)
-#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN,  __VA_ARGS__)
-#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
-#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
-#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT,  __VA_ARGS__)
+#define LOG_INF(...) clip_log_internal(GGML_LOG_LEVEL_INFO,  __VA_ARGS__)
+#define LOG_WRN(...) clip_log_internal(GGML_LOG_LEVEL_WARN,  __VA_ARGS__)
+#define LOG_ERR(...) clip_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
+#define LOG_DBG(...) clip_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
+#define LOG_CNT(...) clip_log_internal(GGML_LOG_LEVEL_CONT,  __VA_ARGS__)
 
 //
 // cpp wrappers

+ 1 - 3
tools/mtmd/clip.cpp

@@ -24,8 +24,7 @@
 #include <array>
 #include <functional>
 
-// TODO: allow to pass callback from user code
-struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
+struct clip_logger_state g_logger_state = {clip_log_callback_default, NULL};
 
 enum ffn_op_type {
     FFN_GELU,
@@ -3507,7 +3506,6 @@ struct clip_model_loader {
 };
 
 struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params) {
-    g_logger_state.verbosity_thold = ctx_params.verbosity;
     clip_ctx * ctx_vision = nullptr;
     clip_ctx * ctx_audio = nullptr;
 

+ 0 - 1
tools/mtmd/clip.h

@@ -31,7 +31,6 @@ enum clip_flash_attn_type {
 
 struct clip_context_params {
     bool use_gpu;
-    enum ggml_log_level verbosity;
     enum clip_flash_attn_type flash_attn_type;
     int image_min_tokens;
     int image_max_tokens;

+ 1 - 1
tools/mtmd/mtmd-cli.cpp

@@ -135,7 +135,6 @@ struct mtmd_cli_context {
         mparams.use_gpu          = params.mmproj_use_gpu;
         mparams.print_timings    = true;
         mparams.n_threads        = params.cpuparams.n_threads;
-        mparams.verbosity        = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
         mparams.flash_attn_type  = params.flash_attn_type;
         mparams.image_min_tokens = params.image_min_tokens;
         mparams.image_max_tokens = params.image_max_tokens;
@@ -277,6 +276,7 @@ int main(int argc, char ** argv) {
     }
 
     common_init();
+    mtmd_helper_log_set(common_log_default_callback, nullptr);
 
     if (params.mmproj.path.empty()) {
         show_additional_info(argc, argv);

+ 60 - 3
tools/mtmd/mtmd-helper.cpp

@@ -32,8 +32,65 @@
 #define STB_IMAGE_IMPLEMENTATION
 #include "stb/stb_image.h"
 
-#define LOG_INF(...) fprintf(stdout, __VA_ARGS__)
-#define LOG_ERR(...) fprintf(stderr, __VA_ARGS__)
+//
+// internal logging functions
+//
+
+struct mtmd_helper_logger {
+    ggml_log_callback default_callback = [](ggml_log_level level, const char * text, void * user_data) {
+        (void) level;
+        (void) user_data;
+        fputs(text, stderr);
+        fflush(stderr);
+    };
+
+    ggml_log_callback log_callback = default_callback;
+    void * log_callback_user_data;
+
+    void log_v(enum ggml_log_level level, const char * format, va_list args) {
+        if (format == NULL) {
+            return;
+        }
+        va_list args_copy;
+        va_copy(args_copy, args);
+        char buffer[128];
+        int len = vsnprintf(buffer, 128, format, args);
+        if (len < 128) {
+            log_callback(level, buffer, log_callback_user_data);
+        } else {
+            char * buffer2 = (char *) calloc(len + 1, sizeof(char));
+            vsnprintf(buffer2, len + 1, format, args_copy);
+            buffer2[len] = 0;
+            log_callback(level, buffer2, log_callback_user_data);
+            free(buffer2);
+        }
+        va_end(args_copy);
+    }
+
+    void log(enum ggml_log_level level, const char * format, ...) {
+        va_list args;
+        va_start(args, format);
+        log_v(level, format, args);
+        va_end(args);
+    }
+} g_logger;
+
+#define LOG_INF(...) g_logger.log(GGML_LOG_LEVEL_INFO,  __VA_ARGS__)
+#define LOG_WRN(...) g_logger.log(GGML_LOG_LEVEL_WARN,  __VA_ARGS__)
+#define LOG_ERR(...) g_logger.log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
+
+void mtmd_helper_log_set(ggml_log_callback log_callback, void * user_data) {
+    if (log_callback == nullptr) {
+        log_callback = g_logger.default_callback;
+    }
+    g_logger.log_callback = log_callback;
+    g_logger.log_callback_user_data = user_data;
+    mtmd_log_set(log_callback, user_data);
+}
+
+//
+// helper functions
+//
 
 size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
     size_t n_tokens = 0;
@@ -325,7 +382,7 @@ int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
                                 llama_pos * new_n_past) {
     size_t n_chunks = mtmd_input_chunks_size(chunks);
     if (n_chunks == 0) {
-        LOG_ERR("no chunks to eval\n");
+        LOG_WRN("no chunks to eval\n");
         return 0;
     }
 

+ 5 - 0
tools/mtmd/mtmd-helper.h

@@ -20,6 +20,11 @@ extern "C" {
 // BREAKING CHANGES are expected.
 //
 
+// Set callback for all future logging events.
+// If this is not called, or NULL is supplied, everything is output on stderr.
+// Note: this also call mtmd_log_set() internally
+MTMD_API void mtmd_helper_log_set(ggml_log_callback log_callback, void * user_data);
+
 // helper function to construct a mtmd_bitmap from a file
 // it calls mtmd_helper_bitmap_init_from_buf() internally
 // returns nullptr on failure

+ 5 - 2
tools/mtmd/mtmd.cpp

@@ -105,7 +105,6 @@ mtmd_context_params mtmd_context_params_default() {
         /* use_gpu           */ true,
         /* print_timings     */ true,
         /* n_threads         */ 4,
-        /* verbosity         */ GGML_LOG_LEVEL_INFO,
         /* image_marker      */ MTMD_DEFAULT_IMAGE_MARKER,
         /* media_marker      */ mtmd_default_marker(),
         /* flash_attn_type   */ LLAMA_FLASH_ATTN_TYPE_AUTO,
@@ -175,7 +174,6 @@ struct mtmd_context {
 
         clip_context_params ctx_clip_params {
             /* use_gpu           */ ctx_params.use_gpu,
-            /* verbosity         */ ctx_params.verbosity,
             /* flash_attn_type   */ CLIP_FLASH_ATTN_TYPE_AUTO,
             /* image_min_tokens  */ ctx_params.image_min_tokens,
             /* image_max_tokens  */ ctx_params.image_max_tokens,
@@ -1096,3 +1094,8 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() {
 
     return chunks;
 }
+
+void mtmd_log_set(ggml_log_callback log_callback, void * user_data) {
+    g_logger_state.log_callback = log_callback ? log_callback : clip_log_callback_default;
+    g_logger_state.log_callback_user_data = user_data;
+}

+ 4 - 1
tools/mtmd/mtmd.h

@@ -79,7 +79,6 @@ struct mtmd_context_params {
     bool use_gpu;
     bool print_timings;
     int n_threads;
-    enum ggml_log_level verbosity;
     const char * image_marker; // deprecated, use media_marker instead
     const char * media_marker;
     enum llama_flash_attn_type flash_attn_type;
@@ -215,6 +214,10 @@ MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
 // llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float)
 MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
 
+// Set callback for all future logging events.
+// If this is not called, or NULL is supplied, everything is output on stderr.
+MTMD_API void mtmd_log_set(ggml_log_callback log_callback, void * user_data);
+
 /////////////////////////////////////////
 
 // test function, to be used in test-mtmd-c-api.c

+ 2 - 1
tools/server/server.cpp

@@ -2454,11 +2454,12 @@ struct server_context {
 
         std::string & mmproj_path = params_base.mmproj.path;
         if (!mmproj_path.empty()) {
+            mtmd_helper_log_set(common_log_default_callback, nullptr);
+
             mtmd_context_params mparams = mtmd_context_params_default();
             mparams.use_gpu          = params_base.mmproj_use_gpu;
             mparams.print_timings    = false;
             mparams.n_threads        = params_base.cpuparams.n_threads;
-            mparams.verbosity        = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
             mparams.flash_attn_type  = params_base.flash_attn_type;
             mparams.image_min_tokens = params_base.image_min_tokens;
             mparams.image_max_tokens = params_base.image_max_tokens;