|
|
@@ -53,6 +53,8 @@
|
|
|
#include "ggml-cpu-impl.h"
|
|
|
#include "ggml-quants.h"
|
|
|
|
|
|
+#include <atomic>
|
|
|
+
|
|
|
#ifdef _MSC_VER
|
|
|
#define NOINLINE __declspec(noinline)
|
|
|
#else
|
|
|
@@ -134,6 +136,16 @@ inline __m512 madd(__m512 a, __m512 b, __m512 c) {
|
|
|
return _mm512_fmadd_ps(a, b, c);
|
|
|
}
|
|
|
#endif
|
|
|
+#if defined(__AVX512BF16__)
|
|
|
+template <>
|
|
|
+inline __m512 madd(__m512bh a, __m512bh b, __m512 c) {
|
|
|
+ return _mm512_dpbf16_ps(c, a, b);
|
|
|
+}
|
|
|
+template <>
|
|
|
+inline __m256 madd(__m256bh a, __m256bh b, __m256 c) {
|
|
|
+ return _mm256_dpbf16_ps(c, a, b);
|
|
|
+}
|
|
|
+#endif
|
|
|
#endif
|
|
|
|
|
|
#if defined(__ARM_FEATURE_FMA)
|
|
|
@@ -226,6 +238,13 @@ template <> inline __m256 load(const float *p) {
|
|
|
}
|
|
|
#endif // __AVX__
|
|
|
|
|
|
+#if defined(__AVX2__) || defined(__AVX512F__)
|
|
|
+template <> inline __m256 load(const ggml_bf16_t *p) {
|
|
|
+ return _mm256_castsi256_ps(
|
|
|
+ _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)p)), 16));
|
|
|
+}
|
|
|
+#endif // __AVX2__
|
|
|
+
|
|
|
#if defined(__F16C__)
|
|
|
template <> inline __m256 load(const ggml_fp16_t *p) {
|
|
|
return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
|
|
|
@@ -239,8 +258,27 @@ template <> inline __m512 load(const float *p) {
|
|
|
template <> inline __m512 load(const ggml_fp16_t *p) {
|
|
|
return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
|
|
|
}
|
|
|
+template <> inline __m512 load(const ggml_bf16_t *p) {
|
|
|
+ return _mm512_castsi512_ps(
|
|
|
+ _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)p)), 16));
|
|
|
+}
|
|
|
#endif // __AVX512F__
|
|
|
|
|
|
+#if defined(__AVX512BF16__)
|
|
|
+template <> inline __m512bh load(const ggml_bf16_t *p) {
|
|
|
+ return (__m512bh)_mm512_loadu_ps((const float *)p);
|
|
|
+}
|
|
|
+template <> inline __m256bh load(const ggml_bf16_t *p) {
|
|
|
+ return (__m256bh)_mm256_loadu_ps((const float *)p);
|
|
|
+}
|
|
|
+template <> inline __m512bh load(const float *p) {
|
|
|
+ return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p));
|
|
|
+}
|
|
|
+template <> inline __m256bh load(const float *p) {
|
|
|
+ return _mm512_cvtneps_pbh(_mm512_loadu_ps(p));
|
|
|
+}
|
|
|
+#endif
|
|
|
+
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
// CONSTANTS
|
|
|
|
|
|
@@ -252,199 +290,170 @@ static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
// FLOATING POINT MATRIX MULTIPLICATION
|
|
|
|
|
|
+template <int M>
|
|
|
+static inline int64_t BLOCK_SIZE(size_t m) {
|
|
|
+ const int64_t NB_BLOC_M = (m + M - 1) / M;
|
|
|
+ return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
|
|
|
+}
|
|
|
+
|
|
|
+static constexpr inline int64_t BLOC_POS(int64_t ib, int64_t ibN, int64_t bloc_size) {
|
|
|
+ return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1);
|
|
|
+}
|
|
|
+
|
|
|
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
|
|
|
class tinyBLAS {
|
|
|
public:
|
|
|
- tinyBLAS(int64_t k,
|
|
|
+ tinyBLAS(const ggml_compute_params * params, int64_t k,
|
|
|
const TA *A, int64_t lda,
|
|
|
const TB *B, int64_t ldb,
|
|
|
- TC *C, int64_t ldc,
|
|
|
- int ith, int nth)
|
|
|
- : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
|
|
+ TC *C, int64_t ldc)
|
|
|
+ : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
|
|
|
}
|
|
|
|
|
|
- void matmul(int64_t m, int64_t n) {
|
|
|
- mnpack(0, m, 0, n);
|
|
|
+ bool matmul(int64_t m, int64_t n) {
|
|
|
+ if (k % KN != 0)
|
|
|
+ return false;
|
|
|
+ // compute RM for only need tile with size RM&RM-1
|
|
|
+#if VECTOR_REGISTERS == 32
|
|
|
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
|
|
|
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
|
|
+ mnpack<4, 6, 4>(m, n, SIZE_N, 12);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (m % 8 == 0 ) {
|
|
|
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
|
|
+ mnpack<4, 6, 2>(m, n, SIZE_N, 12);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (m % 4 == 0) {
|
|
|
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
|
|
+ mnpack<4, 6, 1>(m, n, SIZE_N, 12);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+#else // VECTOR_REGISTERS == 16
|
|
|
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
|
|
|
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
|
|
+ mnpack<4, 3, 4>(m, n, SIZE_N, 24);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (m % 8 == 0 ) {
|
|
|
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
|
|
+ mnpack<4, 3, 2>(m, n, SIZE_N, 24);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (m % 4 == 0) {
|
|
|
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
|
|
+ mnpack<4, 3, 1>(m, n, SIZE_N, 24);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+#endif
|
|
|
+ return false;
|
|
|
}
|
|
|
|
|
|
private:
|
|
|
- NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
|
- int64_t mc, nc, mp, np;
|
|
|
- switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
|
|
|
-#if VECTOR_REGISTERS == 32
|
|
|
- case 0x55:
|
|
|
- mc = 5;
|
|
|
- nc = 5;
|
|
|
- gemm<5, 5>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x45:
|
|
|
- mc = 4;
|
|
|
- nc = 5;
|
|
|
- gemm<4, 5>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x54:
|
|
|
- mc = 5;
|
|
|
- nc = 4;
|
|
|
- gemm<5, 4>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x44:
|
|
|
- mc = 4;
|
|
|
- nc = 4;
|
|
|
- gemm<4, 4>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x53:
|
|
|
- mc = 5;
|
|
|
- nc = 3;
|
|
|
- gemm<5, 3>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x35:
|
|
|
- mc = 3;
|
|
|
- nc = 5;
|
|
|
- gemm<3, 5>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x43:
|
|
|
- mc = 4;
|
|
|
- nc = 3;
|
|
|
- gemm<4, 3>(m0, m, n0, n);
|
|
|
- break;
|
|
|
-#else
|
|
|
- case 0x55:
|
|
|
- case 0x54:
|
|
|
- case 0x53:
|
|
|
- case 0x45:
|
|
|
- case 0x44:
|
|
|
- case 0x43:
|
|
|
- mc = 4;
|
|
|
- nc = 3;
|
|
|
- gemm<4, 3>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x35:
|
|
|
-#endif
|
|
|
- case 0x34:
|
|
|
- mc = 3;
|
|
|
- nc = 4;
|
|
|
- gemm<3, 4>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x52:
|
|
|
- mc = 5;
|
|
|
- nc = 2;
|
|
|
- gemm<5, 2>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x33:
|
|
|
- mc = 3;
|
|
|
- nc = 3;
|
|
|
- gemm<3, 3>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x25:
|
|
|
- mc = 2;
|
|
|
- nc = 5;
|
|
|
- gemm<2, 5>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x42:
|
|
|
- mc = 4;
|
|
|
- nc = 2;
|
|
|
- gemm<4, 2>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x24:
|
|
|
- mc = 2;
|
|
|
- nc = 4;
|
|
|
- gemm<2, 4>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x32:
|
|
|
- mc = 3;
|
|
|
- nc = 2;
|
|
|
- gemm<3, 2>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x23:
|
|
|
- mc = 2;
|
|
|
- nc = 3;
|
|
|
- gemm<2, 3>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x51:
|
|
|
- mc = 5;
|
|
|
- nc = 1;
|
|
|
- gemm<5, 1>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x41:
|
|
|
- mc = 4;
|
|
|
- nc = 1;
|
|
|
- gemm<4, 1>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x22:
|
|
|
- mc = 2;
|
|
|
- nc = 2;
|
|
|
- gemm<2, 2>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x15:
|
|
|
- mc = 1;
|
|
|
- nc = 5;
|
|
|
- gemm<1, 5>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x14:
|
|
|
- mc = 1;
|
|
|
- nc = 4;
|
|
|
- gemm<1, 4>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x31:
|
|
|
- mc = 3;
|
|
|
- nc = 1;
|
|
|
- gemm<3, 1>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x13:
|
|
|
- mc = 1;
|
|
|
- nc = 3;
|
|
|
- gemm<1, 3>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x21:
|
|
|
- mc = 2;
|
|
|
- nc = 1;
|
|
|
- gemm<2, 1>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x12:
|
|
|
- mc = 1;
|
|
|
- nc = 2;
|
|
|
- gemm<1, 2>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- case 0x11:
|
|
|
- mc = 1;
|
|
|
- nc = 1;
|
|
|
- gemm<1, 1>(m0, m, n0, n);
|
|
|
- break;
|
|
|
- default:
|
|
|
- return;
|
|
|
+ template <int RM, int RN, int BM>
|
|
|
+ inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
|
|
|
+ if (SIZE_N == RN) {
|
|
|
+ return gemm<RM, RN, BM>(m, n, BN);
|
|
|
+ }
|
|
|
+ if constexpr (RN > 1) {
|
|
|
+ return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
|
|
|
+ } else {
|
|
|
+ GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
|
|
|
+ GGML_ASSERT(false); // we have miss something.
|
|
|
}
|
|
|
- mp = m0 + (m - m0) / mc * mc;
|
|
|
- np = n0 + (n - n0) / nc * nc;
|
|
|
- mnpack(mp, m, n0, np);
|
|
|
- mnpack(m0, m, np, n);
|
|
|
}
|
|
|
|
|
|
template <int RM, int RN>
|
|
|
- NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
|
- int64_t ytiles = (m - m0) / RM;
|
|
|
- int64_t xtiles = (n - n0) / RN;
|
|
|
- int64_t tiles = xtiles * ytiles;
|
|
|
- int64_t duty = (tiles + nth - 1) / nth;
|
|
|
- int64_t start = duty * ith;
|
|
|
- int64_t end = start + duty;
|
|
|
- if (end > tiles)
|
|
|
- end = tiles;
|
|
|
- for (int64_t job = start; job < end; ++job) {
|
|
|
- int64_t ii = m0 + job / xtiles * RM;
|
|
|
- int64_t jj = n0 + job % xtiles * RN;
|
|
|
- D Cv[RN][RM] = {};
|
|
|
- for (int64_t l = 0; l < k; l += KN)
|
|
|
- for (int64_t j = 0; j < RN; ++j)
|
|
|
- for (int64_t i = 0; i < RM; ++i)
|
|
|
- Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
|
|
|
- load<V>(B + ldb * (jj + j) + l),
|
|
|
- Cv[j][i]);
|
|
|
- for (int64_t j = 0; j < RN; ++j)
|
|
|
- for (int64_t i = 0; i < RM; ++i)
|
|
|
- C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
|
|
+ inline void gemm_bloc(int64_t ii, int64_t jj) {
|
|
|
+ D Cv[RN][RM] = {};
|
|
|
+ for (int64_t l = 0; l < k; l += KN) {
|
|
|
+ // help compiler for op order.
|
|
|
+ if constexpr (RM <= RN) {
|
|
|
+ V Av[RM];
|
|
|
+ for (int64_t i = 0; i < RM; ++i) {
|
|
|
+ Av[i] = load<V>(A + lda * (ii + i) + l);
|
|
|
+ }
|
|
|
+ for (int64_t j = 0; j < RN; ++j) {
|
|
|
+ V Bv = load<V>(B + ldb * (jj + j) + l);
|
|
|
+ for (int64_t i = 0; i < RM; ++i) {
|
|
|
+ Cv[j][i] = madd(Av[i], Bv, Cv[j][i]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ V Bv[RN];
|
|
|
+ for (int64_t j = 0; j < RN; ++j) {
|
|
|
+ Bv[j] = load<V>(B + ldb * (jj + j) + l);
|
|
|
+ }
|
|
|
+ for (int64_t i = 0; i < RM; ++i) {
|
|
|
+ V Av = load<V>(A + lda * (ii + i) + l);
|
|
|
+ for (int64_t j = 0; j < RN; ++j) {
|
|
|
+ Cv[j][i] = madd(Av, Bv[j], Cv[j][i]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
+ for (int64_t j = 0; j < RN; ++j)
|
|
|
+ for (int64_t i = 0; i < RM; ++i)
|
|
|
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
|
|
}
|
|
|
|
|
|
+ template <int RM, int RN, int BM>
|
|
|
+ NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
|
|
|
+ static std::atomic<int64_t> current_chunk;
|
|
|
+
|
|
|
+ GGML_ASSERT(m % (RM * BM) == 0);
|
|
|
+ const int64_t ytiles = m / (RM * BM);
|
|
|
+ const int64_t xtiles = (n + RN -1) / RN;
|
|
|
+ const int64_t jj_RN = (xtiles - (xtiles * RN - n));
|
|
|
+
|
|
|
+ // "round" bloc_size to "nearest" BN
|
|
|
+ const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
|
|
|
+ const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
|
|
|
+ const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
|
|
|
+ const int64_t nb_job = ytiles * NB_BN;
|
|
|
+
|
|
|
+ if (params->ith == 0) {
|
|
|
+ GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
|
|
|
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
|
|
+ std::atomic_store_explicit(¤t_chunk, (int64_t)params->nth, std::memory_order_relaxed);
|
|
|
+ }
|
|
|
+
|
|
|
+ ggml_barrier(params->threadpool);
|
|
|
+
|
|
|
+ int64_t job = params->ith;
|
|
|
+ while (job < nb_job) {
|
|
|
+ const int64_t ii = (job % ytiles) * RM * BM;
|
|
|
+ const int64_t jb = job / ytiles;
|
|
|
+ const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
|
|
|
+ const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
|
|
|
+
|
|
|
+ const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
|
|
|
+ const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
|
|
|
+ const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
|
|
|
+
|
|
|
+ for (int64_t bi = 0; bi < BM * RM; bi += RM) {
|
|
|
+ int64_t jj = jj0;
|
|
|
+ for (; jj < jj1; jj += RN) {
|
|
|
+ gemm_bloc<RM, RN>(ii + bi, jj);
|
|
|
+ }
|
|
|
+ if constexpr (RN > 1) {
|
|
|
+ for (; jj < jj2; jj += RN - 1) {
|
|
|
+ gemm_bloc<RM, RN-1>(ii + bi, jj);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ GGML_ASSERT(jj == jj2);
|
|
|
+ }
|
|
|
+
|
|
|
+ // next step.
|
|
|
+ job = std::atomic_fetch_add_explicit(¤t_chunk, (int64_t)1, std::memory_order_relaxed);
|
|
|
+ }
|
|
|
+
|
|
|
+ ggml_barrier(params->threadpool);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const ggml_compute_params * params;
|
|
|
const TA *const A;
|
|
|
const TB *const B;
|
|
|
TC *const C;
|
|
|
@@ -452,8 +461,6 @@ class tinyBLAS {
|
|
|
const int64_t lda;
|
|
|
const int64_t ldb;
|
|
|
const int64_t ldc;
|
|
|
- const int ith;
|
|
|
- const int nth;
|
|
|
};
|
|
|
|
|
|
//////////////////////////////////////////////////////////////////////////////////////////
|
|
|
@@ -1657,8 +1664,9 @@ class tinyBLAS_PPC {
|
|
|
* @param Ctype is GGML data type of `C`
|
|
|
* @return true if this function was able to service the matmul request
|
|
|
*/
|
|
|
-bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
|
|
|
- int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
|
|
|
+bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
|
|
|
+ const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
|
|
|
+ int64_t ldc, int Atype, int Btype, int Ctype) {
|
|
|
|
|
|
assert(m >= 0);
|
|
|
assert(n >= 0);
|
|
|
@@ -1666,8 +1674,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
assert(lda >= k);
|
|
|
assert(ldb >= k);
|
|
|
assert(ldc >= m);
|
|
|
- assert(nth > 0);
|
|
|
- assert(ith < nth);
|
|
|
+ assert(params->nth > 0);
|
|
|
+ assert(params->ith < params->nth);
|
|
|
|
|
|
// only enable sgemm for prompt processing
|
|
|
if (n < 2)
|
|
|
@@ -1682,37 +1690,25 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
if (Btype != GGML_TYPE_F32)
|
|
|
return false;
|
|
|
#if defined(__AVX512F__)
|
|
|
- if (k % 16)
|
|
|
- return false;
|
|
|
- tinyBLAS<16, __m512, __m512, float, float, float> tb{
|
|
|
+ tinyBLAS<16, __m512, __m512, float, float, float> tb{ params,
|
|
|
k, (const float *)A, lda,
|
|
|
(const float *)B, ldb,
|
|
|
- (float *)C, ldc,
|
|
|
- ith, nth};
|
|
|
- tb.matmul(m, n);
|
|
|
- return true;
|
|
|
+ (float *)C, ldc};
|
|
|
+ return tb.matmul(m, n);
|
|
|
#elif defined(__AVX__) || defined(__AVX2__)
|
|
|
- if (k % 8)
|
|
|
- return false;
|
|
|
- tinyBLAS<8, __m256, __m256, float, float, float> tb{
|
|
|
+ tinyBLAS<8, __m256, __m256, float, float, float> tb{ params,
|
|
|
k, (const float *)A, lda,
|
|
|
(const float *)B, ldb,
|
|
|
- (float *)C, ldc,
|
|
|
- ith, nth};
|
|
|
- tb.matmul(m, n);
|
|
|
- return true;
|
|
|
+ (float *)C, ldc};
|
|
|
+ return tb.matmul(m, n);
|
|
|
#elif defined(__ARM_NEON)
|
|
|
if (n < 4)
|
|
|
return false;
|
|
|
- if (k % 4)
|
|
|
- return false;
|
|
|
- tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
|
|
|
+ tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
|
|
|
k, (const float *)A, lda,
|
|
|
(const float *)B, ldb,
|
|
|
- (float *)C, ldc,
|
|
|
- ith, nth};
|
|
|
- tb.matmul(m, n);
|
|
|
- return true;
|
|
|
+ (float *)C, ldc};
|
|
|
+ return tb.matmul(m, n);
|
|
|
#elif defined(__MMA__)
|
|
|
if (k % 8)
|
|
|
return false;
|
|
|
@@ -1720,7 +1716,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
k, (const float *)A, lda,
|
|
|
(const float *)B, ldb,
|
|
|
(float *)C, ldc,
|
|
|
- ith, nth};
|
|
|
+ params->ith, params->nth};
|
|
|
tb.matmul(m, n);
|
|
|
return true;
|
|
|
#else
|
|
|
@@ -1728,60 +1724,71 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
#endif
|
|
|
}
|
|
|
|
|
|
+ case GGML_TYPE_BF16: {
|
|
|
+#if defined(__AVX512BF16__)
|
|
|
+ if (Btype == GGML_TYPE_BF16) {
|
|
|
+ tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
|
|
|
+ (const ggml_bf16_t *)A, lda,
|
|
|
+ (const ggml_bf16_t *)B, ldb,
|
|
|
+ (float *)C, ldc};
|
|
|
+ return tb.matmul(m, n);
|
|
|
+ }
|
|
|
+#elif defined(__AVX512F__)
|
|
|
+ if (Btype == GGML_TYPE_BF16) {
|
|
|
+ tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
|
|
|
+ (const ggml_bf16_t *)A, lda,
|
|
|
+ (const ggml_bf16_t *)B, ldb,
|
|
|
+ (float *)C, ldc};
|
|
|
+ return tb.matmul(m, n);
|
|
|
+ }
|
|
|
+#elif defined(__AVX2__)
|
|
|
+ if (Btype == GGML_TYPE_BF16) {
|
|
|
+ tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
|
|
|
+ (const ggml_bf16_t *)A, lda,
|
|
|
+ (const ggml_bf16_t *)B, ldb,
|
|
|
+ (float *)C, ldc};
|
|
|
+ return tb.matmul(m, n);
|
|
|
+ }
|
|
|
+#endif
|
|
|
+ return false;
|
|
|
+ }
|
|
|
case GGML_TYPE_F16: {
|
|
|
#if defined(__AVX512F__)
|
|
|
- if (k % 16)
|
|
|
- return false;
|
|
|
- if (Btype != GGML_TYPE_F32)
|
|
|
- return false;
|
|
|
- tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
|
|
|
- k, (const ggml_fp16_t *)A, lda,
|
|
|
- (const float *)B, ldb,
|
|
|
- (float *)C, ldc,
|
|
|
- ith, nth};
|
|
|
- tb.matmul(m, n);
|
|
|
- return true;
|
|
|
+ if (Btype == GGML_TYPE_F16) {
|
|
|
+ tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
|
|
|
+ (const ggml_fp16_t *)A, lda,
|
|
|
+ (const ggml_fp16_t *)B, ldb,
|
|
|
+ (float *)C, ldc};
|
|
|
+ return tb.matmul(m, n);
|
|
|
+ }
|
|
|
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
|
|
|
- if (k % 8)
|
|
|
- return false;
|
|
|
- if (Btype != GGML_TYPE_F32)
|
|
|
- return false;
|
|
|
- tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
|
|
|
- k, (const ggml_fp16_t *)A, lda,
|
|
|
- (const float *)B, ldb,
|
|
|
- (float *)C, ldc,
|
|
|
- ith, nth};
|
|
|
- tb.matmul(m, n);
|
|
|
- return true;
|
|
|
+ if (Btype == GGML_TYPE_F16) {
|
|
|
+ tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
|
|
|
+ (const ggml_fp16_t *)A, lda,
|
|
|
+ (const ggml_fp16_t *)B, ldb,
|
|
|
+ (float *)C, ldc};
|
|
|
+ return tb.matmul(m, n);
|
|
|
+ }
|
|
|
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
|
|
|
if (n < 8)
|
|
|
return false;
|
|
|
- if (k % 8)
|
|
|
- return false;
|
|
|
- if (Btype != GGML_TYPE_F16)
|
|
|
- return false;
|
|
|
- tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
|
|
|
- k, (const ggml_fp16_t *)A, lda,
|
|
|
- (const ggml_fp16_t *)B, ldb,
|
|
|
- (float *)C, ldc,
|
|
|
- ith, nth};
|
|
|
- tb.matmul(m, n);
|
|
|
- return true;
|
|
|
+ if (Btype == GGML_TYPE_F16) {
|
|
|
+ tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
|
|
|
+ k, (const ggml_fp16_t *)A, lda,
|
|
|
+ (const ggml_fp16_t *)B, ldb,
|
|
|
+ (float *)C, ldc};
|
|
|
+ return tb.matmul(m, n);
|
|
|
+ }
|
|
|
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
|
|
|
- if (k % 4)
|
|
|
- return false;
|
|
|
- if (Btype != GGML_TYPE_F32)
|
|
|
- return false;
|
|
|
- tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
|
|
|
- k, (const ggml_fp16_t *)A, lda,
|
|
|
- (const float *)B, ldb,
|
|
|
- (float *)C, ldc,
|
|
|
- ith, nth};
|
|
|
- tb.matmul(m, n);
|
|
|
- return true;
|
|
|
-#else
|
|
|
- return false;
|
|
|
+ if (Btype == GGML_TYPE_F32) {
|
|
|
+ tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ params,
|
|
|
+ k, (const ggml_fp16_t *)A, lda,
|
|
|
+ (const float *)B, ldb,
|
|
|
+ (float *)C, ldc};
|
|
|
+ return tb.matmul(m, n);
|
|
|
+ }
|
|
|
#endif
|
|
|
+ return false;
|
|
|
}
|
|
|
|
|
|
case GGML_TYPE_Q8_0: {
|
|
|
@@ -1792,7 +1799,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
k, (const block_q8_0 *)A, lda,
|
|
|
(const block_q8_0 *)B, ldb,
|
|
|
(float *)C, ldc,
|
|
|
- ith, nth};
|
|
|
+ params->ith, params->nth};
|
|
|
tb.matmul(m, n);
|
|
|
return true;
|
|
|
#elif defined(__ARM_FEATURE_DOTPROD)
|
|
|
@@ -1800,7 +1807,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
k, (const block_q8_0 *)A, lda,
|
|
|
(const block_q8_0 *)B, ldb,
|
|
|
(float *)C, ldc,
|
|
|
- ith, nth};
|
|
|
+ params->ith, params->nth};
|
|
|
tb.matmul(m, n);
|
|
|
return true;
|
|
|
#else
|
|
|
@@ -1816,7 +1823,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
k, (const block_q4_0 *)A, lda,
|
|
|
(const block_q8_0 *)B, ldb,
|
|
|
(float *)C, ldc,
|
|
|
- ith, nth};
|
|
|
+ params->ith, params->nth};
|
|
|
tb.matmul(m, n);
|
|
|
return true;
|
|
|
#elif defined(__ARM_FEATURE_DOTPROD)
|
|
|
@@ -1824,7 +1831,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
k, (const block_q4_0 *)A, lda,
|
|
|
(const block_q8_0 *)B, ldb,
|
|
|
(float *)C, ldc,
|
|
|
- ith, nth};
|
|
|
+ params->ith, params->nth};
|
|
|
tb.matmul(m, n);
|
|
|
return true;
|
|
|
#else
|
|
|
@@ -1840,7 +1847,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
k, (const block_q5_0 *)A, lda,
|
|
|
(const block_q8_0 *)B, ldb,
|
|
|
(float *)C, ldc,
|
|
|
- ith, nth};
|
|
|
+ params->ith, params->nth};
|
|
|
tb.matmul(m, n);
|
|
|
return true;
|
|
|
#else
|
|
|
@@ -1856,7 +1863,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
k, (const block_iq4_nl *)A, lda,
|
|
|
(const block_q8_0 *)B, ldb,
|
|
|
(float *)C, ldc,
|
|
|
- ith, nth};
|
|
|
+ params->ith, params->nth};
|
|
|
tb.matmul(m, n);
|
|
|
return true;
|
|
|
#else
|
|
|
@@ -1868,6 +1875,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
return false;
|
|
|
}
|
|
|
|
|
|
+ (void)params;
|
|
|
(void)m;
|
|
|
(void)n;
|
|
|
(void)k;
|
|
|
@@ -1877,8 +1885,6 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
(void)ldb;
|
|
|
(void)C;
|
|
|
(void)ldc;
|
|
|
- (void)ith;
|
|
|
- (void)nth;
|
|
|
(void)Atype;
|
|
|
(void)Btype;
|
|
|
(void)Ctype;
|