Просмотр исходного кода

ggml-cpu : add RISC-V Zvfh impl for ggml_vec_mad_f16 (#17448)

* ggml-cpu : add RISC-V Zvfh impl for ggml_vec_mad_f16

* ggml-cpu : dedup scalar impl

* Update ggml/src/ggml-cpu/vec.h

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
xctan 1 месяц назад
Родитель
Сommit
6ab4e50d9c
1 измененных файлов с 84 добавлено и 85 удалено
  1. 84 85
      ggml/src/ggml-cpu/vec.h

+ 84 - 85
ggml/src/ggml-cpu/vec.h

@@ -397,119 +397,118 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
 }
 
 inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
-#if defined(GGML_SIMD)
-    #if defined(__ARM_FEATURE_SVE)
-        const int sve_register_length = svcntb() * 8;
-        const int ggml_f16_epr = sve_register_length / 16;
-        const int ggml_f16_step = 8 * ggml_f16_epr;
+#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
+    const int sve_register_length = svcntb() * 8;
+    const int ggml_f16_epr = sve_register_length / 16;
+    const int ggml_f16_step = 8 * ggml_f16_epr;
 
-        GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
+    GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
 
-        const int np= (n & ~(ggml_f16_step - 1));
+    int np = (n & ~(ggml_f16_step - 1));
 
-        svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
-        svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
-        for (int i = 0; i < np; i += ggml_f16_step) {
-            ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
-            ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
-            ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
+    svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
+    svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
+    for (int i = 0; i < np; i += ggml_f16_step) {
+        ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
+        ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
+        ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
 
-            GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
+        GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
 
-            ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
-            ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
-            ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
+        ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
+        ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
+        ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
 
-            GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
+        GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
 
-            ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
-            ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
-            ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
+        ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
+        ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
+        ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
 
-            GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
+        GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
 
-            ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
-            ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
-            ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
+        ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
+        ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
+        ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
 
-            GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
+        GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
 
-            ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
-            ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
-            ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
+        ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
+        ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
+        ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
 
-            GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
+        GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
 
-            ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
-            ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
-            ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
+        ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
+        ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
+        ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
 
-            GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
+        GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
 
-            ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
-            ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
-            ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
+        ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
+        ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
+        ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
 
-            GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
+        GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
 
-            ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
-            ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
-            ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
+        ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
+        ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
+        ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
 
-            GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
-        }
-        const int np2 = (n & ~(ggml_f16_epr - 1));
-        for (int k = np; k < np2; k += ggml_f16_epr) {
-            svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
-            svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
-            ry = GGML_F16x_VEC_FMA(ry, rx, vx);
-
-            GGML_F16x_VEC_STORE(y + k, ry, 0);
-        }
-
-        if (np2 < n) {
-            svbool_t pg = svwhilelt_b16(np2, n);
-            svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
-            svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
-            hy = svmad_f16_x(pg, hx, vx, hy);
-            svst1_f16(pg, (__fp16 *)(y + np2), hy);
-        }
+        GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
+    }
+    const int np2 = (n & ~(ggml_f16_epr - 1));
+    for (int k = np; k < np2; k += ggml_f16_epr) {
+        svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
+        svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
+        ry = GGML_F16x_VEC_FMA(ry, rx, vx);
 
-    #elif defined(__riscv_v_intrinsic)
-        // todo: RVV impl
-        // scalar
-        for (int i = 0; i < n; ++i) {
-            y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
-        }
-    #else
-        const int np = (n & ~(GGML_F16_STEP - 1));
+        GGML_F16x_VEC_STORE(y + k, ry, 0);
+    }
 
-        GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+    if (np2 < n) {
+        svbool_t pg = svwhilelt_b16(np2, n);
+        svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
+        svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
+        hy = svmad_f16_x(pg, hx, vx, hy);
+        svst1_f16(pg, (__fp16 *)(y + np2), hy);
+    }
+    np = n;
+#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
+    const int np = n;
+    _Float16 hv = (_Float16)v;
+    for (int i = 0, avl; i < n; i += avl) {
+        avl = __riscv_vsetvl_e16m8(n - i);
+        vfloat16m8_t ax = __riscv_vle16_v_f16m8((const _Float16 *)&x[i], avl);
+        vfloat16m8_t ay = __riscv_vle16_v_f16m8((_Float16 *)&y[i], avl);
+        vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8(ax, hv, ay, avl);
+        __riscv_vse16_v_f16m8((_Float16 *)&y[i], ny, avl);
+    }
+#elif defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
 
-        GGML_F16_VEC ax[GGML_F16_ARR];
-        GGML_F16_VEC ay[GGML_F16_ARR];
+    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
 
-        for (int i = 0; i < np; i += GGML_F16_STEP) {
-            for (int j = 0; j < GGML_F16_ARR; j++) {
-                ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
-                ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
-                ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
 
-                GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
-            }
-        }
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
 
-        // leftovers
-        for (int i = np; i < n; ++i) {
-            y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
+            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
         }
-    #endif
+    }
 #else
-    // scalar
-    for (int i = 0; i < n; ++i) {
+    const int np = 0;
+#endif
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
         y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
     }
-#endif
 }
 
 // xs and vs are byte strides of x and v