|
|
@@ -69,6 +69,10 @@
|
|
|
#define VECTOR_REGISTERS 16
|
|
|
#endif
|
|
|
|
|
|
+#if defined(__riscv_v_intrinsic)
|
|
|
+#define LMUL 4
|
|
|
+#endif
|
|
|
+
|
|
|
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
|
|
|
|
|
namespace {
|
|
|
@@ -175,6 +179,46 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
|
|
|
}
|
|
|
#endif
|
|
|
|
|
|
+#if defined(__riscv_zvfh)
|
|
|
+template <>
|
|
|
+inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
|
|
|
+ return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
|
|
+}
|
|
|
+inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
|
|
|
+ return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
|
|
+}
|
|
|
+inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
|
|
|
+ return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
|
|
+}
|
|
|
+inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
|
|
|
+ return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
|
|
+}
|
|
|
+inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
|
|
|
+ return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
|
|
+}
|
|
|
+inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
|
|
|
+ return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
|
|
+}
|
|
|
+inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
|
|
|
+ return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
|
|
+}
|
|
|
+inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
|
|
|
+ return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
|
|
+}
|
|
|
+#endif
|
|
|
+
|
|
|
+#if defined(__riscv_zvfbfwma)
|
|
|
+inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
|
|
|
+ return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
|
|
+}
|
|
|
+inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
|
|
|
+ return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
|
|
+}
|
|
|
+inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
|
|
|
+ return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
|
|
+}
|
|
|
+#endif
|
|
|
+
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
// VECTORIZED HORIZONTAL SUM
|
|
|
|
|
|
@@ -227,6 +271,25 @@ inline float hsum(__m512 x) {
|
|
|
}
|
|
|
#endif // __AVX512F__
|
|
|
|
|
|
+#if defined(__riscv_zvfh)
|
|
|
+inline float hsum(vfloat32m1_t x) {
|
|
|
+ return __riscv_vfmv_f_s_f32m1_f32(
|
|
|
+ __riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));
|
|
|
+}
|
|
|
+inline float hsum(vfloat32m2_t x) {
|
|
|
+ return __riscv_vfmv_f_s_f32m1_f32(
|
|
|
+ __riscv_vfredusum_vs_f32m2_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m2()));
|
|
|
+}
|
|
|
+inline float hsum(vfloat32m4_t x) {
|
|
|
+ return __riscv_vfmv_f_s_f32m1_f32(
|
|
|
+ __riscv_vfredusum_vs_f32m4_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m4()));
|
|
|
+}
|
|
|
+inline float hsum(vfloat32m8_t x) {
|
|
|
+ return __riscv_vfmv_f_s_f32m1_f32(
|
|
|
+ __riscv_vfredusum_vs_f32m8_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m8()));
|
|
|
+}
|
|
|
+#endif
|
|
|
+
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
// VECTORIZED MEMORY LOADING
|
|
|
|
|
|
@@ -315,6 +378,88 @@ template <> inline __m256bh load(const float *p) {
|
|
|
}
|
|
|
#endif
|
|
|
|
|
|
+#if defined(__riscv_zvfh)
|
|
|
+template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
|
|
|
+ return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
|
|
|
+}
|
|
|
+template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
|
|
|
+ return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
|
|
|
+}
|
|
|
+template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
|
|
|
+ return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
|
|
|
+}
|
|
|
+template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
|
|
|
+ return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
|
|
|
+}
|
|
|
+template <> inline vfloat32m1_t load(const float *p) {
|
|
|
+ return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
|
|
|
+}
|
|
|
+template <> inline vfloat32m2_t load(const float *p) {
|
|
|
+ return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2());
|
|
|
+}
|
|
|
+template <> inline vfloat32m4_t load(const float *p) {
|
|
|
+ return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4());
|
|
|
+}
|
|
|
+template <> inline vfloat32m8_t load(const float *p) {
|
|
|
+ return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8());
|
|
|
+}
|
|
|
+#endif
|
|
|
+
|
|
|
+#if defined(__riscv_zvfbfwma)
|
|
|
+template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) {
|
|
|
+ return __riscv_vle16_v_bf16mf2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16mf2());
|
|
|
+}
|
|
|
+template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
|
|
|
+ return __riscv_vle16_v_bf16m1(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m1());
|
|
|
+}
|
|
|
+template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {
|
|
|
+ return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2());
|
|
|
+}
|
|
|
+#endif
|
|
|
+
|
|
|
+#if defined(__riscv_zvfh)
|
|
|
+template <typename T> T set_zero();
|
|
|
+
|
|
|
+template <> inline vfloat16mf2_t set_zero() {
|
|
|
+ return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2());
|
|
|
+}
|
|
|
+template <> inline vfloat16m1_t set_zero() {
|
|
|
+ return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1());
|
|
|
+}
|
|
|
+template <> inline vfloat16m2_t set_zero() {
|
|
|
+ return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2());
|
|
|
+}
|
|
|
+template <> inline vfloat16m4_t set_zero() {
|
|
|
+ return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());
|
|
|
+}
|
|
|
+template <> inline vfloat32m1_t set_zero() {
|
|
|
+ return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
|
|
|
+}
|
|
|
+template <> inline vfloat32m2_t set_zero() {
|
|
|
+ return __riscv_vfmv_v_f_f32m2(0, __riscv_vsetvlmax_e32m2());
|
|
|
+}
|
|
|
+template <> inline vfloat32m4_t set_zero() {
|
|
|
+ return __riscv_vfmv_v_f_f32m4(0, __riscv_vsetvlmax_e32m4());
|
|
|
+}
|
|
|
+template <> inline vfloat32m8_t set_zero() {
|
|
|
+ return __riscv_vfmv_v_f_f32m8(0, __riscv_vsetvlmax_e32m8());
|
|
|
+}
|
|
|
+#endif
|
|
|
+
|
|
|
+#if defined(__riscv_v_intrinsic)
|
|
|
+template <typename T> size_t vlmax() {
|
|
|
+ if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
|
|
|
+ else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
|
|
|
+ else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
|
|
|
+ else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
|
|
|
+ else if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
|
|
|
+ else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return __riscv_vsetvlmax_e32m2(); }
|
|
|
+ else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return __riscv_vsetvlmax_e32m4(); }
|
|
|
+ else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return __riscv_vsetvlmax_e32m8(); }
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+#endif
|
|
|
+
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
// FLOATING POINT MATRIX MULTIPLICATION
|
|
|
|
|
|
@@ -488,6 +633,573 @@ class tinyBLAS {
|
|
|
const int64_t ldc;
|
|
|
};
|
|
|
|
|
|
+#if defined(__riscv_v_intrinsic)
|
|
|
+template <typename D, typename V, typename TA, typename TB, typename TC>
|
|
|
+class tinyBLAS_RVV {
|
|
|
+ public:
|
|
|
+ tinyBLAS_RVV(const ggml_compute_params * params, int64_t k,
|
|
|
+ const TA *A, int64_t lda,
|
|
|
+ const TB *B, int64_t ldb,
|
|
|
+ TC *C, int64_t ldc)
|
|
|
+ : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
|
|
|
+ }
|
|
|
+
|
|
|
+ bool matmul(int64_t m, int64_t n) {
|
|
|
+ if (k % vlmax<V>() != 0) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+#if LMUL == 1
|
|
|
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
|
|
|
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
|
|
+ mnpack<4, 6, 4>(m, n, SIZE_N, 12);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (m % 8 == 0 ) {
|
|
|
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
|
|
+ mnpack<4, 6, 2>(m, n, SIZE_N, 12);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (m % 4 == 0) {
|
|
|
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
|
|
+ mnpack<4, 6, 1>(m, n, SIZE_N, 12);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+#elif LMUL == 2
|
|
|
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
|
|
|
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
|
|
+ mnpack<4, 3, 4>(m, n, SIZE_N, 24);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (m % 8 == 0 ) {
|
|
|
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
|
|
+ mnpack<4, 3, 2>(m, n, SIZE_N, 24);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (m % 4 == 0) {
|
|
|
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
|
|
+ mnpack<4, 3, 1>(m, n, SIZE_N, 24);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+#else // LMUL = 4
|
|
|
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
|
|
|
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
|
|
|
+ mnpack<2, 2, 8>(m, n, SIZE_N, 36);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (m % 8 == 0 ) {
|
|
|
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
|
|
|
+ mnpack<2, 2, 4>(m, n, SIZE_N, 36);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (m % 4 == 0) {
|
|
|
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
|
|
|
+ mnpack<2, 2, 2>(m, n, SIZE_N, 36);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+#endif
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ private:
|
|
|
+ template<int RM, int RN, int BM>
|
|
|
+ inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
|
|
|
+ if (SIZE_N == RN) {
|
|
|
+ return gemm<RM, RN, BM>(m, n, BN);
|
|
|
+ }
|
|
|
+ if constexpr (RN > 1) {
|
|
|
+ return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
|
|
|
+ } else {
|
|
|
+ GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
|
|
|
+ GGML_ASSERT(false); // we have miss something.
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ inline void gemm_bloc_4x6(int64_t ii, int64_t jj) {
|
|
|
+ size_t vl = vlmax<V>();
|
|
|
+ D Cv00 = set_zero<D>();
|
|
|
+ D Cv01 = set_zero<D>();
|
|
|
+ D Cv02 = set_zero<D>();
|
|
|
+ D Cv03 = set_zero<D>();
|
|
|
+ D Cv10 = set_zero<D>();
|
|
|
+ D Cv11 = set_zero<D>();
|
|
|
+ D Cv12 = set_zero<D>();
|
|
|
+ D Cv13 = set_zero<D>();
|
|
|
+ D Cv20 = set_zero<D>();
|
|
|
+ D Cv21 = set_zero<D>();
|
|
|
+ D Cv22 = set_zero<D>();
|
|
|
+ D Cv23 = set_zero<D>();
|
|
|
+ D Cv30 = set_zero<D>();
|
|
|
+ D Cv31 = set_zero<D>();
|
|
|
+ D Cv32 = set_zero<D>();
|
|
|
+ D Cv33 = set_zero<D>();
|
|
|
+ D Cv40 = set_zero<D>();
|
|
|
+ D Cv41 = set_zero<D>();
|
|
|
+ D Cv42 = set_zero<D>();
|
|
|
+ D Cv43 = set_zero<D>();
|
|
|
+ D Cv50 = set_zero<D>();
|
|
|
+ D Cv51 = set_zero<D>();
|
|
|
+ D Cv52 = set_zero<D>();
|
|
|
+ D Cv53 = set_zero<D>();
|
|
|
+
|
|
|
+ for (int64_t l = 0; l < k; l += vl) {
|
|
|
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
|
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
|
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
|
|
|
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
|
|
|
+ V Bv4 = load<V>(B + ldb * (jj + 4) + l);
|
|
|
+ V Bv5 = load<V>(B + ldb * (jj + 5) + l);
|
|
|
+
|
|
|
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
|
+ Cv00 = madd(Av0, Bv0, Cv00);
|
|
|
+ Cv10 = madd(Av0, Bv1, Cv10);
|
|
|
+ Cv20 = madd(Av0, Bv2, Cv20);
|
|
|
+ Cv30 = madd(Av0, Bv3, Cv30);
|
|
|
+ Cv40 = madd(Av0, Bv4, Cv40);
|
|
|
+ Cv50 = madd(Av0, Bv5, Cv50);
|
|
|
+
|
|
|
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
|
+ Cv01 = madd(Av1, Bv0, Cv01);
|
|
|
+ Cv11 = madd(Av1, Bv1, Cv11);
|
|
|
+ Cv21 = madd(Av1, Bv2, Cv21);
|
|
|
+ Cv31 = madd(Av1, Bv3, Cv31);
|
|
|
+ Cv41 = madd(Av1, Bv4, Cv41);
|
|
|
+ Cv51 = madd(Av1, Bv5, Cv51);
|
|
|
+
|
|
|
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
|
+ Cv02 = madd(Av2, Bv0, Cv02);
|
|
|
+ Cv12 = madd(Av2, Bv1, Cv12);
|
|
|
+ Cv22 = madd(Av2, Bv2, Cv22);
|
|
|
+ Cv32 = madd(Av2, Bv3, Cv32);
|
|
|
+ Cv42 = madd(Av2, Bv4, Cv42);
|
|
|
+ Cv52 = madd(Av2, Bv5, Cv52);
|
|
|
+
|
|
|
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
|
+ Cv03 = madd(Av3, Bv0, Cv03);
|
|
|
+ Cv13 = madd(Av3, Bv1, Cv13);
|
|
|
+ Cv23 = madd(Av3, Bv2, Cv23);
|
|
|
+ Cv33 = madd(Av3, Bv3, Cv33);
|
|
|
+ Cv43 = madd(Av3, Bv4, Cv43);
|
|
|
+ Cv53 = madd(Av3, Bv5, Cv53);
|
|
|
+ }
|
|
|
+
|
|
|
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
|
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
|
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
|
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
|
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
|
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
|
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
|
|
|
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
|
|
|
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
|
|
|
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
|
|
|
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
|
|
|
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
|
|
|
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
|
|
|
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
|
|
|
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
|
|
|
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
|
|
|
+ C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
|
|
|
+ C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
|
|
|
+ C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
|
|
|
+ C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
|
|
|
+ C[ldc * (jj + 5) + (ii + 0)] = hsum(Cv50);
|
|
|
+ C[ldc * (jj + 5) + (ii + 1)] = hsum(Cv51);
|
|
|
+ C[ldc * (jj + 5) + (ii + 2)] = hsum(Cv52);
|
|
|
+ C[ldc * (jj + 5) + (ii + 3)] = hsum(Cv53);
|
|
|
+ }
|
|
|
+
|
|
|
+ inline void gemm_bloc_4x5(int64_t ii, int64_t jj) {
|
|
|
+ size_t vl = vlmax<V>();
|
|
|
+ D Cv00 = set_zero<D>();
|
|
|
+ D Cv01 = set_zero<D>();
|
|
|
+ D Cv02 = set_zero<D>();
|
|
|
+ D Cv03 = set_zero<D>();
|
|
|
+ D Cv10 = set_zero<D>();
|
|
|
+ D Cv11 = set_zero<D>();
|
|
|
+ D Cv12 = set_zero<D>();
|
|
|
+ D Cv13 = set_zero<D>();
|
|
|
+ D Cv20 = set_zero<D>();
|
|
|
+ D Cv21 = set_zero<D>();
|
|
|
+ D Cv22 = set_zero<D>();
|
|
|
+ D Cv23 = set_zero<D>();
|
|
|
+ D Cv30 = set_zero<D>();
|
|
|
+ D Cv31 = set_zero<D>();
|
|
|
+ D Cv32 = set_zero<D>();
|
|
|
+ D Cv33 = set_zero<D>();
|
|
|
+ D Cv40 = set_zero<D>();
|
|
|
+ D Cv41 = set_zero<D>();
|
|
|
+ D Cv42 = set_zero<D>();
|
|
|
+ D Cv43 = set_zero<D>();
|
|
|
+
|
|
|
+ for (int64_t l = 0; l < k; l += vl) {
|
|
|
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
|
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
|
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
|
|
|
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
|
|
|
+ V Bv4 = load<V>(B + ldb * (jj + 4) + l);
|
|
|
+
|
|
|
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
|
+ Cv00 = madd(Av0, Bv0, Cv00);
|
|
|
+ Cv10 = madd(Av0, Bv1, Cv10);
|
|
|
+ Cv20 = madd(Av0, Bv2, Cv20);
|
|
|
+ Cv30 = madd(Av0, Bv3, Cv30);
|
|
|
+ Cv40 = madd(Av0, Bv4, Cv40);
|
|
|
+
|
|
|
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
|
+ Cv01 = madd(Av1, Bv0, Cv01);
|
|
|
+ Cv11 = madd(Av1, Bv1, Cv11);
|
|
|
+ Cv21 = madd(Av1, Bv2, Cv21);
|
|
|
+ Cv31 = madd(Av1, Bv3, Cv31);
|
|
|
+ Cv41 = madd(Av1, Bv4, Cv41);
|
|
|
+
|
|
|
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
|
+ Cv02 = madd(Av2, Bv0, Cv02);
|
|
|
+ Cv12 = madd(Av2, Bv1, Cv12);
|
|
|
+ Cv22 = madd(Av2, Bv2, Cv22);
|
|
|
+ Cv32 = madd(Av2, Bv3, Cv32);
|
|
|
+ Cv42 = madd(Av2, Bv4, Cv42);
|
|
|
+
|
|
|
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
|
+ Cv03 = madd(Av3, Bv0, Cv03);
|
|
|
+ Cv13 = madd(Av3, Bv1, Cv13);
|
|
|
+ Cv23 = madd(Av3, Bv2, Cv23);
|
|
|
+ Cv33 = madd(Av3, Bv3, Cv33);
|
|
|
+ Cv43 = madd(Av3, Bv4, Cv43);
|
|
|
+ }
|
|
|
+
|
|
|
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
|
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
|
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
|
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
|
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
|
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
|
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
|
|
|
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
|
|
|
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
|
|
|
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
|
|
|
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
|
|
|
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
|
|
|
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
|
|
|
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
|
|
|
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
|
|
|
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
|
|
|
+ C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
|
|
|
+ C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
|
|
|
+ C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
|
|
|
+ C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
|
|
|
+ }
|
|
|
+
|
|
|
+ inline void gemm_bloc_4x4(int64_t ii, int64_t jj) {
|
|
|
+ size_t vl = vlmax<V>();
|
|
|
+ D Cv00 = set_zero<D>();
|
|
|
+ D Cv01 = set_zero<D>();
|
|
|
+ D Cv02 = set_zero<D>();
|
|
|
+ D Cv03 = set_zero<D>();
|
|
|
+ D Cv10 = set_zero<D>();
|
|
|
+ D Cv11 = set_zero<D>();
|
|
|
+ D Cv12 = set_zero<D>();
|
|
|
+ D Cv13 = set_zero<D>();
|
|
|
+ D Cv20 = set_zero<D>();
|
|
|
+ D Cv21 = set_zero<D>();
|
|
|
+ D Cv22 = set_zero<D>();
|
|
|
+ D Cv23 = set_zero<D>();
|
|
|
+ D Cv30 = set_zero<D>();
|
|
|
+ D Cv31 = set_zero<D>();
|
|
|
+ D Cv32 = set_zero<D>();
|
|
|
+ D Cv33 = set_zero<D>();
|
|
|
+
|
|
|
+ for (int64_t l = 0; l < k; l += vl) {
|
|
|
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
|
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
|
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
|
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
|
+
|
|
|
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
|
+ Cv00 = madd(Av0, Bv0, Cv00);
|
|
|
+ Cv01 = madd(Av1, Bv0, Cv01);
|
|
|
+ Cv02 = madd(Av2, Bv0, Cv02);
|
|
|
+ Cv03 = madd(Av3, Bv0, Cv03);
|
|
|
+
|
|
|
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
|
+ Cv10 = madd(Av0, Bv1, Cv10);
|
|
|
+ Cv11 = madd(Av1, Bv1, Cv11);
|
|
|
+ Cv12 = madd(Av2, Bv1, Cv12);
|
|
|
+ Cv13 = madd(Av3, Bv1, Cv13);
|
|
|
+
|
|
|
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
|
|
|
+ Cv20 = madd(Av0, Bv2, Cv20);
|
|
|
+ Cv21 = madd(Av1, Bv2, Cv21);
|
|
|
+ Cv22 = madd(Av2, Bv2, Cv22);
|
|
|
+ Cv23 = madd(Av3, Bv2, Cv23);
|
|
|
+
|
|
|
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
|
|
|
+ Cv30 = madd(Av0, Bv3, Cv30);
|
|
|
+ Cv31 = madd(Av1, Bv3, Cv31);
|
|
|
+ Cv32 = madd(Av2, Bv3, Cv32);
|
|
|
+ Cv33 = madd(Av3, Bv3, Cv33);
|
|
|
+ }
|
|
|
+
|
|
|
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
|
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
|
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
|
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
|
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
|
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
|
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
|
|
|
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
|
|
|
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
|
|
|
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
|
|
|
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
|
|
|
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
|
|
|
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
|
|
|
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
|
|
|
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
|
|
|
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
|
|
|
+ }
|
|
|
+
|
|
|
+ inline void gemm_bloc_4x3(int64_t ii, int64_t jj) {
|
|
|
+ size_t vl = vlmax<V>();
|
|
|
+ D Cv00 = set_zero<D>();
|
|
|
+ D Cv01 = set_zero<D>();
|
|
|
+ D Cv02 = set_zero<D>();
|
|
|
+ D Cv03 = set_zero<D>();
|
|
|
+ D Cv10 = set_zero<D>();
|
|
|
+ D Cv11 = set_zero<D>();
|
|
|
+ D Cv12 = set_zero<D>();
|
|
|
+ D Cv13 = set_zero<D>();
|
|
|
+ D Cv20 = set_zero<D>();
|
|
|
+ D Cv21 = set_zero<D>();
|
|
|
+ D Cv22 = set_zero<D>();
|
|
|
+ D Cv23 = set_zero<D>();
|
|
|
+
|
|
|
+ for (int64_t l = 0; l < k; l += vl) {
|
|
|
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
|
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
|
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
|
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
|
+
|
|
|
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
|
+ Cv00 = madd(Av0, Bv0, Cv00);
|
|
|
+ Cv01 = madd(Av1, Bv0, Cv01);
|
|
|
+ Cv02 = madd(Av2, Bv0, Cv02);
|
|
|
+ Cv03 = madd(Av3, Bv0, Cv03);
|
|
|
+
|
|
|
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
|
+ Cv10 = madd(Av0, Bv1, Cv10);
|
|
|
+ Cv11 = madd(Av1, Bv1, Cv11);
|
|
|
+ Cv12 = madd(Av2, Bv1, Cv12);
|
|
|
+ Cv13 = madd(Av3, Bv1, Cv13);
|
|
|
+
|
|
|
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
|
|
|
+ Cv20 = madd(Av0, Bv2, Cv20);
|
|
|
+ Cv21 = madd(Av1, Bv2, Cv21);
|
|
|
+ Cv22 = madd(Av2, Bv2, Cv22);
|
|
|
+ Cv23 = madd(Av3, Bv2, Cv23);
|
|
|
+ }
|
|
|
+
|
|
|
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
|
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
|
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
|
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
|
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
|
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
|
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
|
|
|
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
|
|
|
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
|
|
|
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
|
|
|
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
|
|
|
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
|
|
|
+ }
|
|
|
+
|
|
|
+ inline void gemm_bloc_4x2(int64_t ii, int64_t jj) {
|
|
|
+ size_t vl = vlmax<V>();
|
|
|
+ D Cv00 = set_zero<D>();
|
|
|
+ D Cv01 = set_zero<D>();
|
|
|
+ D Cv02 = set_zero<D>();
|
|
|
+ D Cv03 = set_zero<D>();
|
|
|
+ D Cv10 = set_zero<D>();
|
|
|
+ D Cv11 = set_zero<D>();
|
|
|
+ D Cv12 = set_zero<D>();
|
|
|
+ D Cv13 = set_zero<D>();
|
|
|
+
|
|
|
+ for (int64_t l = 0; l < k; l += vl) {
|
|
|
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
|
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
|
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
|
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
|
+
|
|
|
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
|
+ Cv00 = madd(Av0, Bv0, Cv00);
|
|
|
+ Cv01 = madd(Av1, Bv0, Cv01);
|
|
|
+ Cv02 = madd(Av2, Bv0, Cv02);
|
|
|
+ Cv03 = madd(Av3, Bv0, Cv03);
|
|
|
+
|
|
|
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
|
+ Cv10 = madd(Av0, Bv1, Cv10);
|
|
|
+ Cv11 = madd(Av1, Bv1, Cv11);
|
|
|
+ Cv12 = madd(Av2, Bv1, Cv12);
|
|
|
+ Cv13 = madd(Av3, Bv1, Cv13);
|
|
|
+ }
|
|
|
+
|
|
|
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
|
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
|
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
|
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
|
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
|
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
|
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
|
|
|
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
|
|
|
+ }
|
|
|
+
|
|
|
+ inline void gemm_bloc_4x1(int64_t ii, int64_t jj) {
|
|
|
+ size_t vl = vlmax<V>();
|
|
|
+ D Cv00 = set_zero<D>();
|
|
|
+ D Cv01 = set_zero<D>();
|
|
|
+ D Cv02 = set_zero<D>();
|
|
|
+ D Cv03 = set_zero<D>();
|
|
|
+
|
|
|
+ for (int64_t l = 0; l < k; l += vl) {
|
|
|
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
|
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
|
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
|
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
|
+
|
|
|
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
|
+ Cv00 = madd(Av0, Bv0, Cv00);
|
|
|
+ Cv01 = madd(Av1, Bv0, Cv01);
|
|
|
+ Cv02 = madd(Av2, Bv0, Cv02);
|
|
|
+ Cv03 = madd(Av3, Bv0, Cv03);
|
|
|
+ }
|
|
|
+
|
|
|
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
|
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
|
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
|
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
|
+ }
|
|
|
+
|
|
|
+ inline void gemm_bloc_2x2(int64_t ii, int64_t jj) {
|
|
|
+ size_t vl = vlmax<V>();
|
|
|
+ D Cv00 = set_zero<D>();
|
|
|
+ D Cv01 = set_zero<D>();
|
|
|
+ D Cv10 = set_zero<D>();
|
|
|
+ D Cv11 = set_zero<D>();
|
|
|
+
|
|
|
+ for (int64_t l = 0; l < k; l += vl) {
|
|
|
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
|
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
|
+
|
|
|
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
|
+ Cv00 = madd(Av0, Bv0, Cv00);
|
|
|
+ Cv01 = madd(Av1, Bv0, Cv01);
|
|
|
+
|
|
|
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
|
+ Cv10 = madd(Av0, Bv1, Cv10);
|
|
|
+ Cv11 = madd(Av1, Bv1, Cv11);
|
|
|
+ }
|
|
|
+
|
|
|
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
|
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
|
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
|
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
|
+ }
|
|
|
+
|
|
|
+ inline void gemm_bloc_2x1(int64_t ii, int64_t jj) {
|
|
|
+ size_t vl = vlmax<V>();
|
|
|
+ D Cv00 = set_zero<D>();
|
|
|
+ D Cv01 = set_zero<D>();
|
|
|
+
|
|
|
+ for (int64_t l = 0; l < k; l += vl) {
|
|
|
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
|
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
|
+
|
|
|
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
|
+ Cv00 = madd(Av0, Bv0, Cv00);
|
|
|
+ Cv01 = madd(Av1, Bv0, Cv01);
|
|
|
+ }
|
|
|
+
|
|
|
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
|
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
|
+ }
|
|
|
+
|
|
|
+ template <int RM, int RN>
|
|
|
+ inline void gemm_bloc(int64_t ii, int64_t jj) {
|
|
|
+ if constexpr (RM == 4) {
|
|
|
+ if constexpr (RN == 6) { return gemm_bloc_4x6(ii, jj); }
|
|
|
+ if constexpr (RN == 5) { return gemm_bloc_4x5(ii, jj); }
|
|
|
+ if constexpr (RN == 4) { return gemm_bloc_4x4(ii, jj); }
|
|
|
+ if constexpr (RN == 3) { return gemm_bloc_4x3(ii, jj); }
|
|
|
+ if constexpr (RN == 2) { return gemm_bloc_4x2(ii, jj); }
|
|
|
+ if constexpr (RN == 1) { return gemm_bloc_4x1(ii, jj); }
|
|
|
+ } else if constexpr (RM == 2) {
|
|
|
+ if constexpr (RN == 2) { return gemm_bloc_2x2(ii, jj); }
|
|
|
+ if constexpr (RN == 1) { return gemm_bloc_2x1(ii, jj); }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ template <int RM, int RN, int BM>
|
|
|
+ NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
|
|
|
+ GGML_ASSERT(m % (RM * BM) == 0);
|
|
|
+ const int64_t ytiles = m / (RM * BM);
|
|
|
+ const int64_t xtiles = (n + RN -1) / RN;
|
|
|
+ const int64_t jj_RN = (xtiles - (xtiles * RN - n));
|
|
|
+
|
|
|
+ // "round" bloc_size to "nearest" BN
|
|
|
+ const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
|
|
|
+ const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
|
|
|
+ const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
|
|
|
+ const int64_t nb_job = ytiles * NB_BN;
|
|
|
+
|
|
|
+ if (params->ith == 0) {
|
|
|
+ GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
|
|
|
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
|
|
+ ggml_threadpool_chunk_set(params->threadpool, params->nth);
|
|
|
+ }
|
|
|
+
|
|
|
+ ggml_barrier(params->threadpool);
|
|
|
+
|
|
|
+ int64_t job = params->ith;
|
|
|
+ while (job < nb_job) {
|
|
|
+ const int64_t ii = (job % ytiles) * RM * BM;
|
|
|
+ const int64_t jb = job / ytiles;
|
|
|
+ const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
|
|
|
+ const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
|
|
|
+
|
|
|
+ const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
|
|
|
+ const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
|
|
|
+ const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
|
|
|
+
|
|
|
+ for (int64_t bi = 0; bi < BM * RM; bi += RM) {
|
|
|
+ int64_t jj = jj0;
|
|
|
+ for (; jj < jj1; jj += RN) {
|
|
|
+ gemm_bloc<RM, RN>(ii + bi, jj);
|
|
|
+ }
|
|
|
+ if constexpr (RN > 1) {
|
|
|
+ for (; jj < jj2; jj += RN - 1) {
|
|
|
+ gemm_bloc<RM, RN-1>(ii + bi, jj);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ GGML_ASSERT(jj == jj2);
|
|
|
+ }
|
|
|
+
|
|
|
+ job = ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
|
+ }
|
|
|
+
|
|
|
+ ggml_barrier(params->threadpool);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const ggml_compute_params * params;
|
|
|
+ const TA *const A;
|
|
|
+ const TB *const B;
|
|
|
+ TC *const C;
|
|
|
+ const int64_t k;
|
|
|
+ const int64_t lda;
|
|
|
+ const int64_t ldb;
|
|
|
+ const int64_t ldc;
|
|
|
+};
|
|
|
+#endif
|
|
|
+
|
|
|
//////////////////////////////////////////////////////////////////////////////////////////
|
|
|
// QUANT ZERO MATRIX MULTIPLICATION
|
|
|
|
|
|
@@ -2657,6 +3369,24 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
params->ith, params->nth};
|
|
|
tb.matmul(m, n);
|
|
|
return true;
|
|
|
+#elif defined(__riscv_zvfh)
|
|
|
+ #if LMUL == 1
|
|
|
+ tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params,
|
|
|
+ k, (const float *)A, lda,
|
|
|
+ (const float *)B, ldb,
|
|
|
+ (float *)C, ldc};
|
|
|
+ #elif LMUL == 2
|
|
|
+ tinyBLAS_RVV<vfloat32m2_t, vfloat32m2_t, float, float, float> tb{ params,
|
|
|
+ k, (const float *)A, lda,
|
|
|
+ (const float *)B, ldb,
|
|
|
+ (float *)C, ldc};
|
|
|
+ #else // LMUL = 4
|
|
|
+ tinyBLAS_RVV<vfloat32m4_t, vfloat32m4_t, float, float, float> tb{ params,
|
|
|
+ k, (const float *)A, lda,
|
|
|
+ (const float *)B, ldb,
|
|
|
+ (float *)C, ldc};
|
|
|
+ #endif
|
|
|
+ return tb.matmul(m, n);
|
|
|
#else
|
|
|
return false;
|
|
|
#endif
|
|
|
@@ -2699,6 +3429,24 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
tb.matmul(m, n);
|
|
|
return true;
|
|
|
}
|
|
|
+#elif defined(__riscv_zvfbfwma)
|
|
|
+ #if LMUL == 1
|
|
|
+ tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
|
|
+ k, (const ggml_bf16_t *)A, lda,
|
|
|
+ (const ggml_bf16_t *)B, ldb,
|
|
|
+ (float *)C, ldc};
|
|
|
+ #elif LMUL == 2
|
|
|
+ tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
|
|
+ k, (const ggml_bf16_t *)A, lda,
|
|
|
+ (const ggml_bf16_t *)B, ldb,
|
|
|
+ (float *)C, ldc};
|
|
|
+ #else // LMUL = 4
|
|
|
+ tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
|
|
+ k, (const ggml_bf16_t *)A, lda,
|
|
|
+ (const ggml_bf16_t *)B, ldb,
|
|
|
+ (float *)C, ldc};
|
|
|
+ #endif
|
|
|
+ return tb.matmul(m, n);
|
|
|
#endif
|
|
|
return false;
|
|
|
}
|
|
|
@@ -2748,6 +3496,26 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
(float *)C, ldc};
|
|
|
return tb.matmul(m, n);
|
|
|
}
|
|
|
+#elif defined(__riscv_zvfh)
|
|
|
+ if (Btype == GGML_TYPE_F16) {
|
|
|
+ #if LMUL == 1
|
|
|
+ tinyBLAS_RVV<vfloat32m1_t, vfloat16mf2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
|
|
|
+ k, (const ggml_fp16_t *)A, lda,
|
|
|
+ (const ggml_fp16_t *)B, ldb,
|
|
|
+ (float *)C, ldc};
|
|
|
+ #elif LMUL == 2
|
|
|
+ tinyBLAS_RVV<vfloat32m2_t, vfloat16m1_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
|
|
|
+ k, (const ggml_fp16_t *)A, lda,
|
|
|
+ (const ggml_fp16_t *)B, ldb,
|
|
|
+ (float *)C, ldc};
|
|
|
+ #else // LMUL = 4
|
|
|
+ tinyBLAS_RVV<vfloat32m4_t, vfloat16m2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
|
|
|
+ k, (const ggml_fp16_t *)A, lda,
|
|
|
+ (const ggml_fp16_t *)B, ldb,
|
|
|
+ (float *)C, ldc};
|
|
|
+ #endif
|
|
|
+ return tb.matmul(m, n);
|
|
|
+ }
|
|
|
#endif
|
|
|
return false;
|
|
|
}
|