|
|
@@ -0,0 +1,3196 @@
|
|
|
+#include "ggml.h"
|
|
|
+#include "ime_kernels.h"
|
|
|
+
|
|
|
+#include <algorithm>
|
|
|
+#include <cmath>
|
|
|
+
|
|
|
+// clang-format off
|
|
|
+#if defined(__GNUC__)
|
|
|
+#pragma GCC diagnostic ignored "-Woverlength-strings"
|
|
|
+#pragma GCC diagnostic ignored "-Wcast-qual"
|
|
|
+#pragma GCC diagnostic ignored "-Wunused-parameter"
|
|
|
+#endif
|
|
|
+// clang-format on
|
|
|
+namespace sqnbitgemm_spacemit_ime {
|
|
|
+
|
|
|
+#define QUANTIZEM4ROW_KERNEL \
|
|
|
+ "vmv.s.x v16, zero \n\t" \
|
|
|
+ "vfabs.v v8, v0 \n\t" \
|
|
|
+ "vfredmax.vs v16, v8, v16 \n\t" \
|
|
|
+ "vfmv.f.s f10, v16 \n\t" \
|
|
|
+ "fmul.s f10, f10, %[RMAXREC] \n\t" \
|
|
|
+ "fsw f10, (a1) \n\t" \
|
|
|
+ "fdiv.s f11, %[FONE], f10 \n\t" \
|
|
|
+ "vfmul.vf v16, v0, f11 \n\t" \
|
|
|
+ "vfcvt.x.f.v v16, v16 \n\t" \
|
|
|
+ "vsetvli t0, zero, e16, mf2 \n\t" \
|
|
|
+ "vnclip.wx v16, v16, zero \n\t" \
|
|
|
+ "vnclip.wx v17, v17, zero \n\t" \
|
|
|
+ "vnclip.wx v18, v18, zero \n\t" \
|
|
|
+ "vnclip.wx v19, v19, zero \n\t" \
|
|
|
+ "vnclip.wx v20, v20, zero \n\t" \
|
|
|
+ "vnclip.wx v21, v21, zero \n\t" \
|
|
|
+ "vnclip.wx v22, v22, zero \n\t" \
|
|
|
+ "vnclip.wx v23, v23, zero \n\t" \
|
|
|
+ "vsetvli t0, zero, e8, mf4 \n\t" \
|
|
|
+ "vnclip.wx v24, v16, zero \n\t" \
|
|
|
+ "vnclip.wx v25, v17, zero \n\t" \
|
|
|
+ "vnclip.wx v26, v18, zero \n\t" \
|
|
|
+ "vnclip.wx v27, v19, zero \n\t" \
|
|
|
+ "vnclip.wx v28, v20, zero \n\t" \
|
|
|
+ "vnclip.wx v29, v21, zero \n\t" \
|
|
|
+ "vnclip.wx v30, v22, zero \n\t" \
|
|
|
+ "vnclip.wx v31, v23, zero \n\t"
|
|
|
+
|
|
|
+#define QUANTIZEM4ROW_STORE \
|
|
|
+ "addi t1, %[BlkLen], 0 \n\t" \
|
|
|
+ "vsetvli t0, t1, e8, mf4 \n\t" \
|
|
|
+ "vse8.v v24, (s1) \n\t" \
|
|
|
+ "addi s1, s1, 32 \n\t" \
|
|
|
+ "sub t1, t1, t0 \n\t" \
|
|
|
+ "vsetvli t0, t1, e8, mf4 \n\t" \
|
|
|
+ "vse8.v v25, (s1) \n\t" \
|
|
|
+ "addi s1, s1, 32 \n\t" \
|
|
|
+ "sub t1, t1, t0 \n\t" \
|
|
|
+ "vsetvli t0, t1, e8, mf4 \n\t" \
|
|
|
+ "vse8.v v26, (s1) \n\t" \
|
|
|
+ "addi s1, s1, 32 \n\t" \
|
|
|
+ "sub t1, t1, t0 \n\t" \
|
|
|
+ "vsetvli t0, t1, e8, mf4 \n\t" \
|
|
|
+ "vse8.v v27, (s1) \n\t" \
|
|
|
+ "addi s1, s1, 32 \n\t" \
|
|
|
+ "sub t1, t1, t0 \n\t" \
|
|
|
+ "vsetvli t0, t1, e8, mf4 \n\t" \
|
|
|
+ "vse8.v v28, (s1) \n\t" \
|
|
|
+ "addi s1, s1, 32 \n\t" \
|
|
|
+ "sub t1, t1, t0 \n\t" \
|
|
|
+ "vsetvli t0, t1, e8, mf4 \n\t" \
|
|
|
+ "vse8.v v29, (s1) \n\t" \
|
|
|
+ "addi s1, s1, 32 \n\t" \
|
|
|
+ "sub t1, t1, t0 \n\t" \
|
|
|
+ "vsetvli t0, t1, e8, mf4 \n\t" \
|
|
|
+ "vse8.v v30, (s1) \n\t" \
|
|
|
+ "addi s1, s1, 32 \n\t" \
|
|
|
+ "sub t1, t1, t0 \n\t" \
|
|
|
+ "vsetvli t0, t1, e8, mf4 \n\t" \
|
|
|
+ "vse8.v v31, (s1) \n\t"
|
|
|
+
|
|
|
+namespace ime1 {
|
|
|
+void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
|
|
|
+ constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
|
|
|
+ const float fone = 1.0f;
|
|
|
+
|
|
|
+ if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) {
|
|
|
+ for (size_t row_index = 0; row_index < 4; ++row_index) {
|
|
|
+ const float * SRC = A + row_index * CountK;
|
|
|
+ std::byte * DST = QuantA + row_index * sizeof(float);
|
|
|
+
|
|
|
+ const size_t offset = (4 - row_index) * 4 + row_index * 8;
|
|
|
+ const size_t stride = 4 * (sizeof(float) + BlkLen);
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "addi t2, %[CountK], 0 \n\t"
|
|
|
+ "addi a1, %[DST], 0 \n\t"
|
|
|
+ "blt t2, %[BlkLen], TAIL%= \n\t"
|
|
|
+
|
|
|
+ "LOOP%=: \n\t"
|
|
|
+ "vsetvli t0, %[BlkLen], e32, m8 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "sub t2, t2, t0 \n\t"
|
|
|
+ "slli t1, t0, 2 \n\t"
|
|
|
+ "add %[SRC], %[SRC], t1 \n\t"
|
|
|
+ "add s1, a1, %[OFFSET] \n\t"
|
|
|
+
|
|
|
+ QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE
|
|
|
+
|
|
|
+ "add a1, a1, %[STRIDE] \n\t"
|
|
|
+ "bge t2, %[BlkLen], LOOP%= \n\t"
|
|
|
+
|
|
|
+ "TAIL%=: \n\t"
|
|
|
+ "blez t2, QUIT%= \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ "vxor.vv v24, v24, v24 \n\t"
|
|
|
+ "vsetvli t0, t2, e32, m8 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "add s1, a1, %[OFFSET] \n\t"
|
|
|
+
|
|
|
+ QUANTIZEM4ROW_KERNEL
|
|
|
+
|
|
|
+ "addi t3, %[BlkLen], 0 \n\t"
|
|
|
+ "addi s2, s1, 0 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf4 \n\t"
|
|
|
+ "vxor.vv v8, v8, v8 \n\t"
|
|
|
+ "SET_ZERO%=: \n\t"
|
|
|
+ "vse8.v v8, (s2) \n\t"
|
|
|
+ "addi s2, s2, 32 \n\t"
|
|
|
+ "addi t3, t3, -8 \n\t"
|
|
|
+ "bnez t3, SET_ZERO%= \n\t"
|
|
|
+
|
|
|
+ QUANTIZEM4ROW_STORE
|
|
|
+
|
|
|
+ "QUIT%=: \n\t"
|
|
|
+ : [SRC] "+r"(SRC)
|
|
|
+ : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
|
|
|
+ [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
|
|
|
+ : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11");
|
|
|
+ }
|
|
|
+ } else if (BlkLen == 128) {
|
|
|
+ for (size_t row_index = 0; row_index < 4; ++row_index) {
|
|
|
+ const float * SRC = A + row_index * CountK;
|
|
|
+ std::byte * DST = QuantA + row_index * sizeof(float);
|
|
|
+
|
|
|
+ const size_t offset = (4 - row_index) * 4 + row_index * 8;
|
|
|
+ const size_t stride = 4 * (sizeof(float) + BlkLen);
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "li t6, 32 \n\t"
|
|
|
+ "addi t2, %[CountK], 0 \n\t"
|
|
|
+ "addi a1, %[DST], 0 \n\t"
|
|
|
+ "add s1, a1, %[OFFSET] \n\t"
|
|
|
+ "blt t2, %[BlkLen], TAIL%= \n\t"
|
|
|
+
|
|
|
+ "LOOP%=: \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vle32.v v8, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "addi t2, t2, -128 \n\t"
|
|
|
+
|
|
|
+ "QUANTIZE%=: \n\t"
|
|
|
+ "add s1, a1, %[OFFSET] \n\t"
|
|
|
+ "vfabs.v v16, v0 \n\t"
|
|
|
+ "vfabs.v v24, v8 \n\t"
|
|
|
+ "vfmax.vv v16, v24, v16 \n\t"
|
|
|
+ "vfredmax.vs v24, v16, v24 \n\t"
|
|
|
+ "vfmv.f.s f10, v24 \n\t"
|
|
|
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
|
|
|
+ "fsw f10, (a1) \n\t"
|
|
|
+ "fdiv.s f11, %[FONE], f10 \n\t"
|
|
|
+ "vfmul.vf v16, v0, f11 \n\t"
|
|
|
+ "vfmul.vf v24, v8, f11 \n\t"
|
|
|
+ "vfcvt.x.f.v v16, v16 \n\t"
|
|
|
+ "vfcvt.x.f.v v24, v24 \n\t"
|
|
|
+ "vsetvli t0, zero, e16, m4 \n\t"
|
|
|
+ "vnclip.wx v16, v16, zero \n\t"
|
|
|
+ "vnclip.wx v20, v24, zero \n\t"
|
|
|
+ "vsetvli t0, zero, e8, m4 \n\t"
|
|
|
+ "vnclip.wx v16, v16, zero \n\t"
|
|
|
+ "vsetvli t0, zero, e64, m4 \n\t"
|
|
|
+ "vsse64.v v16, (s1), t6 \n\t"
|
|
|
+ "add a1, a1, %[STRIDE] \n\t"
|
|
|
+ "bge t2, %[BlkLen], LOOP%= \n\t"
|
|
|
+
|
|
|
+ "TAIL%=: \n\t"
|
|
|
+ "blez t2, QUIT%= \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v0, v0, v0 \n\t"
|
|
|
+ "vxor.vv v8, v8, v8 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ "vxor.vv v24, v24, v24 \n\t"
|
|
|
+ "vsetvli t0, t2, e32, m8 \n\t"
|
|
|
+ "sub t2, t2, t0 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vsetvli t0, t2, e32, m8 \n\t"
|
|
|
+ "vle32.v v8, (%[SRC]) \n\t"
|
|
|
+ "sub t2, t2, t2 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "jal x0, QUANTIZE%= \n\t"
|
|
|
+
|
|
|
+ "QUIT%=: \n\t"
|
|
|
+ : [SRC] "+r"(SRC)
|
|
|
+ : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
|
|
|
+ [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
|
|
|
+ : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
|
|
|
+ }
|
|
|
+ } else if (BlkLen == 256) {
|
|
|
+ for (size_t row_index = 0; row_index < 4; ++row_index) {
|
|
|
+ const float * SRC = A + row_index * CountK;
|
|
|
+ std::byte * DST = QuantA + row_index * sizeof(float);
|
|
|
+ const size_t offset = (4 - row_index) * 4 + row_index * 8;
|
|
|
+ const size_t stride = 4 * (sizeof(float) + BlkLen);
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "li t6, 32 \n\t"
|
|
|
+ "addi t2, %[CountK], 0 \n\t"
|
|
|
+ "addi a1, %[DST], 0 \n\t"
|
|
|
+ "add s1, a1, %[OFFSET] \n\t"
|
|
|
+ "blt t2, %[BlkLen], TAIL%= \n\t"
|
|
|
+
|
|
|
+ "LOOP%=: \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vle32.v v8, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vle32.v v16, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vle32.v v24, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], -768 \n\t"
|
|
|
+ "addi t2, t2, -256 \n\t"
|
|
|
+ "vfabs.v v0, v0 \n\t"
|
|
|
+ "vfabs.v v8, v8 \n\t"
|
|
|
+ "vfabs.v v16, v16 \n\t"
|
|
|
+ "vfabs.v v24, v24 \n\t"
|
|
|
+ "vfmax.vv v8, v0, v8 \n\t"
|
|
|
+ "vfmax.vv v24, v24, v16 \n\t"
|
|
|
+ "vfmax.vv v8, v8, v24 \n\t"
|
|
|
+ "vfredmax.vs v24, v8, v24 \n\t"
|
|
|
+ "vfmv.f.s f10, v24 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vle32.v v8, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vle32.v v16, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vle32.v v24, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+
|
|
|
+ "QUANTIZE%=: \n\t"
|
|
|
+ "add s1, a1, %[OFFSET] \n\t"
|
|
|
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
|
|
|
+ "fsw f10, (a1) \n\t"
|
|
|
+ "fdiv.s f11, %[FONE], f10 \n\t"
|
|
|
+ "vfmul.vf v0, v0, f11 \n\t"
|
|
|
+ "vfmul.vf v8, v8, f11 \n\t"
|
|
|
+ "vfmul.vf v16, v16, f11 \n\t"
|
|
|
+ "vfmul.vf v24, v24, f11 \n\t"
|
|
|
+ "vfcvt.x.f.v v0, v0 \n\t"
|
|
|
+ "vfcvt.x.f.v v8, v8 \n\t"
|
|
|
+ "vfcvt.x.f.v v16, v16 \n\t"
|
|
|
+ "vfcvt.x.f.v v24, v24 \n\t"
|
|
|
+ "vsetvli t0, zero, e16, m4 \n\t"
|
|
|
+ "vnclip.wx v0, v0, zero \n\t"
|
|
|
+ "vnclip.wx v4, v8, zero \n\t"
|
|
|
+ "vnclip.wx v8, v16, zero \n\t"
|
|
|
+ "vnclip.wx v12, v24, zero \n\t"
|
|
|
+ "vsetvli t0, zero, e8, m4 \n\t"
|
|
|
+ "vnclip.wx v0, v0, zero \n\t"
|
|
|
+ "vnclip.wx v4, v8, zero \n\t"
|
|
|
+ "vsetvli t0, zero, e64, m8 \n\t"
|
|
|
+ "vsse64.v v0, (s1), t6 \n\t"
|
|
|
+ "add a1, a1, %[STRIDE] \n\t"
|
|
|
+ "bge t2, %[BlkLen], LOOP%= \n\t"
|
|
|
+
|
|
|
+ "TAIL%=: \n\t"
|
|
|
+ "blez t2, QUIT%= \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v0, v0, v0 \n\t"
|
|
|
+ "vxor.vv v8, v8, v8 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ "vxor.vv v24, v24, v24 \n\t"
|
|
|
+ "addi t1, t2, 0 \n\t"
|
|
|
+ "vsetvli t0, t1, e32, m8 \n\t"
|
|
|
+ "sub t1, t1, t0 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vsetvli t0, t1, e32, m8 \n\t"
|
|
|
+ "sub t1, t1, t0 \n\t"
|
|
|
+ "vle32.v v8, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vsetvli t0, t1, e32, m8 \n\t"
|
|
|
+ "sub t1, t1, t0 \n\t"
|
|
|
+ "vle32.v v16, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vsetvli t0, t1, e32, m8 \n\t"
|
|
|
+ "vle32.v v24, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], -768 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vfabs.v v0, v0 \n\t"
|
|
|
+ "vfabs.v v8, v8 \n\t"
|
|
|
+ "vfabs.v v16, v16 \n\t"
|
|
|
+ "vfabs.v v24, v24 \n\t"
|
|
|
+ "vfmax.vv v8, v0, v8 \n\t"
|
|
|
+ "vfmax.vv v24, v16, v24 \n\t"
|
|
|
+ "vfmax.vv v8, v8, v24 \n\t"
|
|
|
+ "vfredmax.vs v24, v8, v24 \n\t"
|
|
|
+ "vfmv.f.s f10, v24 \n\t"
|
|
|
+ "add s1, a1, %[OFFSET] \n\t"
|
|
|
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
|
|
|
+ "fsw f10, (a1) \n\t"
|
|
|
+ "fdiv.s f11, %[FONE], f10 \n\t"
|
|
|
+ "vsetvli t0, zero, e64, m8 \n\t"
|
|
|
+ "vxor.vv v0, v0, v0 \n\t"
|
|
|
+ "vsse64.v v0, (s1), t6 \n\t"
|
|
|
+
|
|
|
+ "TAIL_LOOP%=: \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m4 \n\t"
|
|
|
+ "vxor.vv v0, v0, v0 \n\t"
|
|
|
+ "vsetvli t0, t2, e32, m1 \n\t"
|
|
|
+ "sub t2, t2, t0 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 32 \n\t"
|
|
|
+ "vfmul.vf v1, v0, f11 \n\t"
|
|
|
+ "vfcvt.x.f.v v2, v1 \n\t"
|
|
|
+ "vsetvli t0, zero, e16, mf2 \n\t"
|
|
|
+ "vnclip.wx v3, v2, zero \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf4 \n\t"
|
|
|
+ "vnclip.wx v3, v3, zero \n\t"
|
|
|
+ "vse8.v v3, (s1) \n\t"
|
|
|
+ "addi s1, s1, 32 \n\t"
|
|
|
+ "bnez t2, TAIL_LOOP%= \n\t"
|
|
|
+
|
|
|
+ "QUIT%=: \n\t"
|
|
|
+ : [SRC] "+r"(SRC)
|
|
|
+ : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
|
|
|
+ [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
|
|
|
+ : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
|
|
|
+ const float * SRC = A;
|
|
|
+ std::byte * DST = QuantA;
|
|
|
+ constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
|
|
|
+ const float fone = 1.0f;
|
|
|
+ std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen);
|
|
|
+ size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK;
|
|
|
+
|
|
|
+ if (CountK <= BlkLen) {
|
|
|
+ float max_abs_A = 0.0f;
|
|
|
+ for (size_t k = 0; k < CountK; k++) {
|
|
|
+ max_abs_A = std::max(max_abs_A, fabsf(A[k]));
|
|
|
+ }
|
|
|
+ float scale_A = max_abs_A * range_max_reciprocal;
|
|
|
+
|
|
|
+ ((float *) QuantA)[0] = scale_A;
|
|
|
+
|
|
|
+ auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float));
|
|
|
+
|
|
|
+ for (size_t k = 0; k < CountK; k++) {
|
|
|
+ QuantAData_offset[k] =
|
|
|
+ (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits<int8_t>::lowest(),
|
|
|
+ (float) std::numeric_limits<int8_t>::max());
|
|
|
+ }
|
|
|
+ for (size_t k = CountK; k < BlkLen; k++) {
|
|
|
+ QuantAData_offset[k] = 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) {
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, zero, e8, m8 \n\t"
|
|
|
+ "vxor.vv v24, v24, v24 \n\t"
|
|
|
+ "LOOP%=: \n\t"
|
|
|
+ "vsetvli t0, %[CNT], e8, m8 \n\t"
|
|
|
+ "vse8.v v24, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 128 \n\t"
|
|
|
+ "sub %[CNT], %[CNT], t0 \n\t"
|
|
|
+ "bnez %[CNT], LOOP%= \n\t"
|
|
|
+ : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset)
|
|
|
+ :
|
|
|
+ : "cc", "t0");
|
|
|
+ }
|
|
|
+ if (BlkLen == 16) {
|
|
|
+ float buffer[64] = { 0.0f };
|
|
|
+ __asm__ volatile(
|
|
|
+ "addi t3, zero, 16*8 \n\t"
|
|
|
+ "addi t2, zero, 16 \n\t"
|
|
|
+ "blt %[K], t3, LOOP_K%= \n\t"
|
|
|
+ "blt %[K], t2, TAIL%= \n\t"
|
|
|
+ "LOOP_MAIN%=: \n\t"
|
|
|
+ "vsetvli t1, zero, e32, m2 \n\t"
|
|
|
+ "addi %[K], %[K], -128 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 64 \n\t"
|
|
|
+ "vle32.v v2, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 64 \n\t"
|
|
|
+ "vle32.v v4, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 64 \n\t"
|
|
|
+ "vle32.v v6, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 64 \n\t"
|
|
|
+ "vle32.v v8, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 64 \n\t"
|
|
|
+ "vle32.v v10, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 64 \n\t"
|
|
|
+ "vle32.v v12, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 64 \n\t"
|
|
|
+ "vle32.v v14, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 64 \n\t"
|
|
|
+ "addi a1, %[BUFFER], 0 \n\t"
|
|
|
+ "vfabs.v v16, v0 \n\t"
|
|
|
+ "vfabs.v v18, v2 \n\t"
|
|
|
+ "vfabs.v v20, v4 \n\t"
|
|
|
+ "vfabs.v v22, v6 \n\t"
|
|
|
+ "vfabs.v v24, v8 \n\t"
|
|
|
+ "vfabs.v v26, v10 \n\t"
|
|
|
+ "vfabs.v v28, v12 \n\t"
|
|
|
+ "vfabs.v v30, v14 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m1 \n\t"
|
|
|
+ "vfmax.vv v16, v16, v17 \n\t"
|
|
|
+ "vfmax.vv v18, v18, v19 \n\t"
|
|
|
+ "vfmax.vv v20, v20, v21 \n\t"
|
|
|
+ "vfmax.vv v22, v22, v23 \n\t"
|
|
|
+ "vfmax.vv v24, v24, v25 \n\t"
|
|
|
+ "vfmax.vv v26, v26, v27 \n\t"
|
|
|
+ "vfmax.vv v28, v28, v29 \n\t"
|
|
|
+ "vfmax.vv v30, v30, v31 \n\t"
|
|
|
+ "vse32.v v16, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vse32.v v18, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vse32.v v20, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vse32.v v22, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vse32.v v24, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vse32.v v26, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vse32.v v28, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vse32.v v30, (a1) \n\t"
|
|
|
+ "addi a1, %[BUFFER], 0 \n\t"
|
|
|
+ "flw f0, (a1) \n\t"
|
|
|
+ "flw f1, 4(a1) \n\t"
|
|
|
+ "flw f2, 8(a1) \n\t"
|
|
|
+ "flw f3, 12(a1) \n\t"
|
|
|
+ "flw f4, 16(a1) \n\t"
|
|
|
+ "flw f5, 20(a1) \n\t"
|
|
|
+ "flw f6, 24(a1) \n\t"
|
|
|
+ "flw f7, 28(a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "fmax.s f1, f0, f1 \n\t"
|
|
|
+ "fmax.s f3, f2, f3 \n\t"
|
|
|
+ "fmax.s f5, f4, f5 \n\t"
|
|
|
+ "fmax.s f7, f6, f7 \n\t"
|
|
|
+ "fmax.s f3, f1, f3 \n\t"
|
|
|
+ "fmax.s f7, f5, f7 \n\t"
|
|
|
+ "fmax.s f10, f3, f7 \n\t"
|
|
|
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
|
|
|
+ "fsw f10, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 20 \n\t"
|
|
|
+ "fdiv.s f10, %[FONE], f10 \n\t"
|
|
|
+ "flw f0, (a1) \n\t"
|
|
|
+ "flw f1, 4(a1) \n\t"
|
|
|
+ "flw f2, 8(a1) \n\t"
|
|
|
+ "flw f3, 12(a1) \n\t"
|
|
|
+ "flw f4, 16(a1) \n\t"
|
|
|
+ "flw f5, 20(a1) \n\t"
|
|
|
+ "flw f6, 24(a1) \n\t"
|
|
|
+ "flw f7, 28(a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "fmax.s f1, f0, f1 \n\t"
|
|
|
+ "fmax.s f3, f2, f3 \n\t"
|
|
|
+ "fmax.s f5, f4, f5 \n\t"
|
|
|
+ "fmax.s f7, f6, f7 \n\t"
|
|
|
+ "fmax.s f3, f1, f3 \n\t"
|
|
|
+ "fmax.s f7, f5, f7 \n\t"
|
|
|
+ "fmax.s f11, f3, f7 \n\t"
|
|
|
+ "fmul.s f11, f11, %[RMAXREC] \n\t"
|
|
|
+ "fsw f11, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 20 \n\t"
|
|
|
+ "fdiv.s f11, %[FONE], f11 \n\t"
|
|
|
+ "flw f0, (a1) \n\t"
|
|
|
+ "flw f1, 4(a1) \n\t"
|
|
|
+ "flw f2, 8(a1) \n\t"
|
|
|
+ "flw f3, 12(a1) \n\t"
|
|
|
+ "flw f4, 16(a1) \n\t"
|
|
|
+ "flw f5, 20(a1) \n\t"
|
|
|
+ "flw f6, 24(a1) \n\t"
|
|
|
+ "flw f7, 28(a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "fmax.s f1, f0, f1 \n\t"
|
|
|
+ "fmax.s f3, f2, f3 \n\t"
|
|
|
+ "fmax.s f5, f4, f5 \n\t"
|
|
|
+ "fmax.s f7, f6, f7 \n\t"
|
|
|
+ "fmax.s f3, f1, f3 \n\t"
|
|
|
+ "fmax.s f7, f5, f7 \n\t"
|
|
|
+ "fmax.s f12, f3, f7 \n\t"
|
|
|
+ "fmul.s f12, f12, %[RMAXREC] \n\t"
|
|
|
+ "fsw f12, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 20 \n\t"
|
|
|
+ "fdiv.s f12, %[FONE], f12 \n\t"
|
|
|
+ "flw f0, (a1) \n\t"
|
|
|
+ "flw f1, 4(a1) \n\t"
|
|
|
+ "flw f2, 8(a1) \n\t"
|
|
|
+ "flw f3, 12(a1) \n\t"
|
|
|
+ "flw f4, 16(a1) \n\t"
|
|
|
+ "flw f5, 20(a1) \n\t"
|
|
|
+ "flw f6, 24(a1) \n\t"
|
|
|
+ "flw f7, 28(a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "fmax.s f1, f0, f1 \n\t"
|
|
|
+ "fmax.s f3, f2, f3 \n\t"
|
|
|
+ "fmax.s f5, f4, f5 \n\t"
|
|
|
+ "fmax.s f7, f6, f7 \n\t"
|
|
|
+ "fmax.s f3, f1, f3 \n\t"
|
|
|
+ "fmax.s f7, f5, f7 \n\t"
|
|
|
+ "fmax.s f13, f3, f7 \n\t"
|
|
|
+ "fmul.s f13, f13, %[RMAXREC] \n\t"
|
|
|
+ "fsw f13, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 20 \n\t"
|
|
|
+ "fdiv.s f13, %[FONE], f13 \n\t"
|
|
|
+ "flw f0, (a1) \n\t"
|
|
|
+ "flw f1, 4(a1) \n\t"
|
|
|
+ "flw f2, 8(a1) \n\t"
|
|
|
+ "flw f3, 12(a1) \n\t"
|
|
|
+ "flw f4, 16(a1) \n\t"
|
|
|
+ "flw f5, 20(a1) \n\t"
|
|
|
+ "flw f6, 24(a1) \n\t"
|
|
|
+ "flw f7, 28(a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "fmax.s f1, f0, f1 \n\t"
|
|
|
+ "fmax.s f3, f2, f3 \n\t"
|
|
|
+ "fmax.s f5, f4, f5 \n\t"
|
|
|
+ "fmax.s f7, f6, f7 \n\t"
|
|
|
+ "fmax.s f3, f1, f3 \n\t"
|
|
|
+ "fmax.s f7, f5, f7 \n\t"
|
|
|
+ "fmax.s f14, f3, f7 \n\t"
|
|
|
+ "fmul.s f14, f14, %[RMAXREC] \n\t"
|
|
|
+ "fsw f14, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 20 \n\t"
|
|
|
+ "fdiv.s f14, %[FONE], f14 \n\t"
|
|
|
+ "flw f0, (a1) \n\t"
|
|
|
+ "flw f1, 4(a1) \n\t"
|
|
|
+ "flw f2, 8(a1) \n\t"
|
|
|
+ "flw f3, 12(a1) \n\t"
|
|
|
+ "flw f4, 16(a1) \n\t"
|
|
|
+ "flw f5, 20(a1) \n\t"
|
|
|
+ "flw f6, 24(a1) \n\t"
|
|
|
+ "flw f7, 28(a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "fmax.s f1, f0, f1 \n\t"
|
|
|
+ "fmax.s f3, f2, f3 \n\t"
|
|
|
+ "fmax.s f5, f4, f5 \n\t"
|
|
|
+ "fmax.s f7, f6, f7 \n\t"
|
|
|
+ "fmax.s f3, f1, f3 \n\t"
|
|
|
+ "fmax.s f7, f5, f7 \n\t"
|
|
|
+ "fmax.s f15, f3, f7 \n\t"
|
|
|
+ "fmul.s f15, f15, %[RMAXREC] \n\t"
|
|
|
+ "fsw f15, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 20 \n\t"
|
|
|
+ "fdiv.s f15, %[FONE], f15 \n\t"
|
|
|
+ "flw f0, (a1) \n\t"
|
|
|
+ "flw f1, 4(a1) \n\t"
|
|
|
+ "flw f2, 8(a1) \n\t"
|
|
|
+ "flw f3, 12(a1) \n\t"
|
|
|
+ "flw f4, 16(a1) \n\t"
|
|
|
+ "flw f5, 20(a1) \n\t"
|
|
|
+ "flw f6, 24(a1) \n\t"
|
|
|
+ "flw f7, 28(a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "fmax.s f1, f0, f1 \n\t"
|
|
|
+ "fmax.s f3, f2, f3 \n\t"
|
|
|
+ "fmax.s f5, f4, f5 \n\t"
|
|
|
+ "fmax.s f7, f6, f7 \n\t"
|
|
|
+ "fmax.s f3, f1, f3 \n\t"
|
|
|
+ "fmax.s f7, f5, f7 \n\t"
|
|
|
+ "fmax.s f16, f3, f7 \n\t"
|
|
|
+ "fmul.s f16, f16, %[RMAXREC] \n\t"
|
|
|
+ "fsw f16, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 20 \n\t"
|
|
|
+ "fdiv.s f16, %[FONE], f16 \n\t"
|
|
|
+ "flw f0, (a1) \n\t"
|
|
|
+ "flw f1, 4(a1) \n\t"
|
|
|
+ "flw f2, 8(a1) \n\t"
|
|
|
+ "flw f3, 12(a1) \n\t"
|
|
|
+ "flw f4, 16(a1) \n\t"
|
|
|
+ "flw f5, 20(a1) \n\t"
|
|
|
+ "flw f6, 24(a1) \n\t"
|
|
|
+ "flw f7, 28(a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "fmax.s f1, f0, f1 \n\t"
|
|
|
+ "fmax.s f3, f2, f3 \n\t"
|
|
|
+ "fmax.s f5, f4, f5 \n\t"
|
|
|
+ "fmax.s f7, f6, f7 \n\t"
|
|
|
+ "fmax.s f3, f1, f3 \n\t"
|
|
|
+ "fmax.s f7, f5, f7 \n\t"
|
|
|
+ "fmax.s f17, f3, f7 \n\t"
|
|
|
+ "fmul.s f17, f17, %[RMAXREC] \n\t"
|
|
|
+ "fsw f17, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], -136 \n\t"
|
|
|
+ "fdiv.s f17, %[FONE], f17 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m2 \n\t"
|
|
|
+ "vfmul.vf v16, v0, f10 \n\t"
|
|
|
+ "vfmul.vf v18, v2, f11 \n\t"
|
|
|
+ "vfmul.vf v20, v4, f12 \n\t"
|
|
|
+ "vfmul.vf v22, v6, f13 \n\t"
|
|
|
+ "vfmul.vf v24, v8, f14 \n\t"
|
|
|
+ "vfmul.vf v26, v10, f15 \n\t"
|
|
|
+ "vfmul.vf v28, v12, f16 \n\t"
|
|
|
+ "vfmul.vf v30, v14, f17 \n\t"
|
|
|
+ "vfcvt.x.f.v v16, v16 \n\t"
|
|
|
+ "vfcvt.x.f.v v18, v18 \n\t"
|
|
|
+ "vfcvt.x.f.v v20, v20 \n\t"
|
|
|
+ "vfcvt.x.f.v v22, v22 \n\t"
|
|
|
+ "vfcvt.x.f.v v24, v24 \n\t"
|
|
|
+ "vfcvt.x.f.v v26, v26 \n\t"
|
|
|
+ "vfcvt.x.f.v v28, v28 \n\t"
|
|
|
+ "vfcvt.x.f.v v30, v30 \n\t"
|
|
|
+ "vsetvli t0, zero, e16, m1 \n\t"
|
|
|
+ "vnclip.wx v16, v16, zero \n\t"
|
|
|
+ "vnclip.wx v18, v18, zero \n\t"
|
|
|
+ "vnclip.wx v20, v20, zero \n\t"
|
|
|
+ "vnclip.wx v22, v22, zero \n\t"
|
|
|
+ "vnclip.wx v24, v24, zero \n\t"
|
|
|
+ "vnclip.wx v26, v26, zero \n\t"
|
|
|
+ "vnclip.wx v28, v28, zero \n\t"
|
|
|
+ "vnclip.wx v30, v30, zero \n\t"
|
|
|
+ "vsetvli t0, t1, e8, mf2 \n\t"
|
|
|
+ "vnclip.wx v16, v16, zero \n\t"
|
|
|
+ "vnclip.wx v18, v18, zero \n\t"
|
|
|
+ "vnclip.wx v20, v20, zero \n\t"
|
|
|
+ "vnclip.wx v22, v22, zero \n\t"
|
|
|
+ "vnclip.wx v24, v24, zero \n\t"
|
|
|
+ "vnclip.wx v26, v26, zero \n\t"
|
|
|
+ "vnclip.wx v28, v28, zero \n\t"
|
|
|
+ "vnclip.wx v30, v30, zero \n\t"
|
|
|
+ "vse8.v v16, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 20 \n\t"
|
|
|
+ "vse8.v v18, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 20 \n\t"
|
|
|
+ "vse8.v v20, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 20 \n\t"
|
|
|
+ "vse8.v v22, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 20 \n\t"
|
|
|
+ "vse8.v v24, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 20 \n\t"
|
|
|
+ "vse8.v v26, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 20 \n\t"
|
|
|
+ "vse8.v v28, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 20 \n\t"
|
|
|
+ "vse8.v v30, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 16 \n\t"
|
|
|
+ "bge %[K], t3, LOOP_MAIN%= \n\t"
|
|
|
+ "blt %[K], t2, TAIL%= \n\t"
|
|
|
+ "LOOP_K%=: \n\t"
|
|
|
+ "vsetvli t1, %[K], e32, m2 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 64 \n\t"
|
|
|
+ "sub %[K], %[K], t1 \n\t"
|
|
|
+ "vfabs.v v16, v0 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m1 \n\t"
|
|
|
+ "vfmax.vv v16, v16, v17 \n\t"
|
|
|
+ "vse32.v v16, (%[BUFFER]) \n\t"
|
|
|
+ "flw f0, (%[BUFFER]) \n\t"
|
|
|
+ "flw f1, 4(%[BUFFER]) \n\t"
|
|
|
+ "flw f2, 8(%[BUFFER]) \n\t"
|
|
|
+ "flw f3, 12(%[BUFFER]) \n\t"
|
|
|
+ "flw f4, 16(%[BUFFER]) \n\t"
|
|
|
+ "flw f5, 20(%[BUFFER]) \n\t"
|
|
|
+ "flw f6, 24(%[BUFFER]) \n\t"
|
|
|
+ "flw f7, 28(%[BUFFER]) \n\t"
|
|
|
+ "fmax.s f1, f0, f1 \n\t"
|
|
|
+ "fmax.s f3, f2, f3 \n\t"
|
|
|
+ "fmax.s f5, f4, f5 \n\t"
|
|
|
+ "fmax.s f7, f6, f7 \n\t"
|
|
|
+ "fmax.s f3, f1, f3 \n\t"
|
|
|
+ "fmax.s f7, f5, f7 \n\t"
|
|
|
+ "fmax.s f10, f3, f7 \n\t"
|
|
|
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
|
|
|
+ "fsw f10, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 4 \n\t"
|
|
|
+ "fdiv.s f11, %[FONE], f10 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m2 \n\t"
|
|
|
+ "vfmul.vf v16, v0, f11 \n\t"
|
|
|
+ "vfcvt.x.f.v v16, v16 \n\t"
|
|
|
+ "vsetvli t0, zero, e16, m1 \n\t"
|
|
|
+ "vnclip.wx v16, v16, zero \n\t"
|
|
|
+ "vsetvli t0, t1, e8, mf2 \n\t"
|
|
|
+ "vnclip.wx v16, v16, zero \n\t"
|
|
|
+ "vse8.v v16, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 16 \n\t"
|
|
|
+ "bge %[K], t2, LOOP_K%= \n\t"
|
|
|
+ "TAIL%=: \n\t"
|
|
|
+ "blez %[K], END%= \n\t"
|
|
|
+ "vsetvli t0, t3, e32, m2 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ "jal x0, LOOP_K%= \n\t"
|
|
|
+ "END%=: \n\t"
|
|
|
+ : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
|
|
|
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer)
|
|
|
+ : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12",
|
|
|
+ "f13", "f14", "f15", "f16", "f17");
|
|
|
+ } else if (BlkLen == 32) {
|
|
|
+ __asm__ volatile(
|
|
|
+ "addi t3, zero, 32*4 \n\t"
|
|
|
+ "addi t2, zero, 32 \n\t"
|
|
|
+
|
|
|
+ "addi a1, %[SRC], 0 \n\t"
|
|
|
+ "addi a2, %[SRC], 128 \n\t"
|
|
|
+ "addi a3, %[SRC], 256 \n\t"
|
|
|
+ "addi a4, %[SRC], 384 \n\t"
|
|
|
+
|
|
|
+ "addi s1, %[DST], 0 \n\t"
|
|
|
+ "addi s2, %[DST], 36 \n\t"
|
|
|
+ "addi s3, %[DST], 72 \n\t"
|
|
|
+ "addi s4, %[DST], 108 \n\t"
|
|
|
+ "blt %[K], t3, LOOP_K%= \n\t"
|
|
|
+ "blt %[K], t2, TAIL%= \n\t"
|
|
|
+
|
|
|
+ "LOOP_MAIN%=: \n\t"
|
|
|
+ "vsetvli t1, zero, e32, m4 \n\t"
|
|
|
+ "addi %[K], %[K], -128 \n\t"
|
|
|
+ "vle32.v v0, (a1) \n\t"
|
|
|
+ "addi a1, a1, 512 \n\t"
|
|
|
+ "vle32.v v4, (a2) \n\t"
|
|
|
+ "addi a2, a2, 512 \n\t"
|
|
|
+ "vle32.v v8, (a3) \n\t"
|
|
|
+ "addi a3, a3, 512 \n\t"
|
|
|
+ "vle32.v v12, (a4) \n\t"
|
|
|
+ "addi a4, a4, 512 \n\t"
|
|
|
+ "vfabs.v v16, v0 \n\t"
|
|
|
+ "vfabs.v v20, v4 \n\t"
|
|
|
+ "vfabs.v v24, v8 \n\t"
|
|
|
+ "vfabs.v v28, v12 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m2 \n\t"
|
|
|
+ "vfmax.vv v16, v16, v18 \n\t"
|
|
|
+ "vfmax.vv v20, v20, v22 \n\t"
|
|
|
+ "vfmax.vv v24, v24, v26 \n\t"
|
|
|
+ "vfmax.vv v28, v28, v30 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m1 \n\t"
|
|
|
+ "vfmax.vv v16, v16, v17 \n\t"
|
|
|
+ "vfmax.vv v20, v20, v21 \n\t"
|
|
|
+ "vfmax.vv v24, v24, v25 \n\t"
|
|
|
+ "vfmax.vv v28, v28, v29 \n\t"
|
|
|
+
|
|
|
+ "vfredmax.vs v17, v16, v17 \n\t"
|
|
|
+ "vfredmax.vs v21, v20, v21 \n\t"
|
|
|
+ "vfredmax.vs v25, v24, v25 \n\t"
|
|
|
+ "vfredmax.vs v29, v28, v29 \n\t"
|
|
|
+ "vfmv.f.s f10, v17 \n\t"
|
|
|
+ "vfmv.f.s f11, v21 \n\t"
|
|
|
+ "vfmv.f.s f12, v25 \n\t"
|
|
|
+ "vfmv.f.s f13, v29 \n\t"
|
|
|
+
|
|
|
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
|
|
|
+ "fmul.s f11, f11, %[RMAXREC] \n\t"
|
|
|
+ "fmul.s f12, f12, %[RMAXREC] \n\t"
|
|
|
+ "fmul.s f13, f13, %[RMAXREC] \n\t"
|
|
|
+ "fsw f10, (s1) \n\t"
|
|
|
+ "addi s1, s1, 4 \n\t"
|
|
|
+
|
|
|
+ "fsw f11, (s2) \n\t"
|
|
|
+ "addi s2, s2, 4 \n\t"
|
|
|
+ "fsw f12, (s3) \n\t"
|
|
|
+ "addi s3, s3, 4 \n\t"
|
|
|
+ "fsw f13, (s4) \n\t"
|
|
|
+ "addi s4, s4, 4 \n\t"
|
|
|
+ "fdiv.s f10, %[FONE], f10 \n\t"
|
|
|
+ "fdiv.s f11, %[FONE], f11 \n\t"
|
|
|
+ "fdiv.s f12, %[FONE], f12 \n\t"
|
|
|
+ "fdiv.s f13, %[FONE], f13 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m4 \n\t"
|
|
|
+ "vfmul.vf v16, v0, f10 \n\t"
|
|
|
+ "vfmul.vf v20, v4, f11 \n\t"
|
|
|
+ "vfmul.vf v24, v8, f12 \n\t"
|
|
|
+ "vfmul.vf v28, v12, f13 \n\t"
|
|
|
+ "vfcvt.x.f.v v16, v16 \n\t"
|
|
|
+ "vfcvt.x.f.v v20, v20 \n\t"
|
|
|
+ "vfcvt.x.f.v v24, v24 \n\t"
|
|
|
+ "vfcvt.x.f.v v28, v28 \n\t"
|
|
|
+ "vsetvli t0, zero, e16, m2 \n\t"
|
|
|
+ "vnclip.wx v16, v16, zero \n\t"
|
|
|
+ "vnclip.wx v20, v20, zero \n\t"
|
|
|
+ "vnclip.wx v24, v24, zero \n\t"
|
|
|
+ "vnclip.wx v28, v28, zero \n\t"
|
|
|
+ "vsetvli t0, t1, e8, m1 \n\t"
|
|
|
+ "vnclip.wx v16, v16, zero \n\t"
|
|
|
+ "vnclip.wx v20, v20, zero \n\t"
|
|
|
+ "vnclip.wx v24, v24, zero \n\t"
|
|
|
+ "vnclip.wx v28, v28, zero \n\t"
|
|
|
+ "vse8.v v16, (s1) \n\t"
|
|
|
+ "addi s1, s1, 140 \n\t"
|
|
|
+ "vse8.v v20, (s2) \n\t"
|
|
|
+ "addi s2, s2, 140 \n\t"
|
|
|
+ "vse8.v v24, (s3) \n\t"
|
|
|
+ "addi s3, s3, 140 \n\t"
|
|
|
+ "vse8.v v28, (s4) \n\t"
|
|
|
+ "addi s4, s4, 140 \n\t"
|
|
|
+ "bge %[K], t3, LOOP_MAIN%= \n\t"
|
|
|
+ "blt %[K], t2, TAIL%= \n\t"
|
|
|
+ "LOOP_K%=: \n\t"
|
|
|
+ "vsetvli t1, %[K], e32, m4 \n\t"
|
|
|
+ "vle32.v v0, (a1) \n\t"
|
|
|
+ "addi a1, a1, 128 \n\t"
|
|
|
+ "sub %[K], %[K], t1 \n\t"
|
|
|
+ "vfabs.v v16, v0 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m2 \n\t"
|
|
|
+ "vfmax.vv v16, v16, v18 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m1 \n\t"
|
|
|
+ "vfmax.vv v16, v16, v17 \n\t"
|
|
|
+ "vfredmax.vs v17, v16, v17 \n\t"
|
|
|
+ "vfmv.f.s f10, v17 \n\t"
|
|
|
+
|
|
|
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
|
|
|
+ "fsw f10, (s1) \n\t"
|
|
|
+ "addi s1, s1, 4 \n\t"
|
|
|
+ "fdiv.s f11, %[FONE], f10 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m4 \n\t"
|
|
|
+ "vfmul.vf v16, v0, f11 \n\t"
|
|
|
+ "vfcvt.x.f.v v16, v16 \n\t"
|
|
|
+ "vsetvli t0, zero, e16, m2 \n\t"
|
|
|
+ "vnclip.wx v16, v16, zero \n\t"
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t"
|
|
|
+ "vnclip.wx v16, v16, zero \n\t"
|
|
|
+ "vse8.v v16, (s1) \n\t"
|
|
|
+ "addi s1, s1, 32 \n\t"
|
|
|
+ "bge %[K], t2, LOOP_K%= \n\t"
|
|
|
+ "TAIL%=: \n\t"
|
|
|
+ "blez %[K], END%= \n\t"
|
|
|
+ "vsetvli t0, t3, e32, m4 \n\t"
|
|
|
+ "vxor.vv v0, v0, v0 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ "jal x0, LOOP_K%= \n\t"
|
|
|
+ "END%=: \n\t"
|
|
|
+ : [K] "+r"(CountK)
|
|
|
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST)
|
|
|
+ : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13");
|
|
|
+ } else if (BlkLen == 64) {
|
|
|
+ __asm__ volatile(
|
|
|
+ "addi t3, zero, 64*2 \n\t"
|
|
|
+ "addi t2, zero, 64 \n\t"
|
|
|
+ "addi a1, %[SRC], 0 \n\t"
|
|
|
+ "addi a2, %[SRC], 256 \n\t"
|
|
|
+ "addi s1, %[DST], 0 \n\t"
|
|
|
+ "addi s2, %[DST], 68 \n\t"
|
|
|
+ "blt %[K], t3, LOOP_K%= \n\t"
|
|
|
+ "blt %[K], t2, TAIL%= \n\t"
|
|
|
+ "LOOP_MAIN%=: \n\t"
|
|
|
+ "vsetvli t1, zero, e32, m8 \n\t"
|
|
|
+ "addi %[K], %[K], -128 \n\t"
|
|
|
+ "vle32.v v0, (a1) \n\t"
|
|
|
+ "addi a1, a1, 512 \n\t"
|
|
|
+ "vle32.v v8, (a2) \n\t"
|
|
|
+ "addi a2, a2, 512 \n\t"
|
|
|
+ "vfabs.v v16, v0 \n\t"
|
|
|
+ "vfabs.v v24, v8 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m4 \n\t"
|
|
|
+ "vfmax.vv v16, v16, v20 \n\t"
|
|
|
+ "vfmax.vv v24, v24, v28 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m2 \n\t"
|
|
|
+ "vfmax.vv v16, v16, v18 \n\t"
|
|
|
+ "vfmax.vv v24, v24, v26 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m1 \n\t"
|
|
|
+ "vfmax.vv v16, v16, v17 \n\t"
|
|
|
+ "vfmax.vv v24, v24, v25 \n\t"
|
|
|
+ "vfredmax.vs v17, v16, v17 \n\t"
|
|
|
+ "vfredmax.vs v25, v24, v25 \n\t"
|
|
|
+ "vfmv.f.s f10, v17 \n\t"
|
|
|
+ "vfmv.f.s f11, v25 \n\t"
|
|
|
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
|
|
|
+ "fmul.s f11, f11, %[RMAXREC] \n\t"
|
|
|
+ "fsw f10, (s1) \n\t"
|
|
|
+ "addi s1, s1, 4 \n\t"
|
|
|
+ "fsw f11, (s2) \n\t"
|
|
|
+ "addi s2, s2, 4 \n\t"
|
|
|
+ "fdiv.s f10, %[FONE], f10 \n\t"
|
|
|
+ "fdiv.s f11, %[FONE], f11 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vfmul.vf v16, v0, f10 \n\t"
|
|
|
+ "vfmul.vf v24, v8, f11 \n\t"
|
|
|
+ "vfcvt.x.f.v v16, v16 \n\t"
|
|
|
+ "vfcvt.x.f.v v24, v24 \n\t"
|
|
|
+ "vsetvli t0, zero, e16, m4 \n\t"
|
|
|
+ "vnclip.wx v16, v16, zero \n\t"
|
|
|
+ "vnclip.wx v24, v24, zero \n\t"
|
|
|
+ "vsetvli t0, t1, e8, m2 \n\t"
|
|
|
+ "vnclip.wx v16, v16, zero \n\t"
|
|
|
+ "vnclip.wx v24, v24, zero \n\t"
|
|
|
+ "vse8.v v16, (s1) \n\t"
|
|
|
+ "addi s1, s1, 132 \n\t"
|
|
|
+ "vse8.v v24, (s2) \n\t"
|
|
|
+ "addi s2, s2, 132 \n\t"
|
|
|
+ "bge %[K], t3, LOOP_MAIN%= \n\t"
|
|
|
+ "blt %[K], t2, TAIL%= \n\t"
|
|
|
+ "LOOP_K%=: \n\t"
|
|
|
+ "vsetvli t1, %[K], e32, m8 \n\t"
|
|
|
+ "vle32.v v0, (a1) \n\t"
|
|
|
+ "addi a1, a1, 256 \n\t"
|
|
|
+ "sub %[K], %[K], t1 \n\t"
|
|
|
+ "vfabs.v v16, v0 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m4 \n\t"
|
|
|
+ "vfmax.vv v16, v16, v20 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m2 \n\t"
|
|
|
+ "vfmax.vv v16, v16, v18 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m1 \n\t"
|
|
|
+ "vfmax.vv v16, v16, v17 \n\t"
|
|
|
+ "vfredmax.vs v17, v16, v17 \n\t"
|
|
|
+ "vfmv.f.s f10, v17 \n\t"
|
|
|
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
|
|
|
+ "fsw f10, (s1) \n\t"
|
|
|
+ "addi s1, s1, 4 \n\t"
|
|
|
+ "fdiv.s f11, %[FONE], f10 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vfmul.vf v16, v0, f11 \n\t"
|
|
|
+ "vfcvt.x.f.v v16, v16 \n\t"
|
|
|
+ "vsetvli t0, zero, e16, m4 \n\t"
|
|
|
+ "vnclip.wx v16, v16, zero \n\t"
|
|
|
+ "vsetvli t0, zero, e8, m2 \n\t"
|
|
|
+ "vnclip.wx v16, v16, zero \n\t"
|
|
|
+ "vse8.v v16, (s1) \n\t"
|
|
|
+ "addi s1, s1, 64 \n\t"
|
|
|
+ "bge %[K], t2, LOOP_K%= \n\t"
|
|
|
+ "TAIL%=: \n\t"
|
|
|
+ "blez %[K], END%= \n\t"
|
|
|
+ "vsetvli t0, t3, e32, m8 \n\t"
|
|
|
+ "vxor.vv v0, v0, v0 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ "jal x0, LOOP_K%= \n\t"
|
|
|
+ "END%=: \n\t"
|
|
|
+ : [K] "+r"(CountK)
|
|
|
+ : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
|
|
|
+ : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11");
|
|
|
+ } else if (BlkLen == 128) {
|
|
|
+ __asm__ volatile(
|
|
|
+ "addi t2, zero, 128 \n\t"
|
|
|
+ "addi a1, %[SRC], 0 \n\t"
|
|
|
+ "addi a2, %[SRC], 256 \n\t"
|
|
|
+ "blt %[K], t2, TAIL%= \n\t"
|
|
|
+ "LOOP_K%=: \n\t"
|
|
|
+ "vsetvli t1, zero, e32, m8 \n\t"
|
|
|
+ "vle32.v v0, (a1) \n\t"
|
|
|
+ "addi a1, a1, 512 \n\t"
|
|
|
+ "vle32.v v8, (a2) \n\t"
|
|
|
+ "addi a2, a2, 512 \n\t"
|
|
|
+ "sub %[K], %[K], t2 \n\t"
|
|
|
+ "QUANT%=: \n\t"
|
|
|
+ "vfabs.v v16, v0 \n\t"
|
|
|
+ "vfabs.v v24, v8 \n\t"
|
|
|
+ "vfmax.vv v24, v16, v24 \n\t"
|
|
|
+ "vsetvli t1, zero, e32, m4 \n\t"
|
|
|
+ "vfmax.vv v28, v24, v28 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m2 \n\t"
|
|
|
+ "vfmax.vv v30, v28, v30 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m1 \n\t"
|
|
|
+ "vfmax.vv v30, v30, v31 \n\t"
|
|
|
+ "vfredmax.vs v31, v30, v31 \n\t"
|
|
|
+ "vfmv.f.s f10, v31 \n\t"
|
|
|
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
|
|
|
+ "fsw f10, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 4 \n\t"
|
|
|
+ "fdiv.s f11, %[FONE], f10 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vfmul.vf v16, v0, f11 \n\t"
|
|
|
+ "vfmul.vf v24, v8, f11 \n\t"
|
|
|
+ "vfcvt.x.f.v v16, v16 \n\t"
|
|
|
+ "vfcvt.x.f.v v24, v24 \n\t"
|
|
|
+ "vsetvli t0, zero, e16, m4 \n\t"
|
|
|
+ "vnclip.wx v16, v16, zero \n\t"
|
|
|
+ "vnclip.wx v20, v24, zero \n\t"
|
|
|
+ "vsetvli t0, zero, e8, m4 \n\t"
|
|
|
+ "vnclip.wx v16, v16, zero \n\t"
|
|
|
+ "vse8.v v16, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 128 \n\t"
|
|
|
+ "bge %[K], t2, LOOP_K%= \n\t"
|
|
|
+ "TAIL%=: \n\t"
|
|
|
+ "blez %[K], END%= \n\t"
|
|
|
+ "vsetvli t1, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v0, v0, v0 \n\t"
|
|
|
+ "vxor.vv v8, v8, v8 \n\t"
|
|
|
+ "vsetvli t0, %[K], e32, m8 \n\t"
|
|
|
+ "vle32.v v0, (a1) \n\t"
|
|
|
+ "sub %[K], %[K], t0 \n\t"
|
|
|
+ "vsetvli t0, %[K], e32, m8 \n\t"
|
|
|
+ "vle32.v v8, (a2) \n\t"
|
|
|
+ "sub %[K], %[K], t0 \n\t"
|
|
|
+ "vsetvli t1, zero, e32, m8 \n\t"
|
|
|
+ "jal x0, QUANT%= \n\t"
|
|
|
+ "END%=: \n\t"
|
|
|
+
|
|
|
+ : [DST] "+r"(DST), [K] "+r"(CountK)
|
|
|
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC)
|
|
|
+ : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11");
|
|
|
+ } else {
|
|
|
+ float buffer[8] = { 0.0f };
|
|
|
+ size_t cnt = BlkLen / 256;
|
|
|
+
|
|
|
+ __asm__ volatile(
|
|
|
+ "slli t3, %[BLK], 2 \n\t"
|
|
|
+ "blt %[K], %[BLK], LOOP_TAIL%= \n\t"
|
|
|
+ "LOOP_MAIN%=: \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m1 \n\t"
|
|
|
+ "vxor.vv v31, v31, v31 \n\t"
|
|
|
+ "vse32.v v31, (%[BUFFER]) \n\t"
|
|
|
+ "addi t6, %[CNT], 0 \n\t"
|
|
|
+ "LOOP_CMP%=: \n\t"
|
|
|
+ "addi t6, t6, -1 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vle32.v v8, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vle32.v v16, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vle32.v v24, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vfabs.v v0, v0 \n\t"
|
|
|
+ "vfabs.v v8, v8 \n\t"
|
|
|
+ "vfabs.v v16, v16 \n\t"
|
|
|
+ "vfabs.v v24, v24 \n\t"
|
|
|
+ "vfmax.vv v8, v0, v8 \n\t"
|
|
|
+ "vfmax.vv v16, v16, v24 \n\t"
|
|
|
+ "vfmax.vv v0, v0, v16 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m4 \n\t"
|
|
|
+ "vfmax.vv v0, v0, v4 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m2 \n\t"
|
|
|
+ "vfmax.vv v0, v0, v2 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m1 \n\t"
|
|
|
+ "vfmax.vv v0, v0, v1 \n\t"
|
|
|
+ "vle32.v v30, (%[BUFFER]) \n\t"
|
|
|
+ "vfmax.vv v31, v30, v0 \n\t"
|
|
|
+ "vse32.v v31, (%[BUFFER]) \n\t"
|
|
|
+ "bnez t6, LOOP_CMP%= \n\t"
|
|
|
+ "sub %[SRC], %[SRC], t3 \n\t"
|
|
|
+ "addi t6, %[CNT], 0 \n\t"
|
|
|
+ "flw f0, (%[BUFFER]) \n\t"
|
|
|
+ "flw f1, 4(%[BUFFER]) \n\t"
|
|
|
+ "flw f2, 8(%[BUFFER]) \n\t"
|
|
|
+ "flw f3, 12(%[BUFFER]) \n\t"
|
|
|
+ "flw f4, 16(%[BUFFER]) \n\t"
|
|
|
+ "flw f5, 20(%[BUFFER]) \n\t"
|
|
|
+ "flw f6, 24(%[BUFFER]) \n\t"
|
|
|
+ "flw f7, 28(%[BUFFER]) \n\t"
|
|
|
+ "fmax.s f1, f0, f1 \n\t"
|
|
|
+ "fmax.s f3, f2, f3 \n\t"
|
|
|
+ "fmax.s f5, f4, f5 \n\t"
|
|
|
+ "fmax.s f7, f6, f7 \n\t"
|
|
|
+ "fmax.s f3, f1, f3 \n\t"
|
|
|
+ "fmax.s f7, f5, f7 \n\t"
|
|
|
+ "fmax.s f10, f3, f7 \n\t"
|
|
|
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
|
|
|
+ "fsw f10, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 4 \n\t"
|
|
|
+ "fdiv.s f11, %[FONE], f10 \n\t"
|
|
|
+ "addi t6, %[CNT], 0 \n\t"
|
|
|
+ "LOOP_QUANT%=: \n\t"
|
|
|
+ "addi t6, t6, -1 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vle32.v v8, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vle32.v v16, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vle32.v v24, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vfmul.vf v0, v0, f11 \n\t"
|
|
|
+ "vfmul.vf v8, v8, f11 \n\t"
|
|
|
+ "vfmul.vf v16, v16, f11 \n\t"
|
|
|
+ "vfmul.vf v24, v24, f11 \n\t"
|
|
|
+ "vfcvt.x.f.v v0, v0 \n\t"
|
|
|
+ "vfcvt.x.f.v v8, v8 \n\t"
|
|
|
+ "vfcvt.x.f.v v16, v16 \n\t"
|
|
|
+ "vfcvt.x.f.v v24, v24 \n\t"
|
|
|
+ "vsetvli t0, zero, e16, m4 \n\t"
|
|
|
+ "vnclip.wx v0, v0, zero \n\t"
|
|
|
+ "vnclip.wx v4, v8, zero \n\t"
|
|
|
+ "vnclip.wx v8, v16, zero \n\t"
|
|
|
+ "vnclip.wx v12, v24, zero \n\t"
|
|
|
+ "vsetvli t0, zero, e8, m4 \n\t"
|
|
|
+ "vnclip.wx v0, v0, zero \n\t"
|
|
|
+ "vnclip.wx v4, v8, zero \n\t"
|
|
|
+ "vse8.v v0, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 128 \n\t"
|
|
|
+ "vse8.v v4, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 128 \n\t"
|
|
|
+ "bnez t6, LOOP_QUANT%= \n\t"
|
|
|
+ "sub %[K], %[K], %[BLK] \n\t"
|
|
|
+ "bge %[K], %[BLK], LOOP_MAIN%= \n\t"
|
|
|
+ "blez %[K], END%= \n\t"
|
|
|
+ "LOOP_TAIL%=: \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m1 \n\t"
|
|
|
+ "vxor.vv v31, v31, v31 \n\t"
|
|
|
+ "vse32.v v31, (%[BUFFER]) \n\t"
|
|
|
+ "addi t6, %[K], 0 \n\t"
|
|
|
+ "addi s1, %[SRC], 0 \n\t"
|
|
|
+ "TAIL_CMP%=: \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v0, v0, v0 \n\t"
|
|
|
+ "vsetvli t0, t6, e32, m8 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "addi %[SRC], %[SRC], 256 \n\t"
|
|
|
+ "sub t6, t6, t0 \n\t"
|
|
|
+ "vfabs.v v0, v0 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m4 \n\t"
|
|
|
+ "vfmax.vv v0, v0, v4 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m2 \n\t"
|
|
|
+ "vfmax.vv v0, v0, v2 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m1 \n\t"
|
|
|
+ "vfmax.vv v0, v0, v1 \n\t"
|
|
|
+ "vle32.v v30, (%[BUFFER]) \n\t"
|
|
|
+ "vfmax.vv v31, v30, v0 \n\t"
|
|
|
+ "vse32.v v31, (%[BUFFER]) \n\t"
|
|
|
+ "bnez t6, TAIL_CMP%= \n\t"
|
|
|
+ "addi t6, %[K], 0 \n\t"
|
|
|
+ "flw f0, (%[BUFFER]) \n\t"
|
|
|
+ "flw f1, 4(%[BUFFER]) \n\t"
|
|
|
+ "flw f2, 8(%[BUFFER]) \n\t"
|
|
|
+ "flw f3, 12(%[BUFFER]) \n\t"
|
|
|
+ "flw f4, 16(%[BUFFER]) \n\t"
|
|
|
+ "flw f5, 20(%[BUFFER]) \n\t"
|
|
|
+ "flw f6, 24(%[BUFFER]) \n\t"
|
|
|
+ "flw f7, 28(%[BUFFER]) \n\t"
|
|
|
+ "fmax.s f1, f0, f1 \n\t"
|
|
|
+ "fmax.s f3, f2, f3 \n\t"
|
|
|
+ "fmax.s f5, f4, f5 \n\t"
|
|
|
+ "fmax.s f7, f6, f7 \n\t"
|
|
|
+ "fmax.s f3, f1, f3 \n\t"
|
|
|
+ "fmax.s f7, f5, f7 \n\t"
|
|
|
+ "fmax.s f10, f3, f7 \n\t"
|
|
|
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
|
|
|
+ "fsw f10, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 4 \n\t"
|
|
|
+ "fdiv.s f11, %[FONE], f10 \n\t"
|
|
|
+ "addi t6, %[K], 0 \n\t"
|
|
|
+ "TAIL_QUANT%=: \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v0, v0, v0 \n\t"
|
|
|
+ "vsetvli t1, t6, e32, m8 \n\t"
|
|
|
+ "vle32.v v0, (s1) \n\t"
|
|
|
+ "addi s1, s1, 256 \n\t"
|
|
|
+ "sub t6, t6, t1 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vfmul.vf v0, v0, f11 \n\t"
|
|
|
+ "vfcvt.x.f.v v0, v0 \n\t"
|
|
|
+ "vsetvli t0, zero, e16, m4 \n\t"
|
|
|
+ "vnclip.wx v0, v0, zero \n\t"
|
|
|
+ "vsetvli t0, t1, e8, m2 \n\t"
|
|
|
+ "vnclip.wx v0, v0, zero \n\t"
|
|
|
+ "vse8.v v0, (%[DST]) \n\t"
|
|
|
+ "addi %[DST], %[DST], 64 \n\t"
|
|
|
+ "bnez t6, TAIL_QUANT%= \n\t"
|
|
|
+ "END%=: \n\t"
|
|
|
+ : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
|
|
|
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer),
|
|
|
+ [CNT] "r"(cnt)
|
|
|
+ : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6");
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+} // namespace ime1
|
|
|
+
|
|
|
+namespace {
|
|
|
+#define SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 \
|
|
|
+ "vmadot v16, v14, v0 \n\t" \
|
|
|
+ "vmadot v18, v14, v1 \n\t" \
|
|
|
+ "vmadot v20, v14, v2 \n\t" \
|
|
|
+ "vmadot v22, v14, v3 \n\t" \
|
|
|
+ "vmadot v16, v15, v4 \n\t" \
|
|
|
+ "vmadot v18, v15, v5 \n\t" \
|
|
|
+ "vmadot v20, v15, v6 \n\t" \
|
|
|
+ "vmadot v22, v15, v7 \n\t"
|
|
|
+
|
|
|
+#define SQ4BIT_KERNEL_ACC_1X4X4 \
|
|
|
+ "vfcvt.f.x.v v16, v16 \n\t" \
|
|
|
+ "vfcvt.f.x.v v18, v18 \n\t" \
|
|
|
+ "vfcvt.f.x.v v20, v20 \n\t" \
|
|
|
+ "vfcvt.f.x.v v22, v22 \n\t" \
|
|
|
+ "addi s2, s1, 16 \n\t" \
|
|
|
+ "addi s3, s1, 32 \n\t" \
|
|
|
+ "addi s4, s1, 48 \n\t" \
|
|
|
+ "addi s6, s5, 12 \n\t" \
|
|
|
+ "vfmacc.vv v28, v16, v24 \n\t" \
|
|
|
+ "vfmacc.vv v29, v18, v25 \n\t" \
|
|
|
+ "vfmacc.vv v30, v20, v26 \n\t" \
|
|
|
+ "vfmacc.vv v31, v22, v27 \n\t"
|
|
|
+
|
|
|
+#define SQ4BIT_KERNEL_ACC_F16_1X4X4 \
|
|
|
+ "vfcvt.f.x.v v16, v16 \n\t" \
|
|
|
+ "vfcvt.f.x.v v18, v18 \n\t" \
|
|
|
+ "vfcvt.f.x.v v20, v20 \n\t" \
|
|
|
+ "vfcvt.f.x.v v22, v22 \n\t" \
|
|
|
+ "addi s2, s1, 8 \n\t" \
|
|
|
+ "addi s3, s1, 16 \n\t" \
|
|
|
+ "addi s4, s1, 24 \n\t" \
|
|
|
+ "addi s6, s5, 12 \n\t" \
|
|
|
+ "vfmacc.vv v28, v16, v24 \n\t" \
|
|
|
+ "vfmacc.vv v29, v18, v25 \n\t" \
|
|
|
+ "vfmacc.vv v30, v20, v26 \n\t" \
|
|
|
+ "vfmacc.vv v31, v22, v27 \n\t"
|
|
|
+
|
|
|
+#define SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 \
|
|
|
+ "vle8.v v4, (s1) \n\t" \
|
|
|
+ "addi s1, s1, 128 \n\t" \
|
|
|
+ "vle8.v v5, (s2) \n\t" \
|
|
|
+ "addi s2, s2, 128 \n\t" \
|
|
|
+ "vle8.v v6, (s3) \n\t" \
|
|
|
+ "addi s3, s3, 128 \n\t" \
|
|
|
+ "vle8.v v7, (s4) \n\t" \
|
|
|
+ "addi s4, s4, 128 \n\t" \
|
|
|
+ "vsetvli t0, zero, e8, mf4 \n\t" \
|
|
|
+ "vle8.v v14, (s5) \n\t" \
|
|
|
+ "addi s5, s5, 16 \n\t" \
|
|
|
+ "vle8.v v15, (s6) \n\t" \
|
|
|
+ "addi s6, s6, 16 \n\t" \
|
|
|
+ "addi t5, t5, -1 \n\t" \
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t" \
|
|
|
+ "vand.vi v0, v4, 15 \n\t" \
|
|
|
+ "vand.vi v1, v5, 15 \n\t" \
|
|
|
+ "vand.vi v2, v6, 15 \n\t" \
|
|
|
+ "vand.vi v3, v7, 15 \n\t" \
|
|
|
+ "vsrl.vi v4, v4, 4 \n\t" \
|
|
|
+ "vsrl.vi v5, v5, 4 \n\t" \
|
|
|
+ "vsrl.vi v6, v6, 4 \n\t" \
|
|
|
+ "vsrl.vi v7, v7, 4 \n\t"
|
|
|
+
|
|
|
+#define SQ4BIT_KERNEL_LOAD_ZP_16X1 \
|
|
|
+ "vsetvli t0, zero, e8, mf2 \n\t" \
|
|
|
+ "vle8.v v1, (s7) \n\t" \
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t" \
|
|
|
+ "vrgather.vv v8, v1, v13 \n\t" \
|
|
|
+ "vadd.vi v13, v13, 4 \n\t" \
|
|
|
+ "vrgather.vv v9, v1, v13 \n\t" \
|
|
|
+ "vadd.vi v13, v13, 4 \n\t" \
|
|
|
+ "vrgather.vv v10, v1, v13 \n\t" \
|
|
|
+ "vadd.vi v13, v13, 4 \n\t" \
|
|
|
+ "vrgather.vv v11, v1, v13 \n\t" \
|
|
|
+ "vadd.vi v13, v13, -12 \n\t"
|
|
|
+
|
|
|
+// using for M4Kernel
|
|
|
+#define LOAD_B_16x8x2 \
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t" \
|
|
|
+ "vle8.v v6, (s1) \n\t" \
|
|
|
+ "addi s1, s1, 32*4 \n\t" \
|
|
|
+ "vle8.v v7, (s2) \n\t" \
|
|
|
+ "addi s2, s2, 32*4 \n\t" \
|
|
|
+ "vle8.v v8, (s3) \n\t" \
|
|
|
+ "addi s3, s3, 32*4 \n\t" \
|
|
|
+ "vle8.v v9, (s4) \n\t" \
|
|
|
+ "addi s4, s4, 32*4 \n\t" \
|
|
|
+ \
|
|
|
+ "vand.vi v2, v6, 15 \n\t" \
|
|
|
+ "vand.vi v3, v7, 15 \n\t" \
|
|
|
+ "vand.vi v4, v8, 15 \n\t" \
|
|
|
+ "vand.vi v5, v9, 15 \n\t" \
|
|
|
+ \
|
|
|
+ "vsrl.vi v6, v6, 4 \n\t" \
|
|
|
+ "vsrl.vi v7, v7, 4 \n\t" \
|
|
|
+ "vsrl.vi v8, v8, 4 \n\t" \
|
|
|
+ "vsrl.vi v9, v9, 4 \n\t"
|
|
|
+
|
|
|
+// [s2|s5, s3, s4, s6]
|
|
|
+#define LOAD_SCALE_4x16_FP16 \
|
|
|
+ "addi s2, s5, -8 \n\t" \
|
|
|
+ "addi s3, s5, 8 \n\t" \
|
|
|
+ "addi s4, s5, 16 \n\t" \
|
|
|
+ "addi s6, s5, 24 \n\t" \
|
|
|
+ "li t1, 0xf0 \n\t" \
|
|
|
+ "vmv.s.x v0, t1 \n\t" \
|
|
|
+ "vsetvli t0, zero, e16, mf4 \n\t" \
|
|
|
+ "vle16.v v9, (s5) \n\t" \
|
|
|
+ "vle16.v v11, (s3) \n\t" \
|
|
|
+ "vle16.v v13, (s4) \n\t" \
|
|
|
+ "vle16.v v15, (s6) \n\t" \
|
|
|
+ "vsetvli t0, zero, e16, mf2 \n\t" \
|
|
|
+ "vle16.v v9, (s2), v0.t \n\t" \
|
|
|
+ "vle16.v v11, (s5), v0.t \n\t" \
|
|
|
+ "vle16.v v13, (s3), v0.t \n\t" \
|
|
|
+ "vle16.v v15, (s4), v0.t \n\t" \
|
|
|
+ "vfwcvt.f.f.v v8, v9 \n\t" \
|
|
|
+ "vfwcvt.f.f.v v10, v11 \n\t" \
|
|
|
+ "vfwcvt.f.f.v v12, v13 \n\t" \
|
|
|
+ "vfwcvt.f.f.v v14, v15 \n\t" \
|
|
|
+ "vsetvli t0, zero, e32, m1 \n\t" \
|
|
|
+ "vmv.v.v v9, v8 \n\t" \
|
|
|
+ "vmv.v.v v11, v10 \n\t" \
|
|
|
+ "vmv.v.v v13, v12 \n\t" \
|
|
|
+ "vmv.v.v v15, v14 \n\t" \
|
|
|
+ "li t1, 0xf0 \n\t" \
|
|
|
+ "vmv.s.x v0, t1 \n\t" \
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t" \
|
|
|
+ "vfmul.vf v8, v8, f1 \n\t" \
|
|
|
+ "vfmul.vf v10, v10, f1 \n\t" \
|
|
|
+ "vfmul.vf v12, v12, f1 \n\t" \
|
|
|
+ "vfmul.vf v14, v14, f1 \n\t" \
|
|
|
+ "vfmul.vf v9, v9, f3 \n\t" \
|
|
|
+ "vfmul.vf v11, v11, f3 \n\t" \
|
|
|
+ "vfmul.vf v13, v13, f3 \n\t" \
|
|
|
+ "vfmul.vf v15, v15, f3 \n\t" \
|
|
|
+ "vsetvli t0, zero, e32, m1 \n\t" \
|
|
|
+ "vfmul.vf v8, v8, f2, v0.t \n\t" \
|
|
|
+ "vfmul.vf v10, v10, f2, v0.t \n\t" \
|
|
|
+ "vfmul.vf v12, v12, f2, v0.t \n\t" \
|
|
|
+ "vfmul.vf v14, v14, f2, v0.t \n\t" \
|
|
|
+ "vfmul.vf v9, v9, f4, v0.t \n\t" \
|
|
|
+ "vfmul.vf v11, v11, f4, v0.t \n\t" \
|
|
|
+ "vfmul.vf v13, v13, f4, v0.t \n\t" \
|
|
|
+ "vfmul.vf v15, v15, f4, v0.t \n\t"
|
|
|
+
|
|
|
+// [s2|s5, s3, s4, s6]
|
|
|
+#define LOAD_SCALE_4x16 \
|
|
|
+ "addi s2, s5, -16 \n\t" \
|
|
|
+ "addi s3, s5, 16 \n\t" \
|
|
|
+ "addi s4, s5, 32 \n\t" \
|
|
|
+ "addi s6, s5, 48 \n\t" \
|
|
|
+ "li t1, 0xf0 \n\t" \
|
|
|
+ "vmv.s.x v0, t1 \n\t" \
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t" \
|
|
|
+ "vle32.v v8, (s5) \n\t" \
|
|
|
+ "vle32.v v10, (s3) \n\t" \
|
|
|
+ "vle32.v v12, (s4) \n\t" \
|
|
|
+ "vle32.v v14, (s6) \n\t" \
|
|
|
+ "vsetvli t0, zero, e32, m1 \n\t" \
|
|
|
+ "vle32.v v8, (s2), v0.t \n\t" \
|
|
|
+ "vle32.v v10, (s5), v0.t \n\t" \
|
|
|
+ "vle32.v v12, (s3), v0.t \n\t" \
|
|
|
+ "vle32.v v14, (s4), v0.t \n\t" \
|
|
|
+ "vmv.v.v v9, v8 \n\t" \
|
|
|
+ "vmv.v.v v11, v10 \n\t" \
|
|
|
+ "vmv.v.v v13, v12 \n\t" \
|
|
|
+ "vmv.v.v v15, v14 \n\t" \
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t" \
|
|
|
+ "vfmul.vf v8, v8, f1 \n\t" \
|
|
|
+ "vfmul.vf v10, v10, f1 \n\t" \
|
|
|
+ "vfmul.vf v12, v12, f1 \n\t" \
|
|
|
+ "vfmul.vf v14, v14, f1 \n\t" \
|
|
|
+ "vfmul.vf v9, v9, f3 \n\t" \
|
|
|
+ "vfmul.vf v11, v11, f3 \n\t" \
|
|
|
+ "vfmul.vf v13, v13, f3 \n\t" \
|
|
|
+ "vfmul.vf v15, v15, f3 \n\t" \
|
|
|
+ "vsetvli t0, zero, e32, m1 \n\t" \
|
|
|
+ "vfmul.vf v8, v8, f2, v0.t \n\t" \
|
|
|
+ "vfmul.vf v10, v10, f2, v0.t \n\t" \
|
|
|
+ "vfmul.vf v12, v12, f2, v0.t \n\t" \
|
|
|
+ "vfmul.vf v14, v14, f2, v0.t \n\t" \
|
|
|
+ "vfmul.vf v9, v9, f4, v0.t \n\t" \
|
|
|
+ "vfmul.vf v11, v11, f4, v0.t \n\t" \
|
|
|
+ "vfmul.vf v13, v13, f4, v0.t \n\t" \
|
|
|
+ "vfmul.vf v15, v15, f4, v0.t \n\t"
|
|
|
+
|
|
|
+//[s1| BIAS, s2, s3, s4]
|
|
|
+#define LOAD_BIAS \
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t" \
|
|
|
+ "li t1, 0xf0 \n\t" \
|
|
|
+ "vmv.s.x v0, t1 \n\t" \
|
|
|
+ "addi s1, %[BIAS], -16 \n\t" \
|
|
|
+ "addi s2, %[BIAS], 16 \n\t" \
|
|
|
+ "addi s3, %[BIAS], 32 \n\t" \
|
|
|
+ "addi s4, %[BIAS], 48 \n\t" \
|
|
|
+ \
|
|
|
+ "vle32.v v24, (%[BIAS]) \n\t" \
|
|
|
+ "vle32.v v26, (s2) \n\t" \
|
|
|
+ "vle32.v v28, (s3) \n\t" \
|
|
|
+ "vle32.v v30, (s4) \n\t" \
|
|
|
+ "vsetvli t0, zero, e32, m1 \n\t" \
|
|
|
+ "vle32.v v24, (s1), v0.t \n\t" \
|
|
|
+ "vle32.v v26, (%[BIAS]), v0.t \n\t" \
|
|
|
+ "vle32.v v28, (s2), v0.t \n\t" \
|
|
|
+ "vle32.v v30, (s3), v0.t \n\t" \
|
|
|
+ "vmv.v.v v25, v24 \n\t" \
|
|
|
+ "vmv.v.v v27, v26 \n\t" \
|
|
|
+ "vmv.v.v v29, v28 \n\t" \
|
|
|
+ "vmv.v.v v31, v30 \n\t"
|
|
|
+
|
|
|
+#define SQ4BIT_KERNEL_COMP_4x16x16 \
|
|
|
+ "vmadot v16, v10, v2 \n\t" \
|
|
|
+ "vmadot v18, v10, v3 \n\t" \
|
|
|
+ "vmadot v20, v10, v4 \n\t" \
|
|
|
+ "vmadot v22, v10, v5 \n\t" \
|
|
|
+ "vmadot v16, v11, v6 \n\t" \
|
|
|
+ "vmadot v18, v11, v7 \n\t" \
|
|
|
+ "vmadot v20, v11, v8 \n\t" \
|
|
|
+ "vmadot v22, v11, v9 \n\t"
|
|
|
+
|
|
|
+#define SAVE_RESULT_4x16 \
|
|
|
+ "addi a1, %[C], 0 \n\t" \
|
|
|
+ "add a2, %[C], %[LDC] \n\t" \
|
|
|
+ "add a3, a2, %[LDC] \n\t" \
|
|
|
+ "add a4, a3, %[LDC] \n\t" \
|
|
|
+ "addi a2, a2, -16 \n\t" \
|
|
|
+ "addi a4, a4, -16 \n\t" \
|
|
|
+ "li t1, 0xf0 \n\t" \
|
|
|
+ "vmv.s.x v0, t1 \n\t" \
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t" \
|
|
|
+ \
|
|
|
+ "vse32.v v24, (a1) \n\t" \
|
|
|
+ "addi a1, a1, 16 \n\t" \
|
|
|
+ "vse32.v v25, (a3) \n\t" \
|
|
|
+ "addi a3, a3, 16 \n\t" \
|
|
|
+ \
|
|
|
+ "vse32.v v26, (a1) \n\t" \
|
|
|
+ "addi a1, a1, 16 \n\t" \
|
|
|
+ "vse32.v v27, (a3) \n\t" \
|
|
|
+ "addi a3, a3, 16 \n\t" \
|
|
|
+ \
|
|
|
+ "vse32.v v28, (a1) \n\t" \
|
|
|
+ "addi a1, a1, 16 \n\t" \
|
|
|
+ "vse32.v v29, (a3) \n\t" \
|
|
|
+ "addi a3, a3, 16 \n\t" \
|
|
|
+ \
|
|
|
+ "vse32.v v30, (a1) \n\t" \
|
|
|
+ "vse32.v v31, (a3) \n\t" \
|
|
|
+ "vsetvli t0, zero, e32, m1 \n\t" \
|
|
|
+ \
|
|
|
+ "vse32.v v24, (a2), v0.t \n\t" \
|
|
|
+ "addi a2, a2, 16 \n\t" \
|
|
|
+ "vse32.v v25, (a4), v0.t \n\t" \
|
|
|
+ "addi a4, a4, 16 \n\t" \
|
|
|
+ \
|
|
|
+ "vse32.v v26, (a2), v0.t \n\t" \
|
|
|
+ "addi a2, a2, 16 \n\t" \
|
|
|
+ "vse32.v v27, (a4), v0.t \n\t" \
|
|
|
+ "addi a4, a4, 16 \n\t" \
|
|
|
+ \
|
|
|
+ "vse32.v v28, (a2), v0.t \n\t" \
|
|
|
+ "addi a2, a2, 16 \n\t" \
|
|
|
+ "vse32.v v29, (a4), v0.t \n\t" \
|
|
|
+ "addi a4, a4, 16 \n\t" \
|
|
|
+ \
|
|
|
+ "vse32.v v30, (a2), v0.t \n\t" \
|
|
|
+ "vse32.v v31, (a4), v0.t \n\t"
|
|
|
+
|
|
|
+#define SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 \
|
|
|
+ "vsetvli t0, zero, e8, mf2 \n\t" \
|
|
|
+ "vle8.v v11, (s6) \n\t" \
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t" \
|
|
|
+ "vrgather.vv v12, v11, v1 \n\t" \
|
|
|
+ "vadd.vi v1, v1, 4 \n\t" \
|
|
|
+ "vrgather.vv v13, v11, v1 \n\t" \
|
|
|
+ "vadd.vi v1, v1, 4 \n\t" \
|
|
|
+ "vrgather.vv v14, v11, v1 \n\t" \
|
|
|
+ "vadd.vi v1, v1, 4 \n\t" \
|
|
|
+ "vrgather.vv v15, v11, v1 \n\t" \
|
|
|
+ "vadd.vi v1, v1, -12 \n\t"
|
|
|
+
|
|
|
+template <bool HasZeroPoint>
|
|
|
+void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
|
|
|
+ const std::byte * QuantA,
|
|
|
+ const std::byte * QuantBData,
|
|
|
+ const float * QuantBScale,
|
|
|
+ const std::byte * QuantBZeroPoint,
|
|
|
+ float * C,
|
|
|
+ size_t CountN,
|
|
|
+ size_t BlockCountK,
|
|
|
+ const float * Bias,
|
|
|
+ const size_t ldc) {
|
|
|
+ GGML_UNUSED(QuantBScale);
|
|
|
+ GGML_UNUSED(QuantBZeroPoint);
|
|
|
+ size_t LDC = ldc * sizeof(float);
|
|
|
+ const size_t INNER = BlkLen / 16;
|
|
|
+ float tmp[4 * 16];
|
|
|
+
|
|
|
+ if constexpr (HasZeroPoint) {
|
|
|
+ for (size_t n = 0; n < CountN; n += 16) {
|
|
|
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
|
|
|
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
|
|
|
+ n * BlockCountK * BlkLen / 2 + // b data
|
|
|
+ n * BlockCountK * sizeof(uint8_t) + // zp
|
|
|
+ n * BlockCountK * sizeof(_Float16); // scale
|
|
|
+ float * CPtr = C + n;
|
|
|
+ if (NBLKS < 16) {
|
|
|
+ CPtr = tmp;
|
|
|
+ LDC = 16 * sizeof(float);
|
|
|
+ }
|
|
|
+ if (Bias != nullptr) {
|
|
|
+ const float * bias = Bias + n;
|
|
|
+ if (NBLKS < 16) {
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, %[N], e32, m2 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "vse32.v v0, (%[DST]) \n\t"
|
|
|
+ :
|
|
|
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
|
|
|
+ : "cc", "t0");
|
|
|
+ bias = tmp;
|
|
|
+ }
|
|
|
+ __asm__ volatile(LOAD_BIAS
|
|
|
+
|
|
|
+ "addi t3, %[BlockCountK], 0 \n\t"
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t"
|
|
|
+ "li s1, 24 \n\t"
|
|
|
+ "vmv.v.i v1, 3 \n\t"
|
|
|
+ "vsetvli t0, s1, e8, m1 \n\t"
|
|
|
+ "vmv.v.i v1, 2 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf2 \n\t"
|
|
|
+ "vmv.v.i v1, 1 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf4 \n\t"
|
|
|
+ "vmv.v.i v1, 0 \n\t"
|
|
|
+
|
|
|
+ "addi a1, %[A], 0 \n\t"
|
|
|
+ "addi s1, %[B], 0 \n\t"
|
|
|
+
|
|
|
+ "BLOCK_COUNTK_LOOP%=: \n\t"
|
|
|
+ // scale offset
|
|
|
+ "addi s5, s1, 0 \n\t"
|
|
|
+ // zp offset
|
|
|
+ "addi s6, s1, 32 \n\t"
|
|
|
+ "addi s1, s6, 16 \n\t"
|
|
|
+ "addi s2, s1, 32 \n\t"
|
|
|
+ "addi s3, s1, 32*2 \n\t"
|
|
|
+ "addi s4, s1, 32*3 \n\t"
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ // load a scale
|
|
|
+ "flw f1, (a1) \n\t"
|
|
|
+ "flw f2, 4(a1) \n\t"
|
|
|
+ "flw f3, 8(a1) \n\t"
|
|
|
+ "flw f4, 12(a1) \n\t"
|
|
|
+ "addi a1, a1, 16 \n\t"
|
|
|
+ "addi t2, %[INNER], 0 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
|
|
|
+
|
|
|
+ "BLOCK_INNER_LOOP%=: \n\t"
|
|
|
+
|
|
|
+ LOAD_B_16x8x2
|
|
|
+
|
|
|
+ "vle8.v v10, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vle8.v v11, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vsub.vv v2, v2, v12 \n\t"
|
|
|
+ "vsub.vv v6, v6, v12 \n\t"
|
|
|
+ "vsub.vv v3, v3, v13 \n\t"
|
|
|
+ "vsub.vv v7, v7, v13 \n\t"
|
|
|
+ "vsub.vv v4, v4, v14 \n\t"
|
|
|
+ "vsub.vv v8, v8, v14 \n\t"
|
|
|
+ "vsub.vv v5, v5, v15 \n\t"
|
|
|
+ "vsub.vv v9, v9, v15 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_COMP_4x16x16
|
|
|
+
|
|
|
+ "addi t2, t2, -1 \n\t"
|
|
|
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
|
|
|
+
|
|
|
+ LOAD_SCALE_4x16_FP16
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vfcvt.f.x.v v16, v16 \n\t"
|
|
|
+ "vfmacc.vv v24, v16, v8 \n\t"
|
|
|
+ "addi t3, t3, -1 \n\t"
|
|
|
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
|
|
|
+
|
|
|
+ "RESULT_SAVE%=: \n\t"
|
|
|
+
|
|
|
+ SAVE_RESULT_4x16
|
|
|
+
|
|
|
+ :
|
|
|
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
|
|
|
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
|
|
|
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
|
|
|
+ "s2", "s3", "s4", "s5", "s6");
|
|
|
+
|
|
|
+ } else {
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v24, v24, v24 \n\t"
|
|
|
+ "addi t3, %[BlockCountK], 0 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t"
|
|
|
+ "li s1, 24 \n\t"
|
|
|
+ "vmv.v.i v1, 3 \n\t"
|
|
|
+ "vsetvli t0, s1, e8, m1 \n\t"
|
|
|
+ "vmv.v.i v1, 2 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf2 \n\t"
|
|
|
+ "vmv.v.i v1, 1 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf4 \n\t"
|
|
|
+ "vmv.v.i v1, 0 \n\t"
|
|
|
+ "addi a1, %[A], 0 \n\t"
|
|
|
+ "addi s1, %[B], 0 \n\t"
|
|
|
+ "BLOCK_COUNTK_LOOP%=: \n\t"
|
|
|
+ // scale offset
|
|
|
+ "addi s5, s1, 0 \n\t"
|
|
|
+ // zp offset
|
|
|
+ "addi s6, s1, 32 \n\t"
|
|
|
+ "addi s1, s6, 16 \n\t"
|
|
|
+ "addi s2, s1, 32 \n\t"
|
|
|
+ "addi s3, s1, 32*2 \n\t"
|
|
|
+ "addi s4, s1, 32*3 \n\t"
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ // load a scale
|
|
|
+ "flw f1, (a1) \n\t"
|
|
|
+ "flw f2, 4(a1) \n\t"
|
|
|
+ "flw f3, 8(a1) \n\t"
|
|
|
+ "flw f4, 12(a1) \n\t"
|
|
|
+ "addi a1, a1, 16 \n\t"
|
|
|
+ "addi t2, %[INNER], 0 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
|
|
|
+
|
|
|
+ "BLOCK_INNER_LOOP%=: \n\t"
|
|
|
+
|
|
|
+ LOAD_B_16x8x2
|
|
|
+
|
|
|
+ "vle8.v v10, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vle8.v v11, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vsub.vv v2, v2, v12 \n\t"
|
|
|
+ "vsub.vv v6, v6, v12 \n\t"
|
|
|
+ "vsub.vv v3, v3, v13 \n\t"
|
|
|
+ "vsub.vv v7, v7, v13 \n\t"
|
|
|
+ "vsub.vv v4, v4, v14 \n\t"
|
|
|
+ "vsub.vv v8, v8, v14 \n\t"
|
|
|
+ "vsub.vv v5, v5, v15 \n\t"
|
|
|
+ "vsub.vv v9, v9, v15 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_COMP_4x16x16
|
|
|
+
|
|
|
+ "addi t2, t2, -1 \n\t"
|
|
|
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
|
|
|
+
|
|
|
+ LOAD_SCALE_4x16_FP16
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vfcvt.f.x.v v16, v16 \n\t"
|
|
|
+ "vfmacc.vv v24, v16, v8 \n\t"
|
|
|
+ "addi t3, t3, -1 \n\t"
|
|
|
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
|
|
|
+
|
|
|
+ "RESULT_SAVE%=: \n\t"
|
|
|
+
|
|
|
+ SAVE_RESULT_4x16
|
|
|
+
|
|
|
+ :
|
|
|
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
|
|
|
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
|
|
|
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
|
|
|
+ "s4", "s5", "s6");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ for (size_t n = 0; n < CountN; n += 16) {
|
|
|
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
|
|
|
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
|
|
|
+ n * BlockCountK * BlkLen / 2 + // b data
|
|
|
+ n * BlockCountK * sizeof(_Float16); // scale
|
|
|
+ float * CPtr = C + n;
|
|
|
+ if (NBLKS < 16) {
|
|
|
+ CPtr = tmp;
|
|
|
+ LDC = 16 * sizeof(float);
|
|
|
+ }
|
|
|
+ if (Bias != nullptr) {
|
|
|
+ const float * bias = Bias + n;
|
|
|
+ if (NBLKS < 16) {
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, %[N], e32, m2 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "vse32.v v0, (%[DST]) \n\t"
|
|
|
+ :
|
|
|
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
|
|
|
+ : "cc", "t0");
|
|
|
+ bias = tmp;
|
|
|
+ }
|
|
|
+ __asm__ volatile(LOAD_BIAS
|
|
|
+
|
|
|
+ "addi t3, %[BlockCountK], 0 \n\t"
|
|
|
+ "addi a1, %[A], 0 \n\t"
|
|
|
+ "addi s1, %[B], 0 \n\t"
|
|
|
+ "BLOCK_COUNTK_LOOP%=: \n\t"
|
|
|
+ "addi s5, s1, 0 \n\t"
|
|
|
+ "addi s1, s5, 32 \n\t"
|
|
|
+ "addi s2, s1, 32 \n\t"
|
|
|
+ "addi s3, s1, 32*2 \n\t"
|
|
|
+ "addi s4, s1, 32*3 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ // load a scale
|
|
|
+ "flw f1, (a1) \n\t"
|
|
|
+ "flw f2, 4(a1) \n\t"
|
|
|
+ "flw f3, 8(a1) \n\t"
|
|
|
+ "flw f4, 12(a1) \n\t"
|
|
|
+ "addi a1, a1, 16 \n\t"
|
|
|
+ "addi t2, %[INNER], 0 \n\t"
|
|
|
+ "BLOCK_INNER_LOOP%=: \n\t"
|
|
|
+
|
|
|
+ LOAD_B_16x8x2
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t"
|
|
|
+ "vle8.v v10, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vle8.v v11, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vadd.vi v2, v2, -8 \n\t"
|
|
|
+ "vadd.vi v3, v3, -8 \n\t"
|
|
|
+ "vadd.vi v4, v4, -8 \n\t"
|
|
|
+ "vadd.vi v5, v5, -8 \n\t"
|
|
|
+ "vadd.vi v6, v6, -8 \n\t"
|
|
|
+ "vadd.vi v7, v7, -8 \n\t"
|
|
|
+ "vadd.vi v8, v8, -8 \n\t"
|
|
|
+ "vadd.vi v9, v9, -8 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_COMP_4x16x16
|
|
|
+
|
|
|
+ "addi t2, t2, -1 \n\t"
|
|
|
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
|
|
|
+
|
|
|
+ LOAD_SCALE_4x16_FP16
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vfcvt.f.x.v v16, v16 \n\t"
|
|
|
+ "vfmacc.vv v24, v16, v8 \n\t"
|
|
|
+ "addi t3, t3, -1 \n\t"
|
|
|
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
|
|
|
+ "RESULT_SAVE%=: \n\t"
|
|
|
+
|
|
|
+ SAVE_RESULT_4x16
|
|
|
+
|
|
|
+ :
|
|
|
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
|
|
|
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
|
|
|
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
|
|
|
+ "s2", "s3", "s4", "s5", "s6");
|
|
|
+
|
|
|
+ } else {
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v24, v24, v24 \n\t"
|
|
|
+ "addi t3, %[BlockCountK], 0 \n\t"
|
|
|
+ "addi a1, %[A], 0 \n\t"
|
|
|
+ "addi s1, %[B], 0 \n\t"
|
|
|
+ "BLOCK_COUNTK_LOOP%=: \n\t"
|
|
|
+ "addi s5, s1, 0 \n\t"
|
|
|
+ "addi s1, s5, 32 \n\t"
|
|
|
+ "addi s2, s1, 32 \n\t"
|
|
|
+ "addi s3, s1, 32*2 \n\t"
|
|
|
+ "addi s4, s1, 32*3 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ // load a scale
|
|
|
+ "flw f1, (a1) \n\t"
|
|
|
+ "flw f2, 4(a1) \n\t"
|
|
|
+ "flw f3, 8(a1) \n\t"
|
|
|
+ "flw f4, 12(a1) \n\t"
|
|
|
+ "addi a1, a1, 16 \n\t"
|
|
|
+ "addi t2, %[INNER], 0 \n\t"
|
|
|
+ "BLOCK_INNER_LOOP%=: \n\t"
|
|
|
+
|
|
|
+ LOAD_B_16x8x2
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t"
|
|
|
+ "vle8.v v10, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vle8.v v11, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vadd.vi v2, v2, -8 \n\t"
|
|
|
+ "vadd.vi v3, v3, -8 \n\t"
|
|
|
+ "vadd.vi v4, v4, -8 \n\t"
|
|
|
+ "vadd.vi v5, v5, -8 \n\t"
|
|
|
+ "vadd.vi v6, v6, -8 \n\t"
|
|
|
+ "vadd.vi v7, v7, -8 \n\t"
|
|
|
+ "vadd.vi v8, v8, -8 \n\t"
|
|
|
+ "vadd.vi v9, v9, -8 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_COMP_4x16x16
|
|
|
+
|
|
|
+ "addi t2, t2, -1 \n\t"
|
|
|
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
|
|
|
+
|
|
|
+ LOAD_SCALE_4x16_FP16
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vfcvt.f.x.v v16, v16 \n\t"
|
|
|
+ "vfmacc.vv v24, v16, v8 \n\t"
|
|
|
+ "addi t3, t3, -1 \n\t"
|
|
|
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
|
|
|
+ "RESULT_SAVE%=: \n\t"
|
|
|
+
|
|
|
+ SAVE_RESULT_4x16
|
|
|
+
|
|
|
+ :
|
|
|
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
|
|
|
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
|
|
|
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
|
|
|
+ "s4", "s5", "s6");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (CountN % 16 != 0) {
|
|
|
+ // stroe output from tmp to C when NBLKS less than 16.
|
|
|
+ float * CPtr = C + CountN / 16 * 16;
|
|
|
+ const size_t N = CountN % 16;
|
|
|
+ LDC = ldc * sizeof(float);
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, %[N], e32, m2 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "addi s2, %[SRC], 64 \n\t"
|
|
|
+ "addi s3, %[SRC], 64*2 \n\t"
|
|
|
+ "addi s4, %[SRC], 64*3 \n\t"
|
|
|
+ "vle32.v v2, (s2) \n\t"
|
|
|
+ "vle32.v v4, (s3) \n\t"
|
|
|
+ "vle32.v v6, (s4) \n\t"
|
|
|
+ "add t2, %[DST], %[LDC] \n\t"
|
|
|
+ "add t3, t2, %[LDC] \n\t"
|
|
|
+ "add t4, t3, %[LDC] \n\t"
|
|
|
+ "vse32.v v0, (%[DST]) \n\t"
|
|
|
+ "vse32.v v2, (t2) \n\t"
|
|
|
+ "vse32.v v4, (t3) \n\t"
|
|
|
+ "vse32.v v6, (t4) \n\t"
|
|
|
+ :
|
|
|
+ : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
|
|
|
+ : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+template <bool HasZeroPoint>
|
|
|
+void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen,
|
|
|
+ const std::byte * QuantA,
|
|
|
+ const std::byte * QuantBData,
|
|
|
+ const float * QuantBScale,
|
|
|
+ const std::byte * QuantBZeroPoint,
|
|
|
+ float * C,
|
|
|
+ size_t CountN,
|
|
|
+ size_t BlockCountK,
|
|
|
+ const float * Bias,
|
|
|
+ const size_t ldc) {
|
|
|
+ GGML_UNUSED(QuantBScale);
|
|
|
+ GGML_UNUSED(QuantBZeroPoint);
|
|
|
+ size_t LDC = ldc * sizeof(float);
|
|
|
+ const size_t INNER = BlkLen / 16;
|
|
|
+ float tmp[4 * 16];
|
|
|
+
|
|
|
+ if constexpr (HasZeroPoint) {
|
|
|
+ for (size_t n = 0; n < CountN; n += 16) {
|
|
|
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
|
|
|
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
|
|
|
+ n * BlockCountK * BlkLen / 2 + // b data
|
|
|
+ n * BlockCountK * sizeof(uint8_t) + // zp
|
|
|
+ n * BlockCountK * sizeof(float); // scale
|
|
|
+ float * CPtr = C + n;
|
|
|
+ if (NBLKS < 16) {
|
|
|
+ CPtr = tmp;
|
|
|
+ LDC = 16 * sizeof(float);
|
|
|
+ }
|
|
|
+ if (Bias != nullptr) {
|
|
|
+ const float * bias = Bias + n;
|
|
|
+ if (NBLKS < 16) {
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, %[N], e32, m2 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "vse32.v v0, (%[DST]) \n\t"
|
|
|
+ :
|
|
|
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
|
|
|
+ : "cc", "t0");
|
|
|
+ bias = tmp;
|
|
|
+ }
|
|
|
+
|
|
|
+ __asm__ volatile(LOAD_BIAS
|
|
|
+ "addi t3, %[BlockCountK], 0 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t"
|
|
|
+ "li s1, 24 \n\t"
|
|
|
+ "vmv.v.i v1, 3 \n\t"
|
|
|
+ "vsetvli t0, s1, e8, m1 \n\t"
|
|
|
+ "vmv.v.i v1, 2 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf2 \n\t"
|
|
|
+ "vmv.v.i v1, 1 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf4 \n\t"
|
|
|
+ "vmv.v.i v1, 0 \n\t"
|
|
|
+ "addi a1, %[A], 0 \n\t"
|
|
|
+ "addi s1, %[B], 0 \n\t"
|
|
|
+ "BLOCK_COUNTK_LOOP%=: \n\t"
|
|
|
+ // scale offset
|
|
|
+ "addi s5, s1, 0 \n\t"
|
|
|
+ // zp offset
|
|
|
+ "addi s6, s1, 64 \n\t"
|
|
|
+ "addi s1, s6, 16 \n\t"
|
|
|
+ "addi s2, s1, 32 \n\t"
|
|
|
+ "addi s3, s1, 32*2 \n\t"
|
|
|
+ "addi s4, s1, 32*3 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ // load a scale
|
|
|
+ "flw f1, (a1) \n\t"
|
|
|
+ "flw f2, 4(a1) \n\t"
|
|
|
+ "flw f3, 8(a1) \n\t"
|
|
|
+ "flw f4, 12(a1) \n\t"
|
|
|
+ "addi a1, a1, 16 \n\t"
|
|
|
+ "addi t2, %[INNER], 0 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
|
|
|
+
|
|
|
+ "BLOCK_INNER_LOOP%=: \n\t"
|
|
|
+
|
|
|
+ LOAD_B_16x8x2
|
|
|
+
|
|
|
+ "vle8.v v10, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vle8.v v11, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vsub.vv v2, v2, v12 \n\t"
|
|
|
+ "vsub.vv v6, v6, v12 \n\t"
|
|
|
+ "vsub.vv v3, v3, v13 \n\t"
|
|
|
+ "vsub.vv v7, v7, v13 \n\t"
|
|
|
+ "vsub.vv v4, v4, v14 \n\t"
|
|
|
+ "vsub.vv v8, v8, v14 \n\t"
|
|
|
+ "vsub.vv v5, v5, v15 \n\t"
|
|
|
+ "vsub.vv v9, v9, v15 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_COMP_4x16x16
|
|
|
+
|
|
|
+ "addi t2, t2, -1 \n\t"
|
|
|
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
|
|
|
+
|
|
|
+ LOAD_SCALE_4x16
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vfcvt.f.x.v v16, v16 \n\t"
|
|
|
+ "vfmacc.vv v24, v16, v8 \n\t"
|
|
|
+ "addi t3, t3, -1 \n\t"
|
|
|
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
|
|
|
+
|
|
|
+ "RESULT_SAVE%=: \n\t"
|
|
|
+
|
|
|
+ SAVE_RESULT_4x16
|
|
|
+
|
|
|
+ :
|
|
|
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
|
|
|
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
|
|
|
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
|
|
|
+ "s2", "s3", "s4", "s5", "s6");
|
|
|
+
|
|
|
+ } else {
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v24, v24, v24 \n\t"
|
|
|
+ "addi t3, %[BlockCountK], 0 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t"
|
|
|
+ "li s1, 24 \n\t"
|
|
|
+ "vmv.v.i v1, 3 \n\t"
|
|
|
+ "vsetvli t0, s1, e8, m1 \n\t"
|
|
|
+ "vmv.v.i v1, 2 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf2 \n\t"
|
|
|
+ "vmv.v.i v1, 1 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf4 \n\t"
|
|
|
+ "vmv.v.i v1, 0 \n\t"
|
|
|
+ "addi a1, %[A], 0 \n\t"
|
|
|
+ "addi s1, %[B], 0 \n\t"
|
|
|
+ "BLOCK_COUNTK_LOOP%=: \n\t"
|
|
|
+ // scale offset
|
|
|
+ "addi s5, s1, 0 \n\t"
|
|
|
+ // zp offset
|
|
|
+ "addi s6, s1, 64 \n\t"
|
|
|
+ "addi s1, s6, 16 \n\t"
|
|
|
+ "addi s2, s1, 32 \n\t"
|
|
|
+ "addi s3, s1, 32*2 \n\t"
|
|
|
+ "addi s4, s1, 32*3 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ // load a scale
|
|
|
+ // load a scale
|
|
|
+ "flw f1, (a1) \n\t"
|
|
|
+ "flw f2, 4(a1) \n\t"
|
|
|
+ "flw f3, 8(a1) \n\t"
|
|
|
+ "flw f4, 12(a1) \n\t"
|
|
|
+ "addi a1, a1, 16 \n\t"
|
|
|
+ "addi t2, %[INNER], 0 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
|
|
|
+
|
|
|
+ "BLOCK_INNER_LOOP%=: \n\t"
|
|
|
+
|
|
|
+ LOAD_B_16x8x2
|
|
|
+
|
|
|
+ "vle8.v v10, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vle8.v v11, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vsub.vv v2, v2, v12 \n\t"
|
|
|
+ "vsub.vv v6, v6, v12 \n\t"
|
|
|
+ "vsub.vv v3, v3, v13 \n\t"
|
|
|
+ "vsub.vv v7, v7, v13 \n\t"
|
|
|
+ "vsub.vv v4, v4, v14 \n\t"
|
|
|
+ "vsub.vv v8, v8, v14 \n\t"
|
|
|
+ "vsub.vv v5, v5, v15 \n\t"
|
|
|
+ "vsub.vv v9, v9, v15 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_COMP_4x16x16
|
|
|
+
|
|
|
+ "addi t2, t2, -1 \n\t"
|
|
|
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
|
|
|
+
|
|
|
+ LOAD_SCALE_4x16
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vfcvt.f.x.v v16, v16 \n\t"
|
|
|
+ "vfmacc.vv v24, v16, v8 \n\t"
|
|
|
+ "addi t3, t3, -1 \n\t"
|
|
|
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
|
|
|
+
|
|
|
+ "RESULT_SAVE%=: \n\t"
|
|
|
+
|
|
|
+ SAVE_RESULT_4x16
|
|
|
+
|
|
|
+ :
|
|
|
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
|
|
|
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
|
|
|
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
|
|
|
+ "s4", "s5", "s6");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ for (size_t n = 0; n < CountN; n += 16) {
|
|
|
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
|
|
|
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
|
|
|
+ n * BlockCountK * BlkLen / 2 + // b data
|
|
|
+ n * BlockCountK * sizeof(float); // scale
|
|
|
+ float * CPtr = C + n;
|
|
|
+ if (NBLKS < 16) {
|
|
|
+ CPtr = tmp;
|
|
|
+ LDC = 16 * sizeof(float);
|
|
|
+ }
|
|
|
+ if (Bias != nullptr) {
|
|
|
+ const float * bias = Bias + n;
|
|
|
+ if (NBLKS < 16) {
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, %[N], e32, m2 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "vse32.v v0, (%[DST]) \n\t"
|
|
|
+ :
|
|
|
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
|
|
|
+ : "cc", "t0");
|
|
|
+ bias = tmp;
|
|
|
+ }
|
|
|
+ __asm__ volatile(LOAD_BIAS
|
|
|
+ "addi t3, %[BlockCountK], 0 \n\t"
|
|
|
+ "addi a1, %[A], 0 \n\t"
|
|
|
+ "addi s1, %[B], 0 \n\t"
|
|
|
+ "BLOCK_COUNTK_LOOP%=: \n\t"
|
|
|
+ "addi s5, s1, 0 \n\t"
|
|
|
+ "addi s1, s5, 64 \n\t"
|
|
|
+ "addi s2, s1, 32 \n\t"
|
|
|
+ "addi s3, s1, 32*2 \n\t"
|
|
|
+ "addi s4, s1, 32*3 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ // load a scale
|
|
|
+ "flw f1, (a1) \n\t"
|
|
|
+ "flw f2, 4(a1) \n\t"
|
|
|
+ "flw f3, 8(a1) \n\t"
|
|
|
+ "flw f4, 12(a1) \n\t"
|
|
|
+ "addi a1, a1, 16 \n\t"
|
|
|
+ "addi t2, %[INNER], 0 \n\t"
|
|
|
+ "BLOCK_INNER_LOOP%=: \n\t"
|
|
|
+
|
|
|
+ LOAD_B_16x8x2
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t"
|
|
|
+ "vle8.v v10, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vle8.v v11, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vadd.vi v2, v2, -8 \n\t"
|
|
|
+ "vadd.vi v3, v3, -8 \n\t"
|
|
|
+ "vadd.vi v4, v4, -8 \n\t"
|
|
|
+ "vadd.vi v5, v5, -8 \n\t"
|
|
|
+ "vadd.vi v6, v6, -8 \n\t"
|
|
|
+ "vadd.vi v7, v7, -8 \n\t"
|
|
|
+ "vadd.vi v8, v8, -8 \n\t"
|
|
|
+ "vadd.vi v9, v9, -8 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_COMP_4x16x16
|
|
|
+
|
|
|
+ "addi t2, t2, -1 \n\t"
|
|
|
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
|
|
|
+
|
|
|
+ LOAD_SCALE_4x16
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vfcvt.f.x.v v16, v16 \n\t"
|
|
|
+ "vfmacc.vv v24, v16, v8 \n\t"
|
|
|
+ "addi t3, t3, -1 \n\t"
|
|
|
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
|
|
|
+
|
|
|
+ "RESULT_SAVE%=: \n\t"
|
|
|
+
|
|
|
+ SAVE_RESULT_4x16
|
|
|
+
|
|
|
+ :
|
|
|
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
|
|
|
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
|
|
|
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
|
|
|
+ "s2", "s3", "s4", "s5", "s6");
|
|
|
+
|
|
|
+ } else {
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v24, v24, v24 \n\t"
|
|
|
+ "addi t3, %[BlockCountK], 0 \n\t"
|
|
|
+ "addi a1, %[A], 0 \n\t"
|
|
|
+ "addi s1, %[B], 0 \n\t"
|
|
|
+ "BLOCK_COUNTK_LOOP%=: \n\t"
|
|
|
+ "addi s5, s1, 0 \n\t"
|
|
|
+ "addi s1, s5, 64 \n\t"
|
|
|
+ "addi s2, s1, 32 \n\t"
|
|
|
+ "addi s3, s1, 32*2 \n\t"
|
|
|
+ "addi s4, s1, 32*3 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ // load a scale
|
|
|
+ "flw f1, (a1) \n\t"
|
|
|
+ "flw f2, 4(a1) \n\t"
|
|
|
+ "flw f3, 8(a1) \n\t"
|
|
|
+ "flw f4, 12(a1) \n\t"
|
|
|
+ "addi a1, a1, 16 \n\t"
|
|
|
+ "addi t2, %[INNER], 0 \n\t"
|
|
|
+ "BLOCK_INNER_LOOP%=: \n\t"
|
|
|
+
|
|
|
+ LOAD_B_16x8x2
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t"
|
|
|
+ "vle8.v v10, (a1) \n\t"
|
|
|
+
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vle8.v v11, (a1) \n\t"
|
|
|
+ "addi a1, a1, 32 \n\t"
|
|
|
+ "vadd.vi v2, v2, -8 \n\t"
|
|
|
+ "vadd.vi v3, v3, -8 \n\t"
|
|
|
+ "vadd.vi v4, v4, -8 \n\t"
|
|
|
+ "vadd.vi v5, v5, -8 \n\t"
|
|
|
+ "vadd.vi v6, v6, -8 \n\t"
|
|
|
+ "vadd.vi v7, v7, -8 \n\t"
|
|
|
+ "vadd.vi v8, v8, -8 \n\t"
|
|
|
+ "vadd.vi v9, v9, -8 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_COMP_4x16x16
|
|
|
+
|
|
|
+ "addi t2, t2, -1 \n\t"
|
|
|
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
|
|
|
+
|
|
|
+ LOAD_SCALE_4x16
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e32, m8 \n\t"
|
|
|
+ "vfcvt.f.x.v v16, v16 \n\t"
|
|
|
+ "vfmacc.vv v24, v16, v8 \n\t"
|
|
|
+ "addi t3, t3, -1 \n\t"
|
|
|
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
|
|
|
+
|
|
|
+ "RESULT_SAVE%=: \n\t"
|
|
|
+
|
|
|
+ SAVE_RESULT_4x16
|
|
|
+
|
|
|
+ :
|
|
|
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
|
|
|
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
|
|
|
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
|
|
|
+ "s4", "s5", "s6");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (CountN % 16 != 0) {
|
|
|
+ // stroe output from tmp to C when NBLKS less than 16.
|
|
|
+ float * CPtr = C + CountN / 16 * 16;
|
|
|
+ const size_t N = CountN % 16;
|
|
|
+ LDC = ldc * sizeof(float);
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, %[N], e32, m2 \n\t"
|
|
|
+ "vle32.v v0, (%[SRC]) \n\t"
|
|
|
+ "addi s2, %[SRC], 64 \n\t"
|
|
|
+ "addi s3, %[SRC], 64*2 \n\t"
|
|
|
+ "addi s4, %[SRC], 64*3 \n\t"
|
|
|
+ "vle32.v v2, (s2) \n\t"
|
|
|
+ "vle32.v v4, (s3) \n\t"
|
|
|
+ "vle32.v v6, (s4) \n\t"
|
|
|
+ "add t2, %[DST], %[LDC] \n\t"
|
|
|
+ "add t3, t2, %[LDC] \n\t"
|
|
|
+ "add t4, t3, %[LDC] \n\t"
|
|
|
+ "vse32.v v0, (%[DST]) \n\t"
|
|
|
+ "vse32.v v2, (t2) \n\t"
|
|
|
+ "vse32.v v4, (t3) \n\t"
|
|
|
+ "vse32.v v6, (t4) \n\t"
|
|
|
+ :
|
|
|
+ : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
|
|
|
+ : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+template <bool HasZeroPoint>
|
|
|
+void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
|
|
|
+ const std::byte * QuantA,
|
|
|
+ const std::byte * QuantBData,
|
|
|
+ const float * QuantBScale,
|
|
|
+ const std::byte * QuantBZeroPoint,
|
|
|
+ float * C,
|
|
|
+ size_t CountN,
|
|
|
+ size_t BlockCountK,
|
|
|
+ const float * Bias) {
|
|
|
+ GGML_UNUSED(QuantBScale);
|
|
|
+ GGML_UNUSED(QuantBZeroPoint);
|
|
|
+ size_t INNER = BlkLen / 16;
|
|
|
+
|
|
|
+ if constexpr (HasZeroPoint) {
|
|
|
+ for (size_t n = 0; n < CountN; n += 16) {
|
|
|
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
|
|
|
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
|
|
|
+ n * BlockCountK * BlkLen / 2 + // b data
|
|
|
+ n * BlockCountK * sizeof(uint8_t) + // zp
|
|
|
+ n * BlockCountK * sizeof(_Float16); // scale
|
|
|
+ float * CPtr = C + n;
|
|
|
+ size_t cnt = BlockCountK;
|
|
|
+ if (Bias != nullptr) {
|
|
|
+ const float * bias = Bias + n;
|
|
|
+ __asm__ volatile(
|
|
|
+ "addi t3, %[NBLKS], 0 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t"
|
|
|
+
|
|
|
+ "vmv.v.i v13, 3 \n\t"
|
|
|
+ "li s1, 24 \n\t"
|
|
|
+ "vsetvli t0, s1, e8, m1 \n\t"
|
|
|
+ "vmv.v.i v13, 2 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf2 \n\t"
|
|
|
+ "vmv.v.i v13, 1 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf4 \n\t"
|
|
|
+ "vmv.v.i v13, 0 \n\t"
|
|
|
+ "addi s1, %[B], 0 \n\t"
|
|
|
+ "addi s2, %[B], 8 \n\t"
|
|
|
+ "addi s3, %[B], 16 \n\t"
|
|
|
+ "addi s4, %[B], 24 \n\t"
|
|
|
+ // zp offset
|
|
|
+ "addi s7, %[B], 32 \n\t"
|
|
|
+ // a offset
|
|
|
+ "addi s5, %[A], 0 \n\t"
|
|
|
+ "addi s6, %[A], 12 \n\t"
|
|
|
+
|
|
|
+ "vsetvli t0, t3, e32, mf2 \n\t"
|
|
|
+ "vle32.v v28, (%[BIAS]) \n\t"
|
|
|
+ "sub t3, t3, t0 \n\t"
|
|
|
+ "addi %[BIAS], %[BIAS], 16 \n\t"
|
|
|
+ "vsetvli t0, t3, e32, mf2 \n\t"
|
|
|
+ "vle32.v v29, (%[BIAS]) \n\t"
|
|
|
+ "sub t3, t3, t0 \n\t"
|
|
|
+ "addi %[BIAS], %[BIAS], 16 \n\t"
|
|
|
+ "vsetvli t0, t3, e32, mf2 \n\t"
|
|
|
+ "vle32.v v30, (%[BIAS]) \n\t"
|
|
|
+ "sub t3, t3, t0 \n\t"
|
|
|
+ "addi %[BIAS], %[BIAS], 16 \n\t"
|
|
|
+ "vsetvli t0, t3, e32, mf2 \n\t"
|
|
|
+ "vle32.v v31, (%[BIAS]) \n\t"
|
|
|
+
|
|
|
+ "LOOP_K%=: \n\t"
|
|
|
+ "vsetvli t0, zero, e16, mf4 \n\t"
|
|
|
+
|
|
|
+ "vle16.v v4, (s1) \n\t"
|
|
|
+ "addi s1, s1, 48 \n\t"
|
|
|
+ "vle16.v v5, (s2) \n\t"
|
|
|
+ "addi s2, s2, 72 \n\t"
|
|
|
+ "vle16.v v6, (s3) \n\t"
|
|
|
+ "addi s3, s3, 96 \n\t"
|
|
|
+ "vle16.v v7, (s4) \n\t"
|
|
|
+ "addi s4, s4, 120 \n\t"
|
|
|
+ "flw f1, (s5) \n\t"
|
|
|
+ "addi s5, s5, 4 \n\t"
|
|
|
+ "vfwcvt.f.f.v v8, v4 \n\t"
|
|
|
+ "vfwcvt.f.f.v v9, v5 \n\t"
|
|
|
+ "vfwcvt.f.f.v v10, v6 \n\t"
|
|
|
+ "vfwcvt.f.f.v v11, v7 \n\t"
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t"
|
|
|
+ "addi t5, %[INNER], 0 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ "vxor.vv v18, v18, v18 \n\t"
|
|
|
+ "vxor.vv v20, v20, v20 \n\t"
|
|
|
+ "vxor.vv v22, v22, v22 \n\t"
|
|
|
+ "vfmul.vf v24, v8, f1 \n\t"
|
|
|
+ "vfmul.vf v25, v9, f1 \n\t"
|
|
|
+ "vfmul.vf v26, v10, f1 \n\t"
|
|
|
+ "vfmul.vf v27, v11, f1 \n\t"
|
|
|
+ "addi %[CNT], %[CNT], -1 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
|
|
|
+
|
|
|
+ "LOOP_INNER%=: \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
|
|
|
+
|
|
|
+ "vsub.vv v0, v0, v8 \n\t"
|
|
|
+ "vsub.vv v4, v4, v8 \n\t"
|
|
|
+ "vsub.vv v1, v1, v9 \n\t"
|
|
|
+ "vsub.vv v5, v5, v9 \n\t"
|
|
|
+ "vsub.vv v2, v2, v10 \n\t"
|
|
|
+ "vsub.vv v6, v6, v10 \n\t"
|
|
|
+ "vsub.vv v3, v3, v11 \n\t"
|
|
|
+ "vsub.vv v7, v7, v11 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
|
|
|
+
|
|
|
+ "bnez t5, LOOP_INNER%= \n\t"
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
|
|
|
+ "addi s7, s1, 32 \n\t"
|
|
|
+
|
|
|
+ "bnez %[CNT], LOOP_K%= \n\t"
|
|
|
+ "addi t3, zero, 16 \n\t"
|
|
|
+ "addi s1, %[C], 16 \n\t"
|
|
|
+ "addi s2, %[C], 32 \n\t"
|
|
|
+ "addi s3, %[C], 48 \n\t"
|
|
|
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
|
|
|
+ "vse32.v v28, (%[C]) \n\t"
|
|
|
+ "vse32.v v29, (s1) \n\t"
|
|
|
+ "vse32.v v30, (s2) \n\t"
|
|
|
+ "vse32.v v31, (s3) \n\t"
|
|
|
+ "jal x0, END%= \n\t"
|
|
|
+
|
|
|
+ "ST_TAIL%=: \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v28, (%[C]) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v29, (s1) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v30, (s2) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v31, (s3) \n\t"
|
|
|
+ "END%=: \n\t"
|
|
|
+
|
|
|
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
|
|
|
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
|
|
|
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
|
|
|
+ } else {
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, zero, e32, m4 \n\t"
|
|
|
+ "vxor.vv v28, v28, v28 \n\t"
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t"
|
|
|
+ "vmv.v.i v13, 3 \n\t"
|
|
|
+ "li s1, 24 \n\t"
|
|
|
+ "vsetvli t0, s1, e8, m1 \n\t"
|
|
|
+ "vmv.v.i v13, 2 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf2 \n\t"
|
|
|
+ "vmv.v.i v13, 1 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf4 \n\t"
|
|
|
+ "vmv.v.i v13, 0 \n\t"
|
|
|
+
|
|
|
+ "addi s1, %[B], 0 \n\t"
|
|
|
+ "addi s2, %[B], 8 \n\t"
|
|
|
+ "addi s3, %[B], 16 \n\t"
|
|
|
+ "addi s4, %[B], 24 \n\t"
|
|
|
+
|
|
|
+ "addi s7, %[B], 32 \n\t"
|
|
|
+
|
|
|
+ "addi s5, %[A], 0 \n\t"
|
|
|
+ "addi s6, %[A], 12 \n\t"
|
|
|
+ "LOOP_K%=: \n\t"
|
|
|
+ "vsetvli t0, zero, e16, mf4 \n\t"
|
|
|
+ "vle16.v v4, (s1) \n\t"
|
|
|
+ "addi s1, s1, 48 \n\t"
|
|
|
+ "vle16.v v5, (s2) \n\t"
|
|
|
+ "addi s2, s2, 72 \n\t"
|
|
|
+ "vle16.v v6, (s3) \n\t"
|
|
|
+ "addi s3, s3, 96 \n\t"
|
|
|
+ "vle16.v v7, (s4) \n\t"
|
|
|
+ "addi s4, s4, 120 \n\t"
|
|
|
+ "flw f1, (s5) \n\t"
|
|
|
+ "addi s5, s5, 4 \n\t"
|
|
|
+
|
|
|
+ "vfwcvt.f.f.v v8, v4 \n\t"
|
|
|
+ "vfwcvt.f.f.v v9, v5 \n\t"
|
|
|
+ "vfwcvt.f.f.v v10, v6 \n\t"
|
|
|
+ "vfwcvt.f.f.v v11, v7 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t"
|
|
|
+
|
|
|
+ "addi t5, %[INNER], 0 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ "vxor.vv v18, v18, v18 \n\t"
|
|
|
+ "vxor.vv v20, v20, v20 \n\t"
|
|
|
+ "vxor.vv v22, v22, v22 \n\t"
|
|
|
+ "vfmul.vf v24, v8, f1 \n\t"
|
|
|
+ "vfmul.vf v25, v9, f1 \n\t"
|
|
|
+ "vfmul.vf v26, v10, f1 \n\t"
|
|
|
+ "vfmul.vf v27, v11, f1 \n\t"
|
|
|
+ "addi %[CNT], %[CNT], -1 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
|
|
|
+
|
|
|
+ "LOOP_INNER%=: \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
|
|
|
+
|
|
|
+ "vsub.vv v0, v0, v8 \n\t"
|
|
|
+ "vsub.vv v4, v4, v8 \n\t"
|
|
|
+ "vsub.vv v1, v1, v9 \n\t"
|
|
|
+ "vsub.vv v5, v5, v9 \n\t"
|
|
|
+ "vsub.vv v2, v2, v10 \n\t"
|
|
|
+ "vsub.vv v6, v6, v10 \n\t"
|
|
|
+ "vsub.vv v3, v3, v11 \n\t"
|
|
|
+ "vsub.vv v7, v7, v11 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
|
|
|
+
|
|
|
+ "bnez t5, LOOP_INNER%= \n\t"
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
|
|
|
+ "addi s7, s1, 32 \n\t"
|
|
|
+
|
|
|
+ "bnez %[CNT], LOOP_K%= \n\t"
|
|
|
+ "addi t3, zero, 16 \n\t"
|
|
|
+ "addi s1, %[C], 16 \n\t"
|
|
|
+ "addi s2, %[C], 32 \n\t"
|
|
|
+ "addi s3, %[C], 48 \n\t"
|
|
|
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
|
|
|
+ "vse32.v v28, (%[C]) \n\t"
|
|
|
+ "vse32.v v29, (s1) \n\t"
|
|
|
+ "vse32.v v30, (s2) \n\t"
|
|
|
+ "vse32.v v31, (s3) \n\t"
|
|
|
+ "jal x0, END%= \n\t"
|
|
|
+
|
|
|
+ "ST_TAIL%=: \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v28, (%[C]) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v29, (s1) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v30, (s2) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v31, (s3) \n\t"
|
|
|
+ "END%=: \n\t"
|
|
|
+
|
|
|
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
|
|
|
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
|
|
|
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ for (size_t n = 0; n < CountN; n += 16) {
|
|
|
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
|
|
|
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
|
|
|
+ n * BlockCountK * BlkLen / 2 + // b data
|
|
|
+ n * BlockCountK * sizeof(_Float16); // scale
|
|
|
+ float * CPtr = C + n;
|
|
|
+ size_t cnt = BlockCountK;
|
|
|
+ if (Bias != nullptr) {
|
|
|
+ const float * bias = Bias + n;
|
|
|
+ __asm__ volatile(
|
|
|
+ "addi t3, %[NBLKS], 0 \n\t"
|
|
|
+ "addi s1, %[B], 0 \n\t"
|
|
|
+ "addi s2, %[B], 8 \n\t"
|
|
|
+ "addi s3, %[B], 16 \n\t"
|
|
|
+ "addi s4, %[B], 24 \n\t"
|
|
|
+ "addi s5, %[A], 0 \n\t"
|
|
|
+ "addi s6, %[A], 12 \n\t"
|
|
|
+ "vsetvli t0, t3, e32, mf2 \n\t"
|
|
|
+ "vle32.v v28, (%[BIAS]) \n\t"
|
|
|
+ "sub t3, t3, t0 \n\t"
|
|
|
+ "addi %[BIAS], %[BIAS], 16 \n\t"
|
|
|
+ "vsetvli t0, t3, e32, mf2 \n\t"
|
|
|
+ "vle32.v v29, (%[BIAS]) \n\t"
|
|
|
+ "sub t3, t3, t0 \n\t"
|
|
|
+ "addi %[BIAS], %[BIAS], 16 \n\t"
|
|
|
+ "vsetvli t0, t3, e32, mf2 \n\t"
|
|
|
+ "vle32.v v30, (%[BIAS]) \n\t"
|
|
|
+ "sub t3, t3, t0 \n\t"
|
|
|
+ "addi %[BIAS], %[BIAS], 16 \n\t"
|
|
|
+ "vsetvli t0, t3, e32, mf2 \n\t"
|
|
|
+ "vle32.v v31, (%[BIAS]) \n\t"
|
|
|
+
|
|
|
+ "LOOP_K%=: \n\t"
|
|
|
+ "vsetvli t0, zero, e16, mf4 \n\t"
|
|
|
+
|
|
|
+ "vle16.v v4, (s1) \n\t"
|
|
|
+ "addi s1, s1, 32 \n\t"
|
|
|
+ "vle16.v v5, (s2) \n\t"
|
|
|
+ "addi s2, s2, 56 \n\t"
|
|
|
+ "vle16.v v6, (s3) \n\t"
|
|
|
+ "addi s3, s3, 80 \n\t"
|
|
|
+ "vle16.v v7, (s4) \n\t"
|
|
|
+ "addi s4, s4, 104 \n\t"
|
|
|
+ "flw f1, (s5) \n\t"
|
|
|
+ "addi s5, s5, 4 \n\t"
|
|
|
+ "vfwcvt.f.f.v v8, v4 \n\t"
|
|
|
+ "vfwcvt.f.f.v v9, v5 \n\t"
|
|
|
+ "vfwcvt.f.f.v v10, v6 \n\t"
|
|
|
+ "vfwcvt.f.f.v v11, v7 \n\t"
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t"
|
|
|
+ "addi t5, %[INNER], 0 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ "vxor.vv v18, v18, v18 \n\t"
|
|
|
+ "vxor.vv v20, v20, v20 \n\t"
|
|
|
+ "vxor.vv v22, v22, v22 \n\t"
|
|
|
+ "vfmul.vf v24, v8, f1 \n\t"
|
|
|
+ "vfmul.vf v25, v9, f1 \n\t"
|
|
|
+ "vfmul.vf v26, v10, f1 \n\t"
|
|
|
+ "vfmul.vf v27, v11, f1 \n\t"
|
|
|
+ "addi %[CNT], %[CNT], -1 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t"
|
|
|
+ "LOOP_INNER%=: \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
|
|
|
+
|
|
|
+ "vadd.vi v0, v0, -8 \n\t"
|
|
|
+ "vadd.vi v1, v1, -8 \n\t"
|
|
|
+ "vadd.vi v2, v2, -8 \n\t"
|
|
|
+ "vadd.vi v3, v3, -8 \n\t"
|
|
|
+ "vadd.vi v4, v4, -8 \n\t"
|
|
|
+ "vadd.vi v5, v5, -8 \n\t"
|
|
|
+ "vadd.vi v6, v6, -8 \n\t"
|
|
|
+ "vadd.vi v7, v7, -8 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
|
|
|
+
|
|
|
+ "bnez t5, LOOP_INNER%= \n\t"
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
|
|
|
+
|
|
|
+ "bnez %[CNT], LOOP_K%= \n\t"
|
|
|
+ "addi t3, zero, 16 \n\t"
|
|
|
+ "addi s1, %[C], 16 \n\t"
|
|
|
+ "addi s2, %[C], 32 \n\t"
|
|
|
+ "addi s3, %[C], 48 \n\t"
|
|
|
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
|
|
|
+ "vse32.v v28, (%[C]) \n\t"
|
|
|
+ "vse32.v v29, (s1) \n\t"
|
|
|
+ "vse32.v v30, (s2) \n\t"
|
|
|
+ "vse32.v v31, (s3) \n\t"
|
|
|
+ "jal x0, END%= \n\t"
|
|
|
+
|
|
|
+ "ST_TAIL%=: \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v28, (%[C]) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v29, (s1) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v30, (s2) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v31, (s3) \n\t"
|
|
|
+ "END%=: \n\t"
|
|
|
+
|
|
|
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
|
|
|
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
|
|
|
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
|
|
|
+ } else {
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, zero, e32, m4 \n\t"
|
|
|
+ "vxor.vv v28, v28, v28 \n\t"
|
|
|
+ "addi s1, %[B], 0 \n\t"
|
|
|
+ "addi s2, %[B], 8 \n\t"
|
|
|
+ "addi s3, %[B], 16 \n\t"
|
|
|
+ "addi s4, %[B], 24 \n\t"
|
|
|
+
|
|
|
+ "addi s5, %[A], 0 \n\t"
|
|
|
+ "addi s6, %[A], 12 \n\t"
|
|
|
+ "LOOP_K%=: \n\t"
|
|
|
+ "vsetvli t0, zero, e16, mf4 \n\t"
|
|
|
+ "vle16.v v4, (s1) \n\t"
|
|
|
+ "addi s1, s1, 32 \n\t"
|
|
|
+ "vle16.v v5, (s2) \n\t"
|
|
|
+ "addi s2, s2, 56 \n\t"
|
|
|
+ "vle16.v v6, (s3) \n\t"
|
|
|
+ "addi s3, s3, 80 \n\t"
|
|
|
+ "vle16.v v7, (s4) \n\t"
|
|
|
+ "addi s4, s4, 104 \n\t"
|
|
|
+ "flw f1, (s5) \n\t"
|
|
|
+ "addi s5, s5, 4 \n\t"
|
|
|
+
|
|
|
+ "vfwcvt.f.f.v v8, v4 \n\t"
|
|
|
+ "vfwcvt.f.f.v v9, v5 \n\t"
|
|
|
+ "vfwcvt.f.f.v v10, v6 \n\t"
|
|
|
+ "vfwcvt.f.f.v v11, v7 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t"
|
|
|
+
|
|
|
+ "addi t5, %[INNER], 0 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ "vxor.vv v18, v18, v18 \n\t"
|
|
|
+ "vxor.vv v20, v20, v20 \n\t"
|
|
|
+ "vxor.vv v22, v22, v22 \n\t"
|
|
|
+ "vfmul.vf v24, v8, f1 \n\t"
|
|
|
+ "vfmul.vf v25, v9, f1 \n\t"
|
|
|
+ "vfmul.vf v26, v10, f1 \n\t"
|
|
|
+ "vfmul.vf v27, v11, f1 \n\t"
|
|
|
+ "addi %[CNT], %[CNT], -1 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t"
|
|
|
+ "LOOP_INNER%=: \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
|
|
|
+
|
|
|
+ "vadd.vi v0, v0, -8 \n\t"
|
|
|
+ "vadd.vi v1, v1, -8 \n\t"
|
|
|
+ "vadd.vi v2, v2, -8 \n\t"
|
|
|
+ "vadd.vi v3, v3, -8 \n\t"
|
|
|
+ "vadd.vi v4, v4, -8 \n\t"
|
|
|
+ "vadd.vi v5, v5, -8 \n\t"
|
|
|
+ "vadd.vi v6, v6, -8 \n\t"
|
|
|
+ "vadd.vi v7, v7, -8 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
|
|
|
+
|
|
|
+ "bnez t5, LOOP_INNER%= \n\t"
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
|
|
|
+
|
|
|
+ "bnez %[CNT], LOOP_K%= \n\t"
|
|
|
+ "addi t3, zero, 16 \n\t"
|
|
|
+ "addi s1, %[C], 16 \n\t"
|
|
|
+ "addi s2, %[C], 32 \n\t"
|
|
|
+ "addi s3, %[C], 48 \n\t"
|
|
|
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
|
|
|
+ "vse32.v v28, (%[C]) \n\t"
|
|
|
+ "vse32.v v29, (s1) \n\t"
|
|
|
+ "vse32.v v30, (s2) \n\t"
|
|
|
+ "vse32.v v31, (s3) \n\t"
|
|
|
+ "jal x0, END%= \n\t"
|
|
|
+
|
|
|
+ "ST_TAIL%=: \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v28, (%[C]) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v29, (s1) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v30, (s2) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v31, (s3) \n\t"
|
|
|
+ "END%=: \n\t"
|
|
|
+
|
|
|
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
|
|
|
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
|
|
|
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+template <bool HasZeroPoint>
|
|
|
+void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen,
|
|
|
+ const std::byte * QuantA,
|
|
|
+ const std::byte * QuantBData,
|
|
|
+ const float * QuantBScale,
|
|
|
+ const std::byte * QuantBZeroPoint,
|
|
|
+ float * C,
|
|
|
+ size_t CountN,
|
|
|
+ size_t BlockCountK,
|
|
|
+ const float * Bias) {
|
|
|
+ GGML_UNUSED(QuantBScale);
|
|
|
+ GGML_UNUSED(QuantBZeroPoint);
|
|
|
+ const size_t INNER = BlkLen / 16;
|
|
|
+ if constexpr (HasZeroPoint) {
|
|
|
+ for (size_t n = 0; n < CountN; n += 16) {
|
|
|
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
|
|
|
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
|
|
|
+ n * BlockCountK * BlkLen / 2 + // b data
|
|
|
+ n * BlockCountK * sizeof(uint8_t) + // zp
|
|
|
+ n * BlockCountK * sizeof(float); // scale
|
|
|
+ float * CPtr = C + n;
|
|
|
+ size_t cnt = BlockCountK;
|
|
|
+ if (Bias != nullptr) {
|
|
|
+ const float * bias = Bias + n;
|
|
|
+ __asm__ volatile(
|
|
|
+ "addi t3, %[NBLKS], 0 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t"
|
|
|
+ "vmv.v.i v13, 3 \n\t"
|
|
|
+ "li s1, 24 \n\t"
|
|
|
+ "vsetvli t0, s1, e8, m1 \n\t"
|
|
|
+ "vmv.v.i v13, 2 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf2 \n\t"
|
|
|
+ "vmv.v.i v13, 1 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf4 \n\t"
|
|
|
+ "vmv.v.i v13, 0 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, m4 \n\t"
|
|
|
+ "vxor.vv v28, v28, v28 \n\t"
|
|
|
+
|
|
|
+ // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0
|
|
|
+ "addi s1, %[B], 0 \n\t"
|
|
|
+ "addi s2, %[B], 16 \n\t"
|
|
|
+ "addi s3, %[B], 32 \n\t"
|
|
|
+ "addi s4, %[B], 48 \n\t"
|
|
|
+ // zp offset
|
|
|
+ "addi s7, %[B], 64 \n\t"
|
|
|
+ // a offset
|
|
|
+ "addi s5, %[A], 0 \n\t"
|
|
|
+ "addi s6, %[A], 12 \n\t"
|
|
|
+
|
|
|
+ "vsetvli t0, t3, e32, mf2 \n\t"
|
|
|
+ "vle32.v v28, (%[BIAS]) \n\t"
|
|
|
+ "sub t3, t3, t0 \n\t"
|
|
|
+ "addi %[BIAS], %[BIAS], 16 \n\t"
|
|
|
+ "vsetvli t0, t3, e32, mf2 \n\t"
|
|
|
+ "vle32.v v29, (%[BIAS]) \n\t"
|
|
|
+ "sub t3, t3, t0 \n\t"
|
|
|
+ "addi %[BIAS], %[BIAS], 16 \n\t"
|
|
|
+ "vsetvli t0, t3, e32, mf2 \n\t"
|
|
|
+ "vle32.v v30, (%[BIAS]) \n\t"
|
|
|
+ "sub t3, t3, t0 \n\t"
|
|
|
+ "addi %[BIAS], %[BIAS], 16 \n\t"
|
|
|
+ "vsetvli t0, t3, e32, mf2 \n\t"
|
|
|
+ "vle32.v v31, (%[BIAS]) \n\t"
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t"
|
|
|
+ "LOOP_K%=: \n\t"
|
|
|
+
|
|
|
+ // load scale
|
|
|
+ "vle32.v v8, (s1) \n\t"
|
|
|
+ "addi s1, s1, 80 \n\t"
|
|
|
+ "vle32.v v9, (s2) \n\t"
|
|
|
+ "addi s2, s2, 96 \n\t"
|
|
|
+ "vle32.v v10, (s3) \n\t"
|
|
|
+ "addi s3, s3, 112 \n\t"
|
|
|
+ "vle32.v v11, (s4) \n\t"
|
|
|
+ "addi s4, s4, 128 \n\t"
|
|
|
+
|
|
|
+ // load a scale
|
|
|
+ "flw f1, (s5) \n\t"
|
|
|
+ "addi s5, s5, 4 \n\t"
|
|
|
+
|
|
|
+ "addi t5, %[INNER], 0 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ "vxor.vv v18, v18, v18 \n\t"
|
|
|
+ "vxor.vv v20, v20, v20 \n\t"
|
|
|
+ "vxor.vv v22, v22, v22 \n\t"
|
|
|
+
|
|
|
+ // a scale * b scale
|
|
|
+ "vfmul.vf v24, v8, f1 \n\t"
|
|
|
+ "vfmul.vf v25, v9, f1 \n\t"
|
|
|
+ "vfmul.vf v26, v10, f1 \n\t"
|
|
|
+ "vfmul.vf v27, v11, f1 \n\t"
|
|
|
+ "addi %[CNT], %[CNT], -1 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
|
|
|
+
|
|
|
+ "LOOP_INNER%=: \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
|
|
|
+
|
|
|
+ "vsub.vv v0, v0, v8 \n\t"
|
|
|
+ "vsub.vv v4, v4, v8 \n\t"
|
|
|
+ "vsub.vv v1, v1, v9 \n\t"
|
|
|
+ "vsub.vv v5, v5, v9 \n\t"
|
|
|
+ "vsub.vv v2, v2, v10 \n\t"
|
|
|
+ "vsub.vv v6, v6, v10 \n\t"
|
|
|
+ "vsub.vv v3, v3, v11 \n\t"
|
|
|
+ "vsub.vv v7, v7, v11 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
|
|
|
+
|
|
|
+ "bnez t5, LOOP_INNER%= \n\t"
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_ACC_1X4X4
|
|
|
+ "addi s7, s1, 64 \n\t"
|
|
|
+
|
|
|
+ "bnez %[CNT], LOOP_K%= \n\t"
|
|
|
+
|
|
|
+ "addi t3, zero, 16 \n\t"
|
|
|
+ "addi s1, %[C], 16 \n\t"
|
|
|
+ "addi s2, %[C], 32 \n\t"
|
|
|
+ "addi s3, %[C], 48 \n\t"
|
|
|
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
|
|
|
+ "vse32.v v28, (%[C]) \n\t"
|
|
|
+ "vse32.v v29, (s1) \n\t"
|
|
|
+ "vse32.v v30, (s2) \n\t"
|
|
|
+ "vse32.v v31, (s3) \n\t"
|
|
|
+ "jal x0, END%= \n\t"
|
|
|
+
|
|
|
+ "ST_TAIL%=: \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v28, (%[C]) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v29, (s1) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v30, (s2) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v31, (s3) \n\t"
|
|
|
+ "END%=: \n\t"
|
|
|
+
|
|
|
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
|
|
|
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
|
|
|
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
|
|
|
+ } else {
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, zero, e32, m4 \n\t"
|
|
|
+ "vxor.vv v28, v28, v28 \n\t"
|
|
|
+
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t"
|
|
|
+ "vmv.v.i v13, 3 \n\t"
|
|
|
+ "li s1, 24 \n\t"
|
|
|
+ "vsetvli t0, s1, e8, m1 \n\t"
|
|
|
+ "vmv.v.i v13, 2 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf2 \n\t"
|
|
|
+ "vmv.v.i v13, 1 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, mf4 \n\t"
|
|
|
+ "vmv.v.i v13, 0 \n\t"
|
|
|
+ "addi s1, %[B], 0 \n\t"
|
|
|
+ "addi s2, %[B], 16 \n\t"
|
|
|
+ "addi s3, %[B], 32 \n\t"
|
|
|
+ "addi s4, %[B], 48 \n\t"
|
|
|
+
|
|
|
+ "addi s7, %[B], 64 \n\t"
|
|
|
+
|
|
|
+ "addi s5, %[A], 0 \n\t"
|
|
|
+ "addi s6, %[A], 12 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t"
|
|
|
+
|
|
|
+ "LOOP_K%=: \n\t"
|
|
|
+ "vle32.v v8, (s1) \n\t"
|
|
|
+ "addi s1, s1, 80 \n\t"
|
|
|
+ "vle32.v v9, (s2) \n\t"
|
|
|
+ "addi s2, s2, 96 \n\t"
|
|
|
+ "vle32.v v10, (s3) \n\t"
|
|
|
+ "addi s3, s3, 112 \n\t"
|
|
|
+ "vle32.v v11, (s4) \n\t"
|
|
|
+ "addi s4, s4, 128 \n\t"
|
|
|
+
|
|
|
+ "flw f1, (s5) \n\t"
|
|
|
+ "addi s5, s5, 4 \n\t"
|
|
|
+
|
|
|
+ "addi t5, %[INNER], 0 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ "vxor.vv v18, v18, v18 \n\t"
|
|
|
+ "vxor.vv v20, v20, v20 \n\t"
|
|
|
+ "vxor.vv v22, v22, v22 \n\t"
|
|
|
+
|
|
|
+ "vfmul.vf v24, v8, f1 \n\t"
|
|
|
+ "vfmul.vf v25, v9, f1 \n\t"
|
|
|
+ "vfmul.vf v26, v10, f1 \n\t"
|
|
|
+ "vfmul.vf v27, v11, f1 \n\t"
|
|
|
+ "addi %[CNT], %[CNT], -1 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
|
|
|
+
|
|
|
+ "LOOP_INNER%=: \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
|
|
|
+
|
|
|
+ "vsub.vv v0, v0, v8 \n\t"
|
|
|
+ "vsub.vv v4, v4, v8 \n\t"
|
|
|
+ "vsub.vv v1, v1, v9 \n\t"
|
|
|
+ "vsub.vv v5, v5, v9 \n\t"
|
|
|
+ "vsub.vv v2, v2, v10 \n\t"
|
|
|
+ "vsub.vv v6, v6, v10 \n\t"
|
|
|
+ "vsub.vv v3, v3, v11 \n\t"
|
|
|
+ "vsub.vv v7, v7, v11 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
|
|
|
+
|
|
|
+ "bnez t5, LOOP_INNER%= \n\t"
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_ACC_1X4X4
|
|
|
+ "addi s7, s1, 64 \n\t"
|
|
|
+
|
|
|
+ "bnez %[CNT], LOOP_K%= \n\t"
|
|
|
+
|
|
|
+ "addi t3, zero, 16 \n\t"
|
|
|
+ "addi s1, %[C], 16 \n\t"
|
|
|
+ "addi s2, %[C], 32 \n\t"
|
|
|
+ "addi s3, %[C], 48 \n\t"
|
|
|
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
|
|
|
+ "vse32.v v28, (%[C]) \n\t"
|
|
|
+ "vse32.v v29, (s1) \n\t"
|
|
|
+ "vse32.v v30, (s2) \n\t"
|
|
|
+ "vse32.v v31, (s3) \n\t"
|
|
|
+ "jal x0, END%= \n\t"
|
|
|
+
|
|
|
+ "ST_TAIL%=: \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v28, (%[C]) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v29, (s1) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v30, (s2) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v31, (s3) \n\t"
|
|
|
+ "END%=: \n\t"
|
|
|
+
|
|
|
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
|
|
|
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
|
|
|
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ for (size_t n = 0; n < CountN; n += 16) {
|
|
|
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
|
|
|
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
|
|
|
+ n * BlockCountK * BlkLen / 2 + // b data
|
|
|
+ n * BlockCountK * sizeof(float); // scale
|
|
|
+ float * CPtr = C + n;
|
|
|
+ size_t cnt = BlockCountK;
|
|
|
+ if (Bias != nullptr) {
|
|
|
+ const float * bias = Bias + n;
|
|
|
+ __asm__ volatile(
|
|
|
+ "addi t3, %[NBLKS], 0 \n\t"
|
|
|
+ "addi s1, %[B], 0 \n\t"
|
|
|
+ "addi s2, %[B], 16 \n\t"
|
|
|
+ "addi s3, %[B], 32 \n\t"
|
|
|
+ "addi s4, %[B], 48 \n\t"
|
|
|
+ "addi s5, %[A], 0 \n\t"
|
|
|
+ "addi s6, %[A], 12 \n\t"
|
|
|
+ "vsetvli t0, t3, e32, mf2 \n\t"
|
|
|
+ "vle32.v v28, (%[BIAS]) \n\t"
|
|
|
+ "sub t3, t3, t0 \n\t"
|
|
|
+ "addi %[BIAS], %[BIAS], 16 \n\t"
|
|
|
+ "vsetvli t0, t3, e32, mf2 \n\t"
|
|
|
+ "vle32.v v29, (%[BIAS]) \n\t"
|
|
|
+ "sub t3, t3, t0 \n\t"
|
|
|
+ "addi %[BIAS], %[BIAS], 16 \n\t"
|
|
|
+ "vsetvli t0, t3, e32, mf2 \n\t"
|
|
|
+ "vle32.v v30, (%[BIAS]) \n\t"
|
|
|
+ "sub t3, t3, t0 \n\t"
|
|
|
+ "addi %[BIAS], %[BIAS], 16 \n\t"
|
|
|
+ "vsetvli t0, t3, e32, mf2 \n\t"
|
|
|
+ "vle32.v v31, (%[BIAS]) \n\t"
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t"
|
|
|
+ "LOOP_K%=: \n\t"
|
|
|
+ "vle32.v v8, (s1) \n\t"
|
|
|
+ "addi s1, s1, 64 \n\t"
|
|
|
+ "vle32.v v9, (s2) \n\t"
|
|
|
+ "addi s2, s2, 80 \n\t"
|
|
|
+ "vle32.v v10, (s3) \n\t"
|
|
|
+ "addi s3, s3, 96 \n\t"
|
|
|
+ "vle32.v v11, (s4) \n\t"
|
|
|
+ "addi s4, s4, 112 \n\t"
|
|
|
+ "flw f1, (s5) \n\t"
|
|
|
+ "addi s5, s5, 4 \n\t"
|
|
|
+
|
|
|
+ "addi t5, %[INNER], 0 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ "vxor.vv v18, v18, v18 \n\t"
|
|
|
+ "vxor.vv v20, v20, v20 \n\t"
|
|
|
+ "vxor.vv v22, v22, v22 \n\t"
|
|
|
+ "vfmul.vf v24, v8, f1 \n\t"
|
|
|
+ "vfmul.vf v25, v9, f1 \n\t"
|
|
|
+ "vfmul.vf v26, v10, f1 \n\t"
|
|
|
+ "vfmul.vf v27, v11, f1 \n\t"
|
|
|
+ "addi %[CNT], %[CNT], -1 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t"
|
|
|
+ "LOOP_INNER%=: \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
|
|
|
+
|
|
|
+ "vadd.vi v0, v0, -8 \n\t"
|
|
|
+ "vadd.vi v1, v1, -8 \n\t"
|
|
|
+ "vadd.vi v2, v2, -8 \n\t"
|
|
|
+ "vadd.vi v3, v3, -8 \n\t"
|
|
|
+ "vadd.vi v4, v4, -8 \n\t"
|
|
|
+ "vadd.vi v5, v5, -8 \n\t"
|
|
|
+ "vadd.vi v6, v6, -8 \n\t"
|
|
|
+ "vadd.vi v7, v7, -8 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
|
|
|
+
|
|
|
+ "bnez t5, LOOP_INNER%= \n\t"
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_ACC_1X4X4
|
|
|
+
|
|
|
+ "bnez %[CNT], LOOP_K%= \n\t"
|
|
|
+ "addi t3, zero, 16 \n\t"
|
|
|
+ "addi s1, %[C], 16 \n\t"
|
|
|
+ "addi s2, %[C], 32 \n\t"
|
|
|
+ "addi s3, %[C], 48 \n\t"
|
|
|
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
|
|
|
+ "vse32.v v28, (%[C]) \n\t"
|
|
|
+ "vse32.v v29, (s1) \n\t"
|
|
|
+ "vse32.v v30, (s2) \n\t"
|
|
|
+ "vse32.v v31, (s3) \n\t"
|
|
|
+ "jal x0, END%= \n\t"
|
|
|
+
|
|
|
+ "ST_TAIL%=: \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v28, (%[C]) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v29, (s1) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v30, (s2) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v31, (s3) \n\t"
|
|
|
+ "END%=: \n\t"
|
|
|
+
|
|
|
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
|
|
|
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
|
|
|
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
|
|
|
+ } else {
|
|
|
+ __asm__ volatile(
|
|
|
+ "vsetvli t0, zero, e32, m4 \n\t"
|
|
|
+ "vxor.vv v28, v28, v28 \n\t"
|
|
|
+ "addi s1, %[B], 0 \n\t"
|
|
|
+ "addi s2, %[B], 16 \n\t"
|
|
|
+ "addi s3, %[B], 32 \n\t"
|
|
|
+ "addi s4, %[B], 48 \n\t"
|
|
|
+
|
|
|
+ "addi s5, %[A], 0 \n\t"
|
|
|
+ "addi s6, %[A], 12 \n\t"
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t"
|
|
|
+ "LOOP_K%=: \n\t"
|
|
|
+ "vle32.v v8, (s1) \n\t"
|
|
|
+ "addi s1, s1, 64 \n\t"
|
|
|
+ "vle32.v v9, (s2) \n\t"
|
|
|
+ "addi s2, s2, 80 \n\t"
|
|
|
+ "vle32.v v10, (s3) \n\t"
|
|
|
+ "addi s3, s3, 96 \n\t"
|
|
|
+ "vle32.v v11, (s4) \n\t"
|
|
|
+ "addi s4, s4, 112 \n\t"
|
|
|
+ "flw f1, (s5) \n\t"
|
|
|
+ "addi s5, s5, 4 \n\t"
|
|
|
+
|
|
|
+ "addi t5, %[INNER], 0 \n\t"
|
|
|
+ "vxor.vv v16, v16, v16 \n\t"
|
|
|
+ "vxor.vv v18, v18, v18 \n\t"
|
|
|
+ "vxor.vv v20, v20, v20 \n\t"
|
|
|
+ "vxor.vv v22, v22, v22 \n\t"
|
|
|
+ "vfmul.vf v24, v8, f1 \n\t"
|
|
|
+ "vfmul.vf v25, v9, f1 \n\t"
|
|
|
+ "vfmul.vf v26, v10, f1 \n\t"
|
|
|
+ "vfmul.vf v27, v11, f1 \n\t"
|
|
|
+ "addi %[CNT], %[CNT], -1 \n\t"
|
|
|
+ "vsetvli t0, zero, e8, m1 \n\t"
|
|
|
+ "LOOP_INNER%=: \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
|
|
|
+
|
|
|
+ "vadd.vi v0, v0, -8 \n\t"
|
|
|
+ "vadd.vi v1, v1, -8 \n\t"
|
|
|
+ "vadd.vi v2, v2, -8 \n\t"
|
|
|
+ "vadd.vi v3, v3, -8 \n\t"
|
|
|
+ "vadd.vi v4, v4, -8 \n\t"
|
|
|
+ "vadd.vi v5, v5, -8 \n\t"
|
|
|
+ "vadd.vi v6, v6, -8 \n\t"
|
|
|
+ "vadd.vi v7, v7, -8 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
|
|
|
+
|
|
|
+ "bnez t5, LOOP_INNER%= \n\t"
|
|
|
+ "vsetvli t0, zero, e32, mf2 \n\t"
|
|
|
+
|
|
|
+ SQ4BIT_KERNEL_ACC_1X4X4
|
|
|
+
|
|
|
+ "bnez %[CNT], LOOP_K%= \n\t"
|
|
|
+ "addi t3, zero, 16 \n\t"
|
|
|
+ "addi s1, %[C], 16 \n\t"
|
|
|
+ "addi s2, %[C], 32 \n\t"
|
|
|
+ "addi s3, %[C], 48 \n\t"
|
|
|
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
|
|
|
+ "vse32.v v28, (%[C]) \n\t"
|
|
|
+ "vse32.v v29, (s1) \n\t"
|
|
|
+ "vse32.v v30, (s2) \n\t"
|
|
|
+ "vse32.v v31, (s3) \n\t"
|
|
|
+ "jal x0, END%= \n\t"
|
|
|
+
|
|
|
+ "ST_TAIL%=: \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v28, (%[C]) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v29, (s1) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v30, (s2) \n\t"
|
|
|
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
|
|
|
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
|
|
|
+ "vse32.v v31, (s3) \n\t"
|
|
|
+ "END%=: \n\t"
|
|
|
+
|
|
|
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
|
|
|
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
|
|
|
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+template <bool HasZeroPoint>
|
|
|
+inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
|
|
|
+ const std::byte * QuantA,
|
|
|
+ const std::byte * QuantBData,
|
|
|
+ const float * QuantBScale,
|
|
|
+ const std::byte * QuantBZeroPoint,
|
|
|
+ float * C,
|
|
|
+ size_t CountM,
|
|
|
+ size_t CountN,
|
|
|
+ size_t BlockStrideQuantB,
|
|
|
+ const float * Bias,
|
|
|
+ const size_t ldc,
|
|
|
+ const size_t scalestride) {
|
|
|
+ if (scalestride == 4) {
|
|
|
+ SQ4BitGemmM4Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
|
|
|
+ CountN, BlockStrideQuantB, Bias, ldc);
|
|
|
+
|
|
|
+ } else if (scalestride == 2) {
|
|
|
+ SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(
|
|
|
+ BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+template <bool HasZeroPoint>
|
|
|
+inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
|
|
|
+ const std::byte * QuantA,
|
|
|
+ const std::byte * QuantBData,
|
|
|
+ const float * QuantBScale,
|
|
|
+ const std::byte * QuantBZeroPoint,
|
|
|
+ float * C,
|
|
|
+ size_t CountM,
|
|
|
+ size_t CountN,
|
|
|
+ size_t BlockStrideQuantB,
|
|
|
+ const float * Bias,
|
|
|
+ const size_t ldc,
|
|
|
+ const size_t scalestride) {
|
|
|
+ if (scalestride == 4) {
|
|
|
+ SQ4BitGemmM1Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
|
|
|
+ CountN, BlockStrideQuantB, Bias);
|
|
|
+ } else if (scalestride == 2) {
|
|
|
+ SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale,
|
|
|
+ QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+} // namespace
|
|
|
+
|
|
|
+namespace ime1 {
|
|
|
+size_t gemm_kernel_i8i4(size_t BlkLen,
|
|
|
+ const std::byte * QuantA,
|
|
|
+ const std::byte * QuantBData,
|
|
|
+ const float * QuantBScale,
|
|
|
+ const std::byte * QuantBZeroPoint,
|
|
|
+ float * C,
|
|
|
+ size_t CountM,
|
|
|
+ size_t CountN,
|
|
|
+ size_t CountK,
|
|
|
+ size_t BlockCountK,
|
|
|
+ size_t ldc,
|
|
|
+ const float * Bias,
|
|
|
+ const size_t ScaleStride) {
|
|
|
+ GGML_UNUSED(CountM);
|
|
|
+ GGML_UNUSED(CountK);
|
|
|
+ GGML_UNUSED(ldc);
|
|
|
+ if (CountM >= 4) {
|
|
|
+ if (QuantBZeroPoint != nullptr) {
|
|
|
+ SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
|
|
|
+ C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
|
|
|
+ } else {
|
|
|
+ SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
|
|
|
+ QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
|
|
|
+ ldc, ScaleStride);
|
|
|
+ }
|
|
|
+ return 4;
|
|
|
+ } else {
|
|
|
+ if (QuantBZeroPoint != nullptr) {
|
|
|
+ SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
|
|
|
+ C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
|
|
|
+ } else {
|
|
|
+ SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
|
|
|
+ QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
|
|
|
+ ldc, ScaleStride);
|
|
|
+ }
|
|
|
+ return 1;
|
|
|
+ }
|
|
|
+}
|
|
|
+} // namespace ime1
|
|
|
+} // namespace sqnbitgemm_spacemit_ime
|