Просмотр исходного кода

ggml : fix llamafile sgemm wdata offsets (#6710)

ggml-ci
Georgi Gerganov 1 год назад
Родитель
Сommit
666867b799
3 измененных файлов с 12 добавлено и 7 удалено
  1. 6 0
      CMakeLists.txt
  2. 2 0
      Makefile
  3. 4 7
      ggml.c

+ 6 - 0
CMakeLists.txt

@@ -88,6 +88,7 @@ endif()
 # 3rd party libs
 option(LLAMA_ACCELERATE                      "llama: enable Accelerate framework"               ON)
 option(LLAMA_BLAS                            "llama: use BLAS"                                  OFF)
+option(LLAMA_LLAMAFILE                       "llama: use llamafile SGEMM"                       ON)
 set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
 option(LLAMA_CUDA                            "llama: use CUDA"                                  OFF)
 option(LLAMA_CUBLAS                          "llama: use CUDA (deprecated, use LLAMA_CUDA)"     OFF)
@@ -286,6 +287,7 @@ if (LLAMA_METAL)
         ${METALKIT_FRAMEWORK}
         )
 endif()
+
 if (LLAMA_BLAS)
     if (LLAMA_STATIC)
         set(BLA_STATIC ON)
@@ -368,6 +370,10 @@ if (LLAMA_BLAS)
     endif()
 endif()
 
+if (LLAMA_LLAMAFILE)
+    add_compile_definitions(GGML_USE_LLAMAFILE)
+endif()
+
 if (LLAMA_QKK_64)
     add_compile_definitions(GGML_QKK_64)
 endif()

+ 2 - 0
Makefile

@@ -222,6 +222,8 @@ endif # LLAMA_DISABLE_LOGS
 # disable ggml.c's use of sgemm.cpp
 ifdef LLAMA_NO_LLAMAFILE
 	MK_CPPFLAGS += -DGGML_USE_LLAMAFILE=0
+else
+	MK_CPPFLAGS += -DGGML_USE_LLAMAFILE=1
 endif
 
 # warnings

+ 4 - 7
ggml.c

@@ -33,12 +33,8 @@
 #include <unistd.h>
 #endif
 
-#ifndef GGML_USE_LLAMAFILE
 #ifdef __ARM_FEATURE_MATMUL_INT8
-#define GGML_USE_LLAMAFILE 0
-#else
-#define GGML_USE_LLAMAFILE 1
-#endif
+#undef GGML_USE_LLAMAFILE
 #endif
 
 #if defined(_MSC_VER)
@@ -10879,8 +10875,9 @@ UseGgmlGemm1:;
                 if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
                                      (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
                                      nb01/ggml_type_size(src0->type),
-                                     (const char *)wdata + (nb12/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i12 +
-                                                            nb13/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i13),
+                                     (const char *)wdata + ggml_row_size(vec_dot_type,
+                                         nb12/ggml_type_size(src1->type)*i12 +
+                                         nb13/ggml_type_size(src1->type)*i13),
                                      row_size/ggml_type_size(vec_dot_type),
                                      (char *)dst->data + i12*nb2 + i13*nb3,
                                      nb1/ggml_type_size(dst->type),