فهرست منبع

ggml : add NVPL BLAS support (#8329) (#8425)

* ggml : add NVPL BLAS support

* ggml : replace `<BLASLIB>_ENABLE_CBLAS` with `GGML_BLAS_USE_<BLASLIB>`

---------

Co-authored-by: ntukanov <ntukanov@nvidia.com>
Nicholai Tukanov 1 سال پیش
والد
کامیت
368645698a
2فایلهای تغییر یافته به همراه16 افزوده شده و 5 حذف شده
  1. 7 1
      Makefile
  2. 9 4
      ggml/src/ggml-blas.cpp

+ 7 - 1
Makefile

@@ -547,11 +547,17 @@ ifdef GGML_OPENBLAS64
 endif # GGML_OPENBLAS64
 
 ifdef GGML_BLIS
-	MK_CPPFLAGS += -DGGML_USE_BLAS -I/usr/local/include/blis -I/usr/include/blis
+	MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_BLIS -I/usr/local/include/blis -I/usr/include/blis
 	MK_LDFLAGS  += -lblis -L/usr/local/lib
 	OBJ_GGML    += ggml/src/ggml-blas.o
 endif # GGML_BLIS
 
+ifdef GGML_NVPL
+	MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_NVPL -DNVPL_ILP64 -I/usr/local/include/nvpl_blas -I/usr/include/nvpl_blas
+	MK_LDFLAGS  += -L/usr/local/lib -lnvpl_blas_core -lnvpl_blas_ilp64_gomp
+	OBJ_GGML    += ggml/src/ggml-blas.o
+endif # GGML_NVPL
+
 ifndef GGML_NO_LLAMAFILE
 	MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
 	OBJ_GGML    += ggml/src/llamafile/sgemm.o

+ 9 - 4
ggml/src/ggml-blas.cpp

@@ -8,11 +8,12 @@
 #   include <Accelerate/Accelerate.h>
 #elif defined(GGML_BLAS_USE_MKL)
 #   include <mkl.h>
+#elif defined(GGML_BLAS_USE_BLIS)
+#   include <blis.h>
+#elif defined(GGML_BLAS_USE_NVPL)
+#   include <nvpl_blas.h>
 #else
 #   include <cblas.h>
-#   ifdef BLIS_ENABLE_CBLAS
-#       include <blis.h>
-#   endif
 #endif
 
 struct ggml_backend_blas_context {
@@ -140,10 +141,14 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
     openblas_set_num_threads(ctx->n_threads);
 #endif
 
-#if defined(BLIS_ENABLE_CBLAS)
+#if defined(GGML_BLAS_USE_BLIS)
     bli_thread_set_num_threads(ctx->n_threads);
 #endif
 
+#if defined(GGML_BLAS_USE_NVPL)
+    nvpl_blas_set_num_threads(ctx->n_threads);
+#endif
+
     for (int64_t i13 = 0; i13 < ne13; i13++) {
         for (int64_t i12 = 0; i12 < ne12; i12++) {
             const int64_t i03 = i13/r3;