Browse Source

ggml: aarch64: SVE kernels for q8_0_q8_0, q4_0_q8_0 vector dot (#7433)

* Add SVE support for q4_0_q8_0 q8_0_q8_0

* remove ifdef
Masaya, Kato 1 year ago
parent
commit
faa0e6979a
7 changed files with 85 additions and 2 deletions
  1. 4 0
      CMakeLists.txt
  2. 1 0
      common/common.cpp
  3. 4 0
      ggml-impl.h
  4. 64 2
      ggml-quants.c
  5. 10 0
      ggml.c
  6. 1 0
      ggml.h
  7. 1 0
      llama.cpp

+ 4 - 0
CMakeLists.txt

@@ -72,6 +72,7 @@ else()
     set(INS_ENB ON)
 endif()
 
+option(LLAMA_SVE                             "llama: enable SVE"                                OFF)
 option(LLAMA_AVX                             "llama: enable AVX"                                ${INS_ENB})
 option(LLAMA_AVX2                            "llama: enable AVX2"                               ${INS_ENB})
 option(LLAMA_AVX512                          "llama: enable AVX512"                             OFF)
@@ -1040,6 +1041,9 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STR
             # Raspberry Pi 3, 4, Zero 2 (32-bit)
             list(APPEND ARCH_FLAGS -mno-unaligned-access)
         endif()
+        if (LLAMA_SVE)
+            list(APPEND ARCH_FLAGS -march=armv8.6-a+sve)
+        endif()
     endif()
 elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
         (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND

+ 1 - 0
common/common.cpp

@@ -2844,6 +2844,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
     fprintf(stream, "cpu_has_fma: %s\n",         ggml_cpu_has_fma()         ? "true" : "false");
     fprintf(stream, "cpu_has_gpublas: %s\n",     ggml_cpu_has_gpublas()     ? "true" : "false");
     fprintf(stream, "cpu_has_neon: %s\n",        ggml_cpu_has_neon()        ? "true" : "false");
+    fprintf(stream, "cpu_has_sve: %s\n",         ggml_cpu_has_sve()         ? "true" : "false");
     fprintf(stream, "cpu_has_f16c: %s\n",        ggml_cpu_has_f16c()        ? "true" : "false");
     fprintf(stream, "cpu_has_fp16_va: %s\n",     ggml_cpu_has_fp16_va()     ? "true" : "false");
     fprintf(stream, "cpu_has_wasm_simd: %s\n",   ggml_cpu_has_wasm_simd()   ? "true" : "false");

+ 4 - 0
ggml-impl.h

@@ -144,6 +144,10 @@ extern "C" {
 #endif
 #endif
 
+#if defined(__ARM_FEATURE_SVE)
+#include <arm_sve.h>
+#endif
+
 // 16-bit float
 // on Arm, we use __fp16
 // on x86, we use uint16_t

+ 64 - 2
ggml-quants.c

@@ -3813,7 +3813,44 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
         return;
     }
 #endif
-#if defined(__ARM_NEON)
+#if defined(__ARM_FEATURE_SVE)
+    const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
+    const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
+
+    svfloat32_t sumv0 = svdup_n_f32(0.0f);
+    svfloat32_t sumv1 = svdup_n_f32(0.0f);
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q4_0 * restrict x0 = &x[i + 0];
+        const block_q4_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        // load x
+        const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
+        const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
+
+        // 4-bit -> 8-bit
+        const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
+        const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
+
+        // sub 8
+        const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
+        const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
+
+        // load y
+        const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
+        const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
+
+        // dot product
+        sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+    }
+
+    *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+#elif defined(__ARM_NEON)
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
@@ -5384,7 +5421,32 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
         return;
     }
 #endif
-#if defined(__ARM_NEON)
+#if defined(__ARM_FEATURE_SVE)
+    svfloat32_t sumv0 = svdup_n_f32(0.0f);
+    svfloat32_t sumv1 = svdup_n_f32(0.0f);
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q8_0 * restrict x0 = &x[i + 0];
+        const block_q8_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        // load x
+        const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
+        const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
+
+        // load y
+        const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
+        const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
+
+        sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+    }
+
+    *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+#elif defined(__ARM_NEON)
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
 

+ 10 - 0
ggml.c

@@ -22742,6 +22742,16 @@ int ggml_cpu_has_neon(void) {
 #endif
 }
 
+int ggml_cpu_has_sve(void) {
+#if defined(__ARM_FEATURE_SVE)
+    // TODO: Currently, SVE 256 bit is only supported.
+    GGML_ASSERT(svcntb() == QK8_0);
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_arm_fma(void) {
 #if defined(__ARM_FEATURE_FMA)
     return 1;

+ 1 - 0
ggml.h

@@ -2404,6 +2404,7 @@ extern "C" {
     GGML_API int ggml_cpu_has_avx512_bf16(void);
     GGML_API int ggml_cpu_has_fma        (void);
     GGML_API int ggml_cpu_has_neon       (void);
+    GGML_API int ggml_cpu_has_sve        (void);
     GGML_API int ggml_cpu_has_arm_fma    (void);
     GGML_API int ggml_cpu_has_metal      (void);
     GGML_API int ggml_cpu_has_f16c       (void);

+ 1 - 0
llama.cpp

@@ -18337,6 +18337,7 @@ const char * llama_print_system_info(void) {
     s += "AVX512_BF16 = " + std::to_string(ggml_cpu_has_avx512_bf16()) + " | ";
     s += "FMA = "         + std::to_string(ggml_cpu_has_fma())         + " | ";
     s += "NEON = "        + std::to_string(ggml_cpu_has_neon())        + " | ";
+    s += "SVE = "         + std::to_string(ggml_cpu_has_sve())         + " | ";
     s += "ARM_FMA = "     + std::to_string(ggml_cpu_has_arm_fma())     + " | ";
     s += "F16C = "        + std::to_string(ggml_cpu_has_f16c())        + " | ";
     s += "FP16_VA = "     + std::to_string(ggml_cpu_has_fp16_va())     + " | ";