浏览代码

Variable scopes are fun

Piotr Wilkin 3 月之前
父节点
当前提交
554593d60d
共有 1 个文件被更改,包括 13 次插入8 次删除
  1. 13 8
      tools/main/main.cpp

+ 13 - 8
tools/main/main.cpp

@@ -14,6 +14,7 @@
 #include <sstream>
 #include <string>
 #include <vector>
+#include <mutex>
 
 // Forward declarations for internal cache access
 struct llama_memory_hybrid;
@@ -92,6 +93,7 @@ static void sigint_handler(int signo) {
 struct callback_data {
     std::vector<uint8_t> data;
     std::map<std::string, int32_t> tensors;
+    std::mutex mutex;
 };
 
 
@@ -210,6 +212,7 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
 
 static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
     auto * cb_data = (callback_data *) user_data;
+    std::lock_guard<std::mutex> lock(cb_data->mutex);
 
     const struct ggml_tensor * src0 = t->src[0];
     const struct ggml_tensor * src1 = t->src[1];
@@ -241,16 +244,18 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
 
     if (!ggml_is_quantized(t->type)) {
         uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data();
-        ggml_print_tensor(data, t->type, t->ne, t->nb, 3);
-        if (std::string(t->name).substr(0, std::string("post_moe-").size()) == "post_moe-" || 
-            std::string(t->name).substr(0, std::string("state_1d-").size()) == "state_1d-") {
-            if (cb_data->tensors.count(t->name) == 0) {
-                cb_data->tensors[t->name] = 1;
+        std::string tensor_name(t->name);
+        if (std::string(tensor_name).substr(0, std::string("post_moe-").size()) == "post_moe-" || 
+            std::string(tensor_name).substr(0, std::string("state_1d-").size()) == "state_1d-") {
+                
+            if (cb_data->tensors.count(tensor_name) == 0) {
+                cb_data->tensors[tensor_name] = 1;
             } else {
-                cb_data->tensors[t->name]++;
+                cb_data->tensors[tensor_name]++;
             }
-            save_tensor(t, data, (std::string(t->name) + "_" + std::to_string(cb_data->tensors[t->name]) + ".bin").c_str());
+            save_tensor(t, data, (tensor_name + "_" + std::to_string(cb_data->tensors[t->name]) + ".bin").c_str());
         }
+        ggml_print_tensor(data, t->type, t->ne, t->nb, 3);
     }
 
     return true;
@@ -312,9 +317,9 @@ int main(int argc, char ** argv) {
     std::vector<common_chat_msg> chat_msgs;
 
     // load the model and apply lora adapter, if any
+    callback_data cb_data;
     if (params.n_predict > 0 && params.n_predict < 50) {
         // enable debug prints if we print small number of tokens
-        callback_data cb_data;
         params.cb_eval = ggml_debug;
         params.cb_eval_user_data = &cb_data;
     }