Просмотр исходного кода

imatrix : fix wname for mul_mat_id ops (#6271)

* imatrix : fix wname for mul_mat_id ops

* also filter tensor names in mul_mat_id ops

---------

Co-authored-by: slaren <slarengh@gmail.com>
Georgi Gerganov 1 год назад
Родитель
Сommit
a0e584defd
1 измененных файлов с 21 добавлено и 18 удалено
  1. 21 18
      examples/imatrix/imatrix.cpp

+ 21 - 18
examples/imatrix/imatrix.cpp

@@ -50,29 +50,31 @@ private:
     void keep_imatrix(int ncall) const;
 };
 
+// remove any prefix and suffixes from the name
+// CUDA0#blk.0.attn_k.weight#0 => blk.0.attn_k.weight
+static std::string filter_tensor_name(const char * name) {
+    std::string wname;
+    const char * p = strchr(name, '#');
+    if (p != NULL) {
+        p = p + 1;
+        const char * q = strchr(p, '#');
+        if (q != NULL) {
+            wname = std::string(p, q - p);
+        } else {
+            wname = p;
+        }
+    } else {
+        wname = name;
+    }
+    return wname;
+}
+
 bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) {
     GGML_UNUSED(user_data);
 
     const struct ggml_tensor * src0 = t->src[0];
     const struct ggml_tensor * src1 = t->src[1];
-
-    std::string wname;
-    {
-        // remove any prefix and suffixes from the name
-        // CUDA0#blk.0.attn_k.weight#0 => blk.0.attn_k.weight
-        const char * p = strchr(src0->name, '#');
-        if (p != NULL) {
-            p = p + 1;
-            const char * q = strchr(p, '#');
-            if (q != NULL) {
-                wname = std::string(p, q - p);
-            } else {
-                wname = p;
-            }
-        } else {
-            wname = src0->name;
-        }
-    }
+    std::string wname = filter_tensor_name(src0->name);
 
     // when ask is true, the scheduler wants to know if we are interested in data from this tensor
     // if we return true, a follow-up call will be made with ask=false in which we can do the actual collection
@@ -112,6 +114,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
         // this is necessary to guarantee equal number of "ncall" for each tensor
         for (int ex = 0; ex < n_as; ++ex) {
             src0 = t->src[2 + ex];
+            wname = filter_tensor_name(src0->name);
             auto& e = m_stats[wname];
             if (e.values.empty()) {
                 e.values.resize(src1->ne[0], 0);