Просмотр исходного кода

llamafile : improve sgemm.cpp (#6796)

* llamafile : improve sgemm.cpp

- Re-enable by default
- Fix issue described in #6716
- Make code more abstract, elegant, and maintainable
- Faster handling of weirdly shaped `m` an `n` edge cases

* Address review comments

* Help clang produce fma instructions

* Address review comments
Justine Tunney 1 год назад
Родитель
Сommit
192090bae4
4 измененных файлов с 408 добавлено и 569 удалено
  1. 5 11
      CMakeLists.txt
  2. 0 4
      Makefile
  3. 3 5
      ggml.c
  4. 400 549
      sgemm.cpp

+ 5 - 11
CMakeLists.txt

@@ -43,17 +43,11 @@ else()
     set(LLAMA_METAL_DEFAULT OFF)
     set(LLAMA_METAL_DEFAULT OFF)
 endif()
 endif()
 
 
-# TODO: fix this for Android CI
-#       https://github.com/ggerganov/llama.cpp/pull/6716#issuecomment-2061509191
-#if (CMAKE_SYSTEM_NAME MATCHES "ANDROID")
-#    set(LLAMA_LLAMAFILE_DEFAULT OFF)
-#else()
-#    set(LLAMA_LLAMAFILE_DEFAULT ON)
-#endif()
-
-# TODO: temporary disable until MoE is fixed
-#       https://github.com/ggerganov/llama.cpp/pull/6716
-set(LLAMA_LLAMAFILE_DEFAULT OFF)
+if (CMAKE_SYSTEM_NAME MATCHES "ANDROID")
+    set(LLAMA_LLAMAFILE_DEFAULT OFF)
+else()
+    set(LLAMA_LLAMAFILE_DEFAULT ON)
+endif()
 
 
 # general
 # general
 option(BUILD_SHARED_LIBS                "build shared libraries"                                OFF)
 option(BUILD_SHARED_LIBS                "build shared libraries"                                OFF)

+ 0 - 4
Makefile

@@ -384,10 +384,6 @@ ifdef LLAMA_OPENBLAS
 	MK_LDFLAGS  += $(shell pkg-config --libs openblas)
 	MK_LDFLAGS  += $(shell pkg-config --libs openblas)
 endif # LLAMA_OPENBLAS
 endif # LLAMA_OPENBLAS
 
 
-# TODO: temporary disable until MoE is fixed
-#       https://github.com/ggerganov/llama.cpp/pull/6716
-LLAMA_NO_LLAMAFILE := 1
-
 ifndef LLAMA_NO_LLAMAFILE
 ifndef LLAMA_NO_LLAMAFILE
 	MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
 	MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
 	OBJS        += sgemm.o
 	OBJS        += sgemm.o

+ 3 - 5
ggml.c

@@ -10825,7 +10825,7 @@ static void ggml_compute_forward_mul_mat(
 #endif
 #endif
 
 
 #if GGML_USE_LLAMAFILE
 #if GGML_USE_LLAMAFILE
-    if (nb10 == ggml_type_size(src1->type)) {
+    if (src1_cont) {
         for (int64_t i13 = 0; i13 < ne13; i13++)
         for (int64_t i13 = 0; i13 < ne13; i13++)
             for (int64_t i12 = 0; i12 < ne12; i12++)
             for (int64_t i12 = 0; i12 < ne12; i12++)
                 if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
                 if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
@@ -10878,15 +10878,13 @@ UseGgmlGemm1:;
     const size_t row_size = ggml_row_size(vec_dot_type, ne10);
     const size_t row_size = ggml_row_size(vec_dot_type, ne10);
 
 
 #if GGML_USE_LLAMAFILE
 #if GGML_USE_LLAMAFILE
-    if (nb10 == ggml_type_size(src1->type) || src1->type != vec_dot_type) {
+    if (src1->type != vec_dot_type) {
         for (int64_t i13 = 0; i13 < ne13; i13++)
         for (int64_t i13 = 0; i13 < ne13; i13++)
             for (int64_t i12 = 0; i12 < ne12; i12++)
             for (int64_t i12 = 0; i12 < ne12; i12++)
                 if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
                 if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
                                      (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
                                      (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
                                      nb01/ggml_type_size(src0->type),
                                      nb01/ggml_type_size(src0->type),
-                                     (const char *)wdata + ggml_row_size(vec_dot_type,
-                                         nb12/ggml_type_size(src1->type)*i12 +
-                                         nb13/ggml_type_size(src1->type)*i13),
+                                     (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
                                      row_size/ggml_type_size(vec_dot_type),
                                      row_size/ggml_type_size(vec_dot_type),
                                      (char *)dst->data + i12*nb2 + i13*nb3,
                                      (char *)dst->data + i12*nb2 + i13*nb3,
                                      nb1/ggml_type_size(dst->type),
                                      nb1/ggml_type_size(dst->type),

Разница между файлами не показана из-за своего большого размера
+ 400 - 549
sgemm.cpp


Некоторые файлы не были показаны из-за большого количества измененных файлов