1
0

sgemm.cpp 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995
  1. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
  2. // vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
  3. //
  4. // Copyright 2024 Mozilla Foundation
  5. //
  6. // Permission is hereby granted, free of charge, to any person obtaining
  7. // a copy of this software and associated documentation files (the
  8. // "Software"), to deal in the Software without restriction, including
  9. // without limitation the rights to use, copy, modify, merge, publish,
  10. // distribute, sublicense, and/or sell copies of the Software, and to
  11. // permit persons to whom the Software is furnished to do so, subject to
  12. // the following conditions:
  13. //
  14. // The above copyright notice and this permission notice shall be
  15. // included in all copies or substantial portions of the Software.
  16. //
  17. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  18. // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  19. // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  20. // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  21. // BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  22. // ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  23. // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. // SOFTWARE.
  25. //
  26. // _ _ ___ _ _ ___
  27. // | |_(_)_ _ _ _| _ ) | /_\ / __|
  28. // | _| | ' \ || | _ \ |__ / _ \\__ \.
  29. // \__|_|_||_\_, |___/____/_/ \_\___/
  30. // |__/
  31. //
  32. // BASIC LINEAR ALGEBRA SUBPROGRAMS
  33. //
  34. //
  35. // This file implements multithreaded CPU matrix multiplication for the
  36. // common contiguous use case C = Aᵀ * B. These kernels are designed to
  37. // have excellent performance[1] for matrices that fit in the CPU cache
  38. // without imposing any overhead such as cache filling or malloc calls.
  39. //
  40. // This implementation does not guarantee any upper bound with rounding
  41. // errors, which grow along with k. Our goal's to maximally exploit the
  42. // hardware for performance, and then use whatever resources remain for
  43. // improving numerical accuracy.
  44. //
  45. // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
  46. // Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
  47. #pragma GCC diagnostic ignored "-Wpedantic"
  48. #pragma GCC diagnostic ignored "-Wignored-attributes"
  49. #include "sgemm.h"
  50. #include "ggml-impl.h"
  51. #include "ggml-quants.h"
  52. #ifdef _MSC_VER
  53. #define NOINLINE __declspec(noinline)
  54. #else
  55. #define NOINLINE __attribute__((__noinline__))
  56. #endif
  57. #if defined(__ARM_NEON) || defined(__AVX512F__)
  58. #define VECTOR_REGISTERS 32
  59. #else
  60. #define VECTOR_REGISTERS 16
  61. #endif
  62. #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
  63. namespace {
  64. inline float unhalf(ggml_fp16_t d) {
  65. return GGML_FP16_TO_FP32(d);
  66. }
  67. ////////////////////////////////////////////////////////////////////////////////////////////////////
  68. // VECTORIZED ARITHMETIC OPERATIONS
  69. #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  70. inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
  71. inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
  72. inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
  73. #endif // __SSE__
  74. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  75. inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
  76. inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
  77. inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
  78. #endif // __AVX__
  79. #if defined(__AVX512F__)
  80. inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
  81. inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
  82. inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
  83. #endif // __AVX512F__
  84. #if defined(__ARM_NEON)
  85. inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
  86. inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
  87. inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
  88. #endif // __ARM_NEON
  89. #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
  90. inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
  91. inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
  92. inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
  93. #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  94. ////////////////////////////////////////////////////////////////////////////////////////////////////
  95. // VECTORIZED FUSED MULTIPLY ADD
  96. /**
  97. * Computes a * b + c.
  98. */
  99. template <typename T, typename U>
  100. inline U madd(T a, T b, U c) {
  101. return add(mul(a, b), c);
  102. }
  103. #if defined(__FMA__)
  104. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  105. template <>
  106. inline __m256 madd(__m256 a, __m256 b, __m256 c) {
  107. return _mm256_fmadd_ps(a, b, c);
  108. }
  109. #endif
  110. #if defined(__AVX512F__)
  111. template <>
  112. inline __m512 madd(__m512 a, __m512 b, __m512 c) {
  113. return _mm512_fmadd_ps(a, b, c);
  114. }
  115. #endif
  116. #endif
  117. #if defined(__ARM_FEATURE_FMA)
  118. template <>
  119. inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
  120. return vfmaq_f32(c, b, a);
  121. }
  122. #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
  123. template <>
  124. inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
  125. return vfmaq_f16(c, b, a);
  126. }
  127. #endif
  128. #endif
  129. ////////////////////////////////////////////////////////////////////////////////////////////////////
  130. // VECTORIZED HORIZONTAL SUM
  131. #if defined(__ARM_NEON)
  132. inline float hsum(float32x4_t x) {
  133. return vaddvq_f32(x);
  134. }
  135. #endif // __ARM_NEON
  136. #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
  137. inline float hsum(float16x8_t x) {
  138. return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
  139. vcvt_f32_f16(vget_high_f16(x))));
  140. }
  141. #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  142. #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  143. inline float hsum(__m128 x) {
  144. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  145. x = _mm_add_ps(x, _mm_movehl_ps(x, x));
  146. x = _mm_add_ss(x, _mm_movehdup_ps(x));
  147. #else
  148. __m128 t;
  149. t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
  150. x = _mm_add_ps(x, t);
  151. t = _mm_movehl_ps(t, x);
  152. x = _mm_add_ss(x, t);
  153. #endif
  154. return _mm_cvtss_f32(x);
  155. }
  156. #endif
  157. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  158. inline float hsum(__m256 x) {
  159. return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
  160. _mm256_castps256_ps128(x)));
  161. }
  162. #endif // __AVX__
  163. #if defined(__AVX512F__)
  164. inline float hsum(__m512 x) {
  165. return _mm512_reduce_add_ps(x);
  166. }
  167. #endif // __AVX512F__
  168. ////////////////////////////////////////////////////////////////////////////////////////////////////
  169. // VECTORIZED MEMORY LOADING
  170. template <typename T, typename U> T load(const U *);
  171. #if defined(__ARM_NEON)
  172. template <> inline float32x4_t load(const float *p) {
  173. return vld1q_f32(p);
  174. }
  175. #if !defined(_MSC_VER)
  176. template <> inline float16x8_t load(const ggml_fp16_t *p) {
  177. return vld1q_f16((const float16_t *)p);
  178. }
  179. template <> inline float32x4_t load(const ggml_fp16_t *p) {
  180. return vcvt_f32_f16(vld1_f16((const float16_t *)p));
  181. }
  182. #endif // _MSC_VER
  183. #endif // __ARM_NEON
  184. #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  185. template <> inline __m128 load(const float *p) {
  186. return _mm_loadu_ps(p);
  187. }
  188. #endif // __SSE__
  189. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  190. template <> inline __m256 load(const float *p) {
  191. return _mm256_loadu_ps(p);
  192. }
  193. #endif // __AVX__
  194. #if defined(__F16C__)
  195. template <> inline __m256 load(const ggml_fp16_t *p) {
  196. return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
  197. }
  198. #endif // __F16C__
  199. #if defined(__AVX512F__)
  200. template <> inline __m512 load(const float *p) {
  201. return _mm512_loadu_ps(p);
  202. }
  203. template <> inline __m512 load(const ggml_fp16_t *p) {
  204. return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
  205. }
  206. #endif // __AVX512F__
  207. ////////////////////////////////////////////////////////////////////////////////////////////////////
  208. // FLOATING POINT MATRIX MULTIPLICATION
  209. template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
  210. class tinyBLAS {
  211. public:
  212. tinyBLAS(int64_t k,
  213. const TA *A, int64_t lda,
  214. const TB *B, int64_t ldb,
  215. TC *C, int64_t ldc,
  216. int ith, int nth)
  217. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  218. }
  219. void matmul(int64_t m, int64_t n, int task) {
  220. if (task == GGML_TASK_TYPE_COMPUTE)
  221. mnpack(0, m, 0, n);
  222. }
  223. private:
  224. NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  225. int64_t mc, nc, mp, np;
  226. switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
  227. #if VECTOR_REGISTERS == 32
  228. case 0x55:
  229. mc = 5;
  230. nc = 5;
  231. gemm<5, 5>(m0, m, n0, n);
  232. break;
  233. case 0x45:
  234. mc = 4;
  235. nc = 5;
  236. gemm<4, 5>(m0, m, n0, n);
  237. break;
  238. case 0x54:
  239. mc = 5;
  240. nc = 4;
  241. gemm<5, 4>(m0, m, n0, n);
  242. break;
  243. case 0x44:
  244. mc = 4;
  245. nc = 4;
  246. gemm<4, 4>(m0, m, n0, n);
  247. break;
  248. case 0x53:
  249. mc = 5;
  250. nc = 3;
  251. gemm<5, 3>(m0, m, n0, n);
  252. break;
  253. case 0x35:
  254. mc = 3;
  255. nc = 5;
  256. gemm<3, 5>(m0, m, n0, n);
  257. break;
  258. case 0x43:
  259. mc = 4;
  260. nc = 3;
  261. gemm<4, 3>(m0, m, n0, n);
  262. break;
  263. #else
  264. case 0x55:
  265. case 0x54:
  266. case 0x53:
  267. case 0x45:
  268. case 0x44:
  269. case 0x43:
  270. mc = 4;
  271. nc = 3;
  272. gemm<4, 3>(m0, m, n0, n);
  273. break;
  274. case 0x35:
  275. #endif
  276. case 0x34:
  277. mc = 3;
  278. nc = 4;
  279. gemm<3, 4>(m0, m, n0, n);
  280. break;
  281. case 0x52:
  282. mc = 5;
  283. nc = 2;
  284. gemm<5, 2>(m0, m, n0, n);
  285. break;
  286. case 0x33:
  287. mc = 3;
  288. nc = 3;
  289. gemm<3, 3>(m0, m, n0, n);
  290. break;
  291. case 0x25:
  292. mc = 2;
  293. nc = 5;
  294. gemm<2, 5>(m0, m, n0, n);
  295. break;
  296. case 0x42:
  297. mc = 4;
  298. nc = 2;
  299. gemm<4, 2>(m0, m, n0, n);
  300. break;
  301. case 0x24:
  302. mc = 2;
  303. nc = 4;
  304. gemm<2, 4>(m0, m, n0, n);
  305. break;
  306. case 0x32:
  307. mc = 3;
  308. nc = 2;
  309. gemm<3, 2>(m0, m, n0, n);
  310. break;
  311. case 0x23:
  312. mc = 2;
  313. nc = 3;
  314. gemm<2, 3>(m0, m, n0, n);
  315. break;
  316. case 0x51:
  317. mc = 5;
  318. nc = 1;
  319. gemm<5, 1>(m0, m, n0, n);
  320. break;
  321. case 0x41:
  322. mc = 4;
  323. nc = 1;
  324. gemm<4, 1>(m0, m, n0, n);
  325. break;
  326. case 0x22:
  327. mc = 2;
  328. nc = 2;
  329. gemm<2, 2>(m0, m, n0, n);
  330. break;
  331. case 0x15:
  332. mc = 1;
  333. nc = 5;
  334. gemm<1, 5>(m0, m, n0, n);
  335. break;
  336. case 0x14:
  337. mc = 1;
  338. nc = 4;
  339. gemm<1, 4>(m0, m, n0, n);
  340. break;
  341. case 0x31:
  342. mc = 3;
  343. nc = 1;
  344. gemm<3, 1>(m0, m, n0, n);
  345. break;
  346. case 0x13:
  347. mc = 1;
  348. nc = 3;
  349. gemm<1, 3>(m0, m, n0, n);
  350. break;
  351. case 0x21:
  352. mc = 2;
  353. nc = 1;
  354. gemm<2, 1>(m0, m, n0, n);
  355. break;
  356. case 0x12:
  357. mc = 1;
  358. nc = 2;
  359. gemm<1, 2>(m0, m, n0, n);
  360. break;
  361. case 0x11:
  362. mc = 1;
  363. nc = 1;
  364. gemm<1, 1>(m0, m, n0, n);
  365. break;
  366. default:
  367. return;
  368. }
  369. mp = m0 + (m - m0) / mc * mc;
  370. np = n0 + (n - n0) / nc * nc;
  371. mnpack(mp, m, n0, np);
  372. mnpack(m0, m, np, n);
  373. }
  374. template <int RM, int RN>
  375. NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  376. int64_t ytiles = (m - m0) / RM;
  377. int64_t xtiles = (n - n0) / RN;
  378. int64_t tiles = xtiles * ytiles;
  379. int64_t duty = (tiles + nth - 1) / nth;
  380. int64_t start = duty * ith;
  381. int64_t end = start + duty;
  382. if (end > tiles)
  383. end = tiles;
  384. for (int64_t job = start; job < end; ++job) {
  385. int64_t ii = m0 + job / xtiles * RM;
  386. int64_t jj = n0 + job % xtiles * RN;
  387. D Cv[RN][RM] = {};
  388. for (int64_t l = 0; l < k; l += KN)
  389. for (int64_t j = 0; j < RN; ++j)
  390. for (int64_t i = 0; i < RM; ++i)
  391. Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
  392. load<V>(B + ldb * (jj + j) + l),
  393. Cv[j][i]);
  394. for (int64_t j = 0; j < RN; ++j)
  395. for (int64_t i = 0; i < RM; ++i)
  396. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  397. }
  398. }
  399. const TA *const A;
  400. const TB *const B;
  401. TC *const C;
  402. const int64_t k;
  403. const int64_t lda;
  404. const int64_t ldb;
  405. const int64_t ldc;
  406. const int ith;
  407. const int nth;
  408. };
  409. //////////////////////////////////////////////////////////////////////////////////////////
  410. // QUANT ZERO MATRIX MULTIPLICATION
  411. #if defined(__ARM_FEATURE_DOTPROD)
  412. template <typename TA>
  413. class tinyBLAS_Q0_ARM {
  414. public:
  415. tinyBLAS_Q0_ARM(int64_t k,
  416. const TA *A, int64_t lda,
  417. const block_q8_0 *B, int64_t ldb,
  418. float *C, int64_t ldc,
  419. int ith, int nth)
  420. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  421. }
  422. void matmul(int64_t m, int64_t n, int task) {
  423. if (task == GGML_TASK_TYPE_COMPUTE)
  424. mnpack(0, m, 0, n);
  425. }
  426. private:
  427. NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  428. int64_t mc, nc, mp, np;
  429. switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
  430. case 0x33:
  431. mc = 3;
  432. nc = 3;
  433. gemm<3, 3>(m0, m, n0, n);
  434. break;
  435. case 0x32:
  436. mc = 3;
  437. nc = 2;
  438. gemm<3, 2>(m0, m, n0, n);
  439. break;
  440. case 0x23:
  441. mc = 2;
  442. nc = 3;
  443. gemm<2, 3>(m0, m, n0, n);
  444. break;
  445. case 0x22:
  446. mc = 2;
  447. nc = 2;
  448. gemm<2, 2>(m0, m, n0, n);
  449. break;
  450. case 0x31:
  451. mc = 3;
  452. nc = 1;
  453. gemm<3, 1>(m0, m, n0, n);
  454. break;
  455. case 0x13:
  456. mc = 1;
  457. nc = 3;
  458. gemm<1, 3>(m0, m, n0, n);
  459. break;
  460. case 0x21:
  461. mc = 2;
  462. nc = 1;
  463. gemm<2, 1>(m0, m, n0, n);
  464. break;
  465. case 0x12:
  466. mc = 1;
  467. nc = 2;
  468. gemm<1, 2>(m0, m, n0, n);
  469. break;
  470. case 0x11:
  471. mc = 1;
  472. nc = 1;
  473. gemm<1, 1>(m0, m, n0, n);
  474. break;
  475. default:
  476. return;
  477. }
  478. mp = m0 + (m - m0) / mc * mc;
  479. np = n0 + (n - n0) / nc * nc;
  480. mnpack(mp, m, n0, np);
  481. mnpack(m0, m, np, n);
  482. }
  483. template <int RM, int RN>
  484. NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  485. int64_t ytiles = (m - m0) / RM;
  486. int64_t xtiles = (n - n0) / RN;
  487. int64_t tiles = xtiles * ytiles;
  488. int64_t duty = (tiles + nth - 1) / nth;
  489. int64_t start = duty * ith;
  490. int64_t end = start + duty;
  491. if (end > tiles)
  492. end = tiles;
  493. for (int64_t job = start; job < end; ++job) {
  494. int64_t ii = m0 + job / xtiles * RM;
  495. int64_t jj = n0 + job % xtiles * RN;
  496. float32x4_t Cv[RN][RM] = {};
  497. for (int64_t l = 0; l < k; ++l)
  498. for (int64_t j = 0; j < RN; ++j)
  499. for (int64_t i = 0; i < RM; ++i)
  500. Cv[j][i] = vmlaq_n_f32(Cv[j][i],
  501. vcvtq_f32_s32(vdotq_s32(
  502. vdotq_s32(vdupq_n_s32(0),
  503. load_lo(A + lda * (ii + i) + l),
  504. load_lo(B + ldb * (jj + j) + l)),
  505. load_hi(A + lda * (ii + i) + l),
  506. load_hi(B + ldb * (jj + j) + l))),
  507. unhalf(A[lda * (ii + i) + l].d) *
  508. unhalf(B[ldb * (jj + j) + l].d));
  509. for (int64_t j = 0; j < RN; ++j)
  510. for (int64_t i = 0; i < RM; ++i)
  511. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  512. }
  513. }
  514. inline int8x16_t load_lo(const block_q8_0 *b) {
  515. return vld1q_s8(b->qs);
  516. }
  517. inline int8x16_t load_hi(const block_q8_0 *b) {
  518. return vld1q_s8(b->qs + 16);
  519. }
  520. inline int8x16_t load_lo(const block_q4_0 *b) {
  521. return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
  522. vdupq_n_u8(0x0f))),
  523. vdupq_n_s8(0x8));
  524. }
  525. inline int8x16_t load_hi(const block_q4_0 *b) {
  526. return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
  527. vdupq_n_s8(0x8));
  528. }
  529. const TA *const A;
  530. const block_q8_0 *const B;
  531. float *const C;
  532. const int64_t k;
  533. const int64_t lda;
  534. const int64_t ldb;
  535. const int64_t ldc;
  536. const int ith;
  537. const int nth;
  538. };
  539. #endif // __ARM_FEATURE_DOTPROD
  540. #if defined(__AVX2__) || defined(__AVX512F__)
  541. template <typename TA, typename TB, typename TC>
  542. class tinyBLAS_Q0_AVX2 {
  543. public:
  544. tinyBLAS_Q0_AVX2(int64_t k,
  545. const TA *A, int64_t lda,
  546. const TB *B, int64_t ldb,
  547. TC *C, int64_t ldc,
  548. int ith, int nth)
  549. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  550. }
  551. void matmul(int64_t m, int64_t n, int task) {
  552. if (task == GGML_TASK_TYPE_COMPUTE)
  553. mnpack(0, m, 0, n);
  554. }
  555. private:
  556. void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  557. int64_t mc, nc, mp, np;
  558. switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
  559. #if VECTOR_REGISTERS == 32
  560. case 0x44:
  561. mc = 4;
  562. nc = 4;
  563. gemm<4, 4>(m0, m, n0, n);
  564. break;
  565. case 0x43:
  566. mc = 4;
  567. nc = 3;
  568. gemm<4, 3>(m0, m, n0, n);
  569. break;
  570. case 0x34:
  571. mc = 3;
  572. nc = 4;
  573. gemm<3, 4>(m0, m, n0, n);
  574. break;
  575. case 0x33:
  576. mc = 3;
  577. nc = 3;
  578. gemm<3, 3>(m0, m, n0, n);
  579. break;
  580. case 0x42:
  581. mc = 4;
  582. nc = 2;
  583. gemm<4, 2>(m0, m, n0, n);
  584. break;
  585. case 0x24:
  586. mc = 2;
  587. nc = 4;
  588. gemm<2, 4>(m0, m, n0, n);
  589. break;
  590. #else
  591. case 0x44:
  592. case 0x43:
  593. case 0x42:
  594. mc = 4;
  595. nc = 2;
  596. gemm<4, 2>(m0, m, n0, n);
  597. break;
  598. case 0x34:
  599. case 0x24:
  600. mc = 2;
  601. nc = 4;
  602. gemm<2, 4>(m0, m, n0, n);
  603. break;
  604. case 0x33:
  605. #endif
  606. case 0x32:
  607. mc = 3;
  608. nc = 2;
  609. gemm<3, 2>(m0, m, n0, n);
  610. break;
  611. case 0x23:
  612. mc = 2;
  613. nc = 3;
  614. gemm<2, 3>(m0, m, n0, n);
  615. break;
  616. case 0x41:
  617. mc = 4;
  618. nc = 1;
  619. gemm<4, 1>(m0, m, n0, n);
  620. break;
  621. case 0x22:
  622. mc = 2;
  623. nc = 2;
  624. gemm<2, 2>(m0, m, n0, n);
  625. break;
  626. case 0x14:
  627. mc = 1;
  628. nc = 4;
  629. gemm<1, 4>(m0, m, n0, n);
  630. break;
  631. case 0x31:
  632. mc = 3;
  633. nc = 1;
  634. gemm<3, 1>(m0, m, n0, n);
  635. break;
  636. case 0x13:
  637. mc = 1;
  638. nc = 3;
  639. gemm<1, 3>(m0, m, n0, n);
  640. break;
  641. case 0x21:
  642. mc = 2;
  643. nc = 1;
  644. gemm<2, 1>(m0, m, n0, n);
  645. break;
  646. case 0x12:
  647. mc = 1;
  648. nc = 2;
  649. gemm<1, 2>(m0, m, n0, n);
  650. break;
  651. case 0x11:
  652. mc = 1;
  653. nc = 1;
  654. gemm<1, 1>(m0, m, n0, n);
  655. break;
  656. default:
  657. return;
  658. }
  659. mp = m0 + (m - m0) / mc * mc;
  660. np = n0 + (n - n0) / nc * nc;
  661. mnpack(mp, m, n0, np);
  662. mnpack(m0, m, np, n);
  663. }
  664. template <int RM, int RN>
  665. NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  666. int64_t ytiles = (m - m0) / RM;
  667. int64_t xtiles = (n - n0) / RN;
  668. int64_t tiles = xtiles * ytiles;
  669. int64_t duty = (tiles + nth - 1) / nth;
  670. int64_t start = duty * ith;
  671. int64_t end = start + duty;
  672. if (end > tiles)
  673. end = tiles;
  674. for (int64_t job = start; job < end; ++job) {
  675. int64_t ii = m0 + job / xtiles * RM;
  676. int64_t jj = n0 + job % xtiles * RN;
  677. __m256 Cv[RN][RM] = {};
  678. for (int64_t l = 0; l < k; ++l)
  679. for (int64_t j = 0; j < RN; ++j)
  680. for (int64_t i = 0; i < RM; ++i)
  681. Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
  682. unhalf(B[ldb * (jj + j) + l].d)),
  683. updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
  684. load(A + lda * (ii + i) + l)),
  685. _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
  686. load(A + lda * (ii + i) + l))),
  687. Cv[j][i]);
  688. for (int64_t j = 0; j < RN; ++j)
  689. for (int64_t i = 0; i < RM; ++i)
  690. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  691. }
  692. }
  693. inline __m256i load(const block_q8_0 *b) {
  694. return _mm256_loadu_si256((const __m256i *)b->qs);
  695. }
  696. inline __m256i load(const block_q4_0 *b) {
  697. return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
  698. }
  699. inline __m256 updot(__m256i u, __m256i s) {
  700. __m256i res;
  701. #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
  702. res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
  703. #else
  704. res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
  705. #endif
  706. return _mm256_cvtepi32_ps(res);
  707. }
  708. static inline __m256i denibble(const uint8_t *p) {
  709. __m128i x = _mm_loadu_si128((const __m128i *)p);
  710. return _mm256_and_si256(_mm256_set1_epi8(15),
  711. _mm256_insertf128_si256(_mm256_castsi128_si256(x),
  712. _mm_srli_epi16(x, 4), 1));
  713. }
  714. const TA *const A;
  715. const TB *const B;
  716. TC *const C;
  717. const int64_t k;
  718. const int64_t lda;
  719. const int64_t ldb;
  720. const int64_t ldc;
  721. const int ith;
  722. const int nth;
  723. };
  724. #endif // __AVX2__
  725. } // namespace
  726. /**
  727. * Performs optimized matrix multiplication on CPU.
  728. *
  729. * This subroutine may compute C = Aᵀ * B with column major ordering.
  730. * Despite its name, this isn't a generalized implementation. Work is
  731. * only performed when a handwritten kernel is written and available.
  732. * Otherwise the caller should fall back to a general matmul routine.
  733. *
  734. * For example, for single-threaded single-precision GEMM you can say
  735. *
  736. * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
  737. * 0, 1, GGML_TASK_TYPE_COMPUTE,
  738. * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
  739. *
  740. * @param m is rows in `A` and `C`
  741. * @param n is cols in `B` and `C`
  742. * @param k is cols in `A` and rows in `B`
  743. * @param A is first input matrix (always transposed)
  744. * @param lda is row stride of `A`
  745. * @param B is second input matrix (never transposed)
  746. * @param ldb is row stride of `B`
  747. * @param C is input/output array of output matrices
  748. * @param ldc is row stride of `C`
  749. * @param ith is thread id (must be less than `nth`)
  750. * @param nth is number of threads (must be greater than zero)
  751. * @param task is GGML task type
  752. * @param Atype is GGML data type of `A`
  753. * @param Btype is GGML data type of `B`
  754. * @param Ctype is GGML data type of `C`
  755. * @return true if this function was able to service the matmul request
  756. */
  757. bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
  758. int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
  759. assert(m >= 0);
  760. assert(n >= 0);
  761. assert(k >= 0);
  762. assert(lda >= k);
  763. assert(ldb >= k);
  764. assert(ldc >= m);
  765. assert(nth > 0);
  766. assert(ith < nth);
  767. if (Ctype != GGML_TYPE_F32)
  768. return false;
  769. switch (Atype) {
  770. case GGML_TYPE_F32: {
  771. if (Btype != GGML_TYPE_F32)
  772. return false;
  773. #if defined(__AVX512F__)
  774. if (k % 16)
  775. return false;
  776. tinyBLAS<16, __m512, __m512, float, float, float> tb{
  777. k, (const float *)A, lda,
  778. (const float *)B, ldb,
  779. (float *)C, ldc,
  780. ith, nth};
  781. tb.matmul(m, n, task);
  782. return true;
  783. #elif defined(__AVX__) || defined(__AVX2__)
  784. if (k % 8)
  785. return false;
  786. tinyBLAS<8, __m256, __m256, float, float, float> tb{
  787. k, (const float *)A, lda,
  788. (const float *)B, ldb,
  789. (float *)C, ldc,
  790. ith, nth};
  791. tb.matmul(m, n, task);
  792. return true;
  793. #elif defined(__ARM_NEON)
  794. if (n < 4)
  795. return false;
  796. if (k % 4)
  797. return false;
  798. tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
  799. k, (const float *)A, lda,
  800. (const float *)B, ldb,
  801. (float *)C, ldc,
  802. ith, nth};
  803. tb.matmul(m, n, task);
  804. return true;
  805. #else
  806. return false;
  807. #endif
  808. }
  809. case GGML_TYPE_F16: {
  810. #if defined(__AVX512F__)
  811. if (k % 16)
  812. return false;
  813. if (Btype != GGML_TYPE_F32)
  814. return false;
  815. tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
  816. k, (const ggml_fp16_t *)A, lda,
  817. (const float *)B, ldb,
  818. (float *)C, ldc,
  819. ith, nth};
  820. tb.matmul(m, n, task);
  821. return true;
  822. #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
  823. if (k % 8)
  824. return false;
  825. if (Btype != GGML_TYPE_F32)
  826. return false;
  827. tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
  828. k, (const ggml_fp16_t *)A, lda,
  829. (const float *)B, ldb,
  830. (float *)C, ldc,
  831. ith, nth};
  832. tb.matmul(m, n, task);
  833. return true;
  834. #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
  835. if (n < 8)
  836. return false;
  837. if (k % 8)
  838. return false;
  839. if (Btype != GGML_TYPE_F16)
  840. return false;
  841. tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
  842. k, (const ggml_fp16_t *)A, lda,
  843. (const ggml_fp16_t *)B, ldb,
  844. (float *)C, ldc,
  845. ith, nth};
  846. tb.matmul(m, n, task);
  847. return true;
  848. #elif defined(__ARM_NEON) && !defined(_MSC_VER)
  849. if (k % 4)
  850. return false;
  851. if (Btype != GGML_TYPE_F32)
  852. return false;
  853. tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
  854. k, (const ggml_fp16_t *)A, lda,
  855. (const float *)B, ldb,
  856. (float *)C, ldc,
  857. ith, nth};
  858. tb.matmul(m, n, task);
  859. return true;
  860. #else
  861. return false;
  862. #endif
  863. }
  864. case GGML_TYPE_Q8_0: {
  865. if (Btype != GGML_TYPE_Q8_0)
  866. return false;
  867. #if defined(__AVX2__) || defined(__AVX512F__)
  868. tinyBLAS_Q0_AVX2<block_q8_0, block_q8_0, float> tb{
  869. k, (const block_q8_0 *)A, lda,
  870. (const block_q8_0 *)B, ldb,
  871. (float *)C, ldc,
  872. ith, nth};
  873. tb.matmul(m, n, task);
  874. return true;
  875. #elif defined(__ARM_FEATURE_DOTPROD)
  876. tinyBLAS_Q0_ARM<block_q8_0> tb{
  877. k, (const block_q8_0 *)A, lda,
  878. (const block_q8_0 *)B, ldb,
  879. (float *)C, ldc,
  880. ith, nth};
  881. tb.matmul(m, n, task);
  882. return true;
  883. #else
  884. return false;
  885. #endif
  886. }
  887. case GGML_TYPE_Q4_0: {
  888. if (Btype != GGML_TYPE_Q8_0)
  889. return false;
  890. #if defined(__AVX2__) || defined(__AVX512F__)
  891. tinyBLAS_Q0_AVX2<block_q4_0, block_q8_0, float> tb{
  892. k, (const block_q4_0 *)A, lda,
  893. (const block_q8_0 *)B, ldb,
  894. (float *)C, ldc,
  895. ith, nth};
  896. tb.matmul(m, n, task);
  897. return true;
  898. #elif defined(__ARM_FEATURE_DOTPROD)
  899. tinyBLAS_Q0_ARM<block_q4_0> tb{
  900. k, (const block_q4_0 *)A, lda,
  901. (const block_q8_0 *)B, ldb,
  902. (float *)C, ldc,
  903. ith, nth};
  904. tb.matmul(m, n, task);
  905. return true;
  906. #else
  907. return false;
  908. #endif
  909. }
  910. default:
  911. return false;
  912. }
  913. (void)m;
  914. (void)n;
  915. (void)k;
  916. (void)A;
  917. (void)lda;
  918. (void)B;
  919. (void)ldb;
  920. (void)C;
  921. (void)ldc;
  922. (void)ith;
  923. (void)nth;
  924. (void)task;
  925. (void)Atype;
  926. (void)Btype;
  927. (void)Ctype;
  928. }