sgemm.cpp 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999
  1. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
  2. // vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
  3. //
  4. // Copyright 2024 Mozilla Foundation
  5. //
  6. // Permission is hereby granted, free of charge, to any person obtaining
  7. // a copy of this software and associated documentation files (the
  8. // "Software"), to deal in the Software without restriction, including
  9. // without limitation the rights to use, copy, modify, merge, publish,
  10. // distribute, sublicense, and/or sell copies of the Software, and to
  11. // permit persons to whom the Software is furnished to do so, subject to
  12. // the following conditions:
  13. //
  14. // The above copyright notice and this permission notice shall be
  15. // included in all copies or substantial portions of the Software.
  16. //
  17. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  18. // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  19. // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  20. // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  21. // BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  22. // ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  23. // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. // SOFTWARE.
  25. //
  26. // _ _ ___ _ _ ___
  27. // | |_(_)_ _ _ _| _ ) | /_\ / __|
  28. // | _| | ' \ || | _ \ |__ / _ \\__ \.
  29. // \__|_|_||_\_, |___/____/_/ \_\___/
  30. // |__/
  31. //
  32. // BASIC LINEAR ALGEBRA SUBPROGRAMS
  33. //
  34. //
  35. // This file implements multithreaded CPU matrix multiplication for the
  36. // common contiguous use case C = Aᵀ * B. These kernels are designed to
  37. // have excellent performance[1] for matrices that fit in the CPU cache
  38. // without imposing any overhead such as cache filling or malloc calls.
  39. //
  40. // This implementation does not guarantee any upper bound with rounding
  41. // errors, which grow along with k. Our goal's to maximally exploit the
  42. // hardware for performance, and then use whatever resources remain for
  43. // improving numerical accuracy.
  44. //
  45. // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
  46. // Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
  47. #pragma GCC diagnostic ignored "-Wpedantic"
  48. #pragma GCC diagnostic ignored "-Wignored-attributes"
  49. #include "sgemm.h"
  50. #include <algorithm>
  51. #include "ggml-impl.h"
  52. #include "ggml-quants.h"
  53. #ifdef _MSC_VER
  54. #define NOINLINE __declspec(noinline)
  55. #else
  56. #define NOINLINE __attribute__((__noinline__))
  57. #endif
  58. #if defined(__ARM_NEON) || defined(__AVX512F__)
  59. #define VECTOR_REGISTERS 32
  60. #else
  61. #define VECTOR_REGISTERS 16
  62. #endif
  63. #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
  64. namespace {
  65. inline float unhalf(ggml_fp16_t d) {
  66. return GGML_FP16_TO_FP32(d);
  67. }
  68. ////////////////////////////////////////////////////////////////////////////////////////////////////
  69. // VECTORIZED ARITHMETIC OPERATIONS
  70. #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  71. inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
  72. inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
  73. inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
  74. #endif // __SSE__
  75. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  76. inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
  77. inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
  78. inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
  79. #endif // __AVX__
  80. #if defined(__AVX512F__)
  81. inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
  82. inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
  83. inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
  84. #endif // __AVX512F__
  85. #if defined(__ARM_NEON)
  86. inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
  87. inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
  88. inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
  89. #endif // __ARM_NEON
  90. #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
  91. inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
  92. inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
  93. inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
  94. #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  95. ////////////////////////////////////////////////////////////////////////////////////////////////////
  96. // VECTORIZED FUSED MULTIPLY ADD
  97. /**
  98. * Computes a * b + c.
  99. */
  100. template <typename T, typename U>
  101. inline U madd(T a, T b, U c) {
  102. return add(mul(a, b), c);
  103. }
  104. #if defined(__FMA__)
  105. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  106. template <>
  107. inline __m256 madd(__m256 a, __m256 b, __m256 c) {
  108. return _mm256_fmadd_ps(a, b, c);
  109. }
  110. #endif
  111. #if defined(__AVX512F__)
  112. template <>
  113. inline __m512 madd(__m512 a, __m512 b, __m512 c) {
  114. return _mm512_fmadd_ps(a, b, c);
  115. }
  116. #endif
  117. #endif
  118. #if defined(__ARM_FEATURE_FMA)
  119. template <>
  120. inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
  121. return vfmaq_f32(c, b, a);
  122. }
  123. #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
  124. template <>
  125. inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
  126. return vfmaq_f16(c, b, a);
  127. }
  128. #endif
  129. #endif
  130. ////////////////////////////////////////////////////////////////////////////////////////////////////
  131. // VECTORIZED HORIZONTAL SUM
  132. #if defined(__ARM_NEON)
  133. inline float hsum(float32x4_t x) {
  134. return vaddvq_f32(x);
  135. }
  136. #endif // __ARM_NEON
  137. #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
  138. inline float hsum(float16x8_t x) {
  139. return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
  140. vcvt_f32_f16(vget_high_f16(x))));
  141. }
  142. #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  143. #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  144. inline float hsum(__m128 x) {
  145. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  146. x = _mm_add_ps(x, _mm_movehl_ps(x, x));
  147. x = _mm_add_ss(x, _mm_movehdup_ps(x));
  148. #else
  149. __m128 t;
  150. t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
  151. x = _mm_add_ps(x, t);
  152. t = _mm_movehl_ps(t, x);
  153. x = _mm_add_ss(x, t);
  154. #endif
  155. return _mm_cvtss_f32(x);
  156. }
  157. #endif
  158. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  159. inline float hsum(__m256 x) {
  160. return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
  161. _mm256_castps256_ps128(x)));
  162. }
  163. #endif // __AVX__
  164. #if defined(__AVX512F__)
  165. inline float hsum(__m512 x) {
  166. return _mm512_reduce_add_ps(x);
  167. }
  168. #endif // __AVX512F__
  169. ////////////////////////////////////////////////////////////////////////////////////////////////////
  170. // VECTORIZED MEMORY LOADING
  171. template <typename T, typename U> T load(const U *);
  172. #if defined(__ARM_NEON)
  173. template <> inline float32x4_t load(const float *p) {
  174. return vld1q_f32(p);
  175. }
  176. #if !defined(_MSC_VER)
  177. template <> inline float16x8_t load(const ggml_fp16_t *p) {
  178. return vld1q_f16((const float16_t *)p);
  179. }
  180. template <> inline float32x4_t load(const ggml_fp16_t *p) {
  181. return vcvt_f32_f16(vld1_f16((const float16_t *)p));
  182. }
  183. #endif // _MSC_VER
  184. #endif // __ARM_NEON
  185. #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  186. template <> inline __m128 load(const float *p) {
  187. return _mm_loadu_ps(p);
  188. }
  189. #endif // __SSE__
  190. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  191. template <> inline __m256 load(const float *p) {
  192. return _mm256_loadu_ps(p);
  193. }
  194. #endif // __AVX__
  195. #if defined(__F16C__)
  196. template <> inline __m256 load(const ggml_fp16_t *p) {
  197. return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
  198. }
  199. #endif // __F16C__
  200. #if defined(__AVX512F__)
  201. template <> inline __m512 load(const float *p) {
  202. return _mm512_loadu_ps(p);
  203. }
  204. template <> inline __m512 load(const ggml_fp16_t *p) {
  205. return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
  206. }
  207. #endif // __AVX512F__
  208. ////////////////////////////////////////////////////////////////////////////////////////////////////
  209. // FLOATING POINT MATRIX MULTIPLICATION
  210. template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
  211. class tinyBLAS {
  212. public:
  213. tinyBLAS(int k,
  214. const TA *A, int lda,
  215. const TB *B, int ldb,
  216. TC *C, int ldc,
  217. int ith, int nth)
  218. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  219. }
  220. void matmul(int m, int n, int task) {
  221. if (task == GGML_TASK_TYPE_COMPUTE)
  222. mnpack(0, m, 0, n);
  223. }
  224. private:
  225. NOINLINE void mnpack(int m0, int m, int n0, int n) {
  226. int mc, nc, mp, np;
  227. switch ((std::min(m - m0, 5) << 4) | std::min(n - n0, 5)) {
  228. #if VECTOR_REGISTERS == 32
  229. case 0x55:
  230. mc = 5;
  231. nc = 5;
  232. gemm<5, 5>(m0, m, n0, n);
  233. break;
  234. case 0x45:
  235. mc = 4;
  236. nc = 5;
  237. gemm<4, 5>(m0, m, n0, n);
  238. break;
  239. case 0x54:
  240. mc = 5;
  241. nc = 4;
  242. gemm<5, 4>(m0, m, n0, n);
  243. break;
  244. case 0x44:
  245. mc = 4;
  246. nc = 4;
  247. gemm<4, 4>(m0, m, n0, n);
  248. break;
  249. case 0x53:
  250. mc = 5;
  251. nc = 3;
  252. gemm<5, 3>(m0, m, n0, n);
  253. break;
  254. case 0x35:
  255. mc = 3;
  256. nc = 5;
  257. gemm<3, 5>(m0, m, n0, n);
  258. break;
  259. case 0x43:
  260. mc = 4;
  261. nc = 3;
  262. gemm<4, 3>(m0, m, n0, n);
  263. break;
  264. #else
  265. case 0x55:
  266. case 0x54:
  267. case 0x53:
  268. case 0x45:
  269. case 0x44:
  270. case 0x43:
  271. mc = 4;
  272. nc = 3;
  273. gemm<4, 3>(m0, m, n0, n);
  274. break;
  275. case 0x35:
  276. #endif
  277. case 0x34:
  278. mc = 3;
  279. nc = 4;
  280. gemm<3, 4>(m0, m, n0, n);
  281. break;
  282. case 0x52:
  283. mc = 5;
  284. nc = 2;
  285. gemm<5, 2>(m0, m, n0, n);
  286. break;
  287. case 0x33:
  288. mc = 3;
  289. nc = 3;
  290. gemm<3, 3>(m0, m, n0, n);
  291. break;
  292. case 0x25:
  293. mc = 2;
  294. nc = 5;
  295. gemm<2, 5>(m0, m, n0, n);
  296. break;
  297. case 0x42:
  298. mc = 4;
  299. nc = 2;
  300. gemm<4, 2>(m0, m, n0, n);
  301. break;
  302. case 0x24:
  303. mc = 2;
  304. nc = 4;
  305. gemm<2, 4>(m0, m, n0, n);
  306. break;
  307. case 0x32:
  308. mc = 3;
  309. nc = 2;
  310. gemm<3, 2>(m0, m, n0, n);
  311. break;
  312. case 0x23:
  313. mc = 2;
  314. nc = 3;
  315. gemm<2, 3>(m0, m, n0, n);
  316. break;
  317. case 0x51:
  318. mc = 5;
  319. nc = 1;
  320. gemm<5, 1>(m0, m, n0, n);
  321. break;
  322. case 0x41:
  323. mc = 4;
  324. nc = 1;
  325. gemm<4, 1>(m0, m, n0, n);
  326. break;
  327. case 0x22:
  328. mc = 2;
  329. nc = 2;
  330. gemm<2, 2>(m0, m, n0, n);
  331. break;
  332. case 0x15:
  333. mc = 1;
  334. nc = 5;
  335. gemm<1, 5>(m0, m, n0, n);
  336. break;
  337. case 0x14:
  338. mc = 1;
  339. nc = 4;
  340. gemm<1, 4>(m0, m, n0, n);
  341. break;
  342. case 0x31:
  343. mc = 3;
  344. nc = 1;
  345. gemm<3, 1>(m0, m, n0, n);
  346. break;
  347. case 0x13:
  348. mc = 1;
  349. nc = 3;
  350. gemm<1, 3>(m0, m, n0, n);
  351. break;
  352. case 0x21:
  353. mc = 2;
  354. nc = 1;
  355. gemm<2, 1>(m0, m, n0, n);
  356. break;
  357. case 0x12:
  358. mc = 1;
  359. nc = 2;
  360. gemm<1, 2>(m0, m, n0, n);
  361. break;
  362. case 0x11:
  363. mc = 1;
  364. nc = 1;
  365. gemm<1, 1>(m0, m, n0, n);
  366. break;
  367. default:
  368. return;
  369. }
  370. mp = m0 + (m - m0) / mc * mc;
  371. np = n0 + (n - n0) / nc * nc;
  372. mnpack(mp, m, n0, np);
  373. mnpack(m0, m, np, n);
  374. }
  375. template <int RM, int RN>
  376. NOINLINE void gemm(int m0, int m, int n0, int n) {
  377. int ytiles = (m - m0) / RM;
  378. int xtiles = (n - n0) / RN;
  379. int tiles = xtiles * ytiles;
  380. int duty = (tiles + nth - 1) / nth;
  381. int start = duty * ith;
  382. int end = start + duty;
  383. if (end > tiles)
  384. end = tiles;
  385. for (int job = start; job < end; ++job) {
  386. int ii = m0 + job / xtiles * RM;
  387. int jj = n0 + job % xtiles * RN;
  388. D Cv[RN][RM] = {};
  389. for (int l = 0; l < k; l += KN)
  390. for (int j = 0; j < RN; ++j)
  391. for (int i = 0; i < RM; ++i)
  392. Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
  393. load<V>(B + ldb * (jj + j) + l),
  394. Cv[j][i]);
  395. for (int j = 0; j < RN; ++j)
  396. for (int i = 0; i < RM; ++i)
  397. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  398. }
  399. }
  400. const TA *const A;
  401. const TB *const B;
  402. TC *const C;
  403. const int k;
  404. const int lda;
  405. const int ldb;
  406. const int ldc;
  407. const int ith;
  408. const int nth;
  409. };
  410. //////////////////////////////////////////////////////////////////////////////////////////
  411. // QUANT ZERO MATRIX MULTIPLICATION
  412. #if defined(__ARM_FEATURE_DOTPROD)
  413. template <typename TA>
  414. class tinyBLAS_Q0_ARM {
  415. public:
  416. tinyBLAS_Q0_ARM(int k,
  417. const TA *A, int lda,
  418. const block_q8_0 *B, int ldb,
  419. float *C, int ldc,
  420. int ith, int nth)
  421. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  422. }
  423. void matmul(int m, int n, int task) {
  424. if (task == GGML_TASK_TYPE_COMPUTE)
  425. mnpack(0, m, 0, n);
  426. }
  427. private:
  428. NOINLINE void mnpack(int m0, int m, int n0, int n) {
  429. int mc, nc, mp, np;
  430. switch ((std::min(m - m0, 3) << 4) | std::min(n - n0, 3)) {
  431. case 0x33:
  432. mc = 3;
  433. nc = 3;
  434. gemm<3, 3>(m0, m, n0, n);
  435. break;
  436. case 0x32:
  437. mc = 3;
  438. nc = 2;
  439. gemm<3, 2>(m0, m, n0, n);
  440. break;
  441. case 0x23:
  442. mc = 2;
  443. nc = 3;
  444. gemm<2, 3>(m0, m, n0, n);
  445. break;
  446. case 0x22:
  447. mc = 2;
  448. nc = 2;
  449. gemm<2, 2>(m0, m, n0, n);
  450. break;
  451. case 0x31:
  452. mc = 3;
  453. nc = 1;
  454. gemm<3, 1>(m0, m, n0, n);
  455. break;
  456. case 0x13:
  457. mc = 1;
  458. nc = 3;
  459. gemm<1, 3>(m0, m, n0, n);
  460. break;
  461. case 0x21:
  462. mc = 2;
  463. nc = 1;
  464. gemm<2, 1>(m0, m, n0, n);
  465. break;
  466. case 0x12:
  467. mc = 1;
  468. nc = 2;
  469. gemm<1, 2>(m0, m, n0, n);
  470. break;
  471. case 0x11:
  472. mc = 1;
  473. nc = 1;
  474. gemm<1, 1>(m0, m, n0, n);
  475. break;
  476. default:
  477. return;
  478. }
  479. mp = m0 + (m - m0) / mc * mc;
  480. np = n0 + (n - n0) / nc * nc;
  481. mnpack(mp, m, n0, np);
  482. mnpack(m0, m, np, n);
  483. }
  484. template <int RM, int RN>
  485. NOINLINE void gemm(int m0, int m, int n0, int n) {
  486. int ytiles = (m - m0) / RM;
  487. int xtiles = (n - n0) / RN;
  488. int tiles = xtiles * ytiles;
  489. int duty = (tiles + nth - 1) / nth;
  490. int start = duty * ith;
  491. int end = start + duty;
  492. if (end > tiles)
  493. end = tiles;
  494. for (int job = start; job < end; ++job) {
  495. int ii = m0 + job / xtiles * RM;
  496. int jj = n0 + job % xtiles * RN;
  497. float32x4_t Cv[RN][RM] = {};
  498. for (int l = 0; l < k; ++l)
  499. for (int j = 0; j < RN; ++j)
  500. for (int i = 0; i < RM; ++i)
  501. Cv[j][i] = vmlaq_n_f32(Cv[j][i],
  502. vcvtq_f32_s32(vdotq_s32(
  503. vdotq_s32(vdupq_n_s32(0),
  504. load_lo(A + lda * (ii + i) + l),
  505. load_lo(B + ldb * (jj + j) + l)),
  506. load_hi(A + lda * (ii + i) + l),
  507. load_hi(B + ldb * (jj + j) + l))),
  508. unhalf(A[lda * (ii + i) + l].d) *
  509. unhalf(B[ldb * (jj + j) + l].d));
  510. for (int j = 0; j < RN; ++j)
  511. for (int i = 0; i < RM; ++i)
  512. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  513. }
  514. }
  515. inline int8x16_t load_lo(const block_q8_0 *b) {
  516. return vld1q_s8(b->qs);
  517. }
  518. inline int8x16_t load_hi(const block_q8_0 *b) {
  519. return vld1q_s8(b->qs + 16);
  520. }
  521. inline int8x16_t load_lo(const block_q4_0 *b) {
  522. return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
  523. vdupq_n_u8(0x0f))),
  524. vdupq_n_s8(0x8));
  525. }
  526. inline int8x16_t load_hi(const block_q4_0 *b) {
  527. return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
  528. vdupq_n_s8(0x8));
  529. }
  530. const TA *const A;
  531. const block_q8_0 *const B;
  532. float *const C;
  533. const int k;
  534. const int lda;
  535. const int ldb;
  536. const int ldc;
  537. const int ith;
  538. const int nth;
  539. };
  540. #endif // __ARM_FEATURE_DOTPROD
  541. #if defined(__AVX2__) || defined(__AVX512F__)
  542. template <typename TA, typename TB, typename TC>
  543. class tinyBLAS_Q0_AVX2 {
  544. public:
  545. tinyBLAS_Q0_AVX2(int k,
  546. const TA *A, int lda,
  547. const TB *B, int ldb,
  548. TC *C, int ldc,
  549. int ith, int nth)
  550. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  551. }
  552. void matmul(int m, int n, int task) {
  553. if (task == GGML_TASK_TYPE_COMPUTE)
  554. mnpack(0, m, 0, n);
  555. }
  556. private:
  557. void mnpack(int m0, int m, int n0, int n) {
  558. int mc, nc, mp, np;
  559. switch ((std::min(m - m0, 4) << 4) | std::min(n - n0, 4)) {
  560. #if VECTOR_REGISTERS == 32
  561. case 0x44:
  562. mc = 4;
  563. nc = 4;
  564. gemm<4, 4>(m0, m, n0, n);
  565. break;
  566. case 0x43:
  567. mc = 4;
  568. nc = 3;
  569. gemm<4, 3>(m0, m, n0, n);
  570. break;
  571. case 0x34:
  572. mc = 3;
  573. nc = 4;
  574. gemm<3, 4>(m0, m, n0, n);
  575. break;
  576. case 0x33:
  577. mc = 3;
  578. nc = 3;
  579. gemm<3, 3>(m0, m, n0, n);
  580. break;
  581. case 0x42:
  582. mc = 4;
  583. nc = 2;
  584. gemm<4, 2>(m0, m, n0, n);
  585. break;
  586. case 0x24:
  587. mc = 2;
  588. nc = 4;
  589. gemm<2, 4>(m0, m, n0, n);
  590. break;
  591. #else
  592. case 0x44:
  593. case 0x43:
  594. case 0x42:
  595. mc = 4;
  596. nc = 2;
  597. gemm<4, 2>(m0, m, n0, n);
  598. break;
  599. case 0x34:
  600. case 0x24:
  601. mc = 2;
  602. nc = 4;
  603. gemm<2, 4>(m0, m, n0, n);
  604. break;
  605. case 0x33:
  606. #endif
  607. case 0x32:
  608. mc = 3;
  609. nc = 2;
  610. gemm<3, 2>(m0, m, n0, n);
  611. break;
  612. case 0x23:
  613. mc = 2;
  614. nc = 3;
  615. gemm<2, 3>(m0, m, n0, n);
  616. break;
  617. case 0x41:
  618. mc = 4;
  619. nc = 1;
  620. gemm<4, 1>(m0, m, n0, n);
  621. break;
  622. case 0x22:
  623. mc = 2;
  624. nc = 2;
  625. gemm<2, 2>(m0, m, n0, n);
  626. break;
  627. case 0x14:
  628. mc = 1;
  629. nc = 4;
  630. gemm<1, 4>(m0, m, n0, n);
  631. break;
  632. case 0x31:
  633. mc = 3;
  634. nc = 1;
  635. gemm<3, 1>(m0, m, n0, n);
  636. break;
  637. case 0x13:
  638. mc = 1;
  639. nc = 3;
  640. gemm<1, 3>(m0, m, n0, n);
  641. break;
  642. case 0x21:
  643. mc = 2;
  644. nc = 1;
  645. gemm<2, 1>(m0, m, n0, n);
  646. break;
  647. case 0x12:
  648. mc = 1;
  649. nc = 2;
  650. gemm<1, 2>(m0, m, n0, n);
  651. break;
  652. case 0x11:
  653. mc = 1;
  654. nc = 1;
  655. gemm<1, 1>(m0, m, n0, n);
  656. break;
  657. default:
  658. return;
  659. }
  660. mp = m0 + (m - m0) / mc * mc;
  661. np = n0 + (n - n0) / nc * nc;
  662. mnpack(mp, m, n0, np);
  663. mnpack(m0, m, np, n);
  664. }
  665. template <int RM, int RN>
  666. NOINLINE void gemm(int m0, int m, int n0, int n) {
  667. int ytiles = (m - m0) / RM;
  668. int xtiles = (n - n0) / RN;
  669. int tiles = xtiles * ytiles;
  670. int duty = (tiles + nth - 1) / nth;
  671. int start = duty * ith;
  672. int end = start + duty;
  673. if (end > tiles)
  674. end = tiles;
  675. for (int job = start; job < end; ++job) {
  676. int ii = m0 + job / xtiles * RM;
  677. int jj = n0 + job % xtiles * RN;
  678. __m256 Cv[RN][RM] = {};
  679. for (int l = 0; l < k; ++l)
  680. for (int j = 0; j < RN; ++j)
  681. for (int i = 0; i < RM; ++i)
  682. Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
  683. unhalf(B[ldb * (jj + j) + l].d)),
  684. updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
  685. load(A + lda * (ii + i) + l)),
  686. _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
  687. load(A + lda * (ii + i) + l))),
  688. Cv[j][i]);
  689. for (int j = 0; j < RN; ++j)
  690. for (int i = 0; i < RM; ++i)
  691. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  692. }
  693. }
  694. inline __m256i load(const block_q8_0 *b) {
  695. return _mm256_loadu_si256((const __m256i *)b->qs);
  696. }
  697. inline __m256i load(const block_q4_0 *b) {
  698. return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
  699. }
  700. inline __m256 updot(__m256i u, __m256i s) {
  701. __m256i res;
  702. #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
  703. res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
  704. #else
  705. res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
  706. #endif
  707. return _mm256_cvtepi32_ps(res);
  708. }
  709. static inline __m256i denibble(const uint8_t *p) {
  710. __m128i x = _mm_loadu_si128((const __m128i *)p);
  711. return _mm256_and_si256(_mm256_set1_epi8(15),
  712. _mm256_insertf128_si256(_mm256_castsi128_si256(x),
  713. _mm_srli_epi16(x, 4), 1));
  714. }
  715. const TA *const A;
  716. const TB *const B;
  717. TC *const C;
  718. const int k;
  719. const int lda;
  720. const int ldb;
  721. const int ldc;
  722. const int ith;
  723. const int nth;
  724. };
  725. #endif // __AVX2__
  726. } // namespace
  727. /**
  728. * Performs optimized matrix multiplication on CPU.
  729. *
  730. * This subroutine may compute C = Aᵀ * B with column major ordering.
  731. * Despite its name, this isn't a generalized implementation. Work is
  732. * only performed when a handwritten kernel is written and available.
  733. * Otherwise the caller should fall back to a general matmul routine.
  734. *
  735. * For example, for single-threaded single-precision GEMM you can say
  736. *
  737. * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
  738. * 0, 1, GGML_TASK_TYPE_COMPUTE,
  739. * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
  740. *
  741. * @param m is rows in `A` and `C`
  742. * @param n is cols in `B` and `C`
  743. * @param k is cols in `A` and rows in `B`
  744. * @param A is first input matrix (always transposed)
  745. * @param lda is row stride of `A`
  746. * @param B is second input matrix (never transposed)
  747. * @param ldb is row stride of `B`
  748. * @param C is input/output array of output matrices
  749. * @param ldc is row stride of `C`
  750. * @param ith is thread id (must be less than `nth`)
  751. * @param nth is number of threads (must be greater than zero)
  752. * @param task is GGML task type
  753. * @param Atype is GGML data type of `A`
  754. * @param Btype is GGML data type of `B`
  755. * @param Ctype is GGML data type of `C`
  756. * @return true if this function was able to service the matmul request
  757. */
  758. bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, int ldb, void *C,
  759. int ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
  760. assert(m >= 0);
  761. assert(n >= 0);
  762. assert(k >= 0);
  763. assert(lda >= k);
  764. assert(ldb >= k);
  765. assert(ldc >= m);
  766. assert(nth > 0);
  767. assert(ith < nth);
  768. assert(1ll * lda * m <= 0x7fffffff);
  769. assert(1ll * ldb * n <= 0x7fffffff);
  770. assert(1ll * ldc * n <= 0x7fffffff);
  771. if (Ctype != GGML_TYPE_F32)
  772. return false;
  773. switch (Atype) {
  774. case GGML_TYPE_F32: {
  775. if (Btype != GGML_TYPE_F32)
  776. return false;
  777. #if defined(__AVX512F__)
  778. if (k % 16)
  779. return false;
  780. tinyBLAS<16, __m512, __m512, float, float, float> tb{
  781. k, (const float *)A, lda,
  782. (const float *)B, ldb,
  783. (float *)C, ldc,
  784. ith, nth};
  785. tb.matmul(m, n, task);
  786. return true;
  787. #elif defined(__AVX__) || defined(__AVX2__)
  788. if (k % 8)
  789. return false;
  790. tinyBLAS<8, __m256, __m256, float, float, float> tb{
  791. k, (const float *)A, lda,
  792. (const float *)B, ldb,
  793. (float *)C, ldc,
  794. ith, nth};
  795. tb.matmul(m, n, task);
  796. return true;
  797. #elif defined(__ARM_NEON)
  798. if (n < 4)
  799. return false;
  800. if (k % 4)
  801. return false;
  802. tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
  803. k, (const float *)A, lda,
  804. (const float *)B, ldb,
  805. (float *)C, ldc,
  806. ith, nth};
  807. tb.matmul(m, n, task);
  808. return true;
  809. #else
  810. return false;
  811. #endif
  812. }
  813. case GGML_TYPE_F16: {
  814. #if defined(__AVX512F__)
  815. if (k % 16)
  816. return false;
  817. if (Btype != GGML_TYPE_F32)
  818. return false;
  819. tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
  820. k, (const ggml_fp16_t *)A, lda,
  821. (const float *)B, ldb,
  822. (float *)C, ldc,
  823. ith, nth};
  824. tb.matmul(m, n, task);
  825. return true;
  826. #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
  827. if (k % 8)
  828. return false;
  829. if (Btype != GGML_TYPE_F32)
  830. return false;
  831. tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
  832. k, (const ggml_fp16_t *)A, lda,
  833. (const float *)B, ldb,
  834. (float *)C, ldc,
  835. ith, nth};
  836. tb.matmul(m, n, task);
  837. return true;
  838. #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
  839. if (n < 8)
  840. return false;
  841. if (k % 8)
  842. return false;
  843. if (Btype != GGML_TYPE_F16)
  844. return false;
  845. tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
  846. k, (const ggml_fp16_t *)A, lda,
  847. (const ggml_fp16_t *)B, ldb,
  848. (float *)C, ldc,
  849. ith, nth};
  850. tb.matmul(m, n, task);
  851. return true;
  852. #elif defined(__ARM_NEON) && !defined(_MSC_VER)
  853. if (k % 4)
  854. return false;
  855. if (Btype != GGML_TYPE_F32)
  856. return false;
  857. tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
  858. k, (const ggml_fp16_t *)A, lda,
  859. (const float *)B, ldb,
  860. (float *)C, ldc,
  861. ith, nth};
  862. tb.matmul(m, n, task);
  863. return true;
  864. #else
  865. return false;
  866. #endif
  867. }
  868. case GGML_TYPE_Q8_0: {
  869. if (Btype != GGML_TYPE_Q8_0)
  870. return false;
  871. #if defined(__AVX2__) || defined(__AVX512F__)
  872. tinyBLAS_Q0_AVX2<block_q8_0, block_q8_0, float> tb{
  873. k, (const block_q8_0 *)A, lda,
  874. (const block_q8_0 *)B, ldb,
  875. (float *)C, ldc,
  876. ith, nth};
  877. tb.matmul(m, n, task);
  878. return true;
  879. #elif defined(__ARM_FEATURE_DOTPROD)
  880. tinyBLAS_Q0_ARM<block_q8_0> tb{
  881. k, (const block_q8_0 *)A, lda,
  882. (const block_q8_0 *)B, ldb,
  883. (float *)C, ldc,
  884. ith, nth};
  885. tb.matmul(m, n, task);
  886. return true;
  887. #else
  888. return false;
  889. #endif
  890. }
  891. case GGML_TYPE_Q4_0: {
  892. if (Btype != GGML_TYPE_Q8_0)
  893. return false;
  894. #if defined(__AVX2__) || defined(__AVX512F__)
  895. tinyBLAS_Q0_AVX2<block_q4_0, block_q8_0, float> tb{
  896. k, (const block_q4_0 *)A, lda,
  897. (const block_q8_0 *)B, ldb,
  898. (float *)C, ldc,
  899. ith, nth};
  900. tb.matmul(m, n, task);
  901. return true;
  902. #elif defined(__ARM_FEATURE_DOTPROD)
  903. tinyBLAS_Q0_ARM<block_q4_0> tb{
  904. k, (const block_q4_0 *)A, lda,
  905. (const block_q8_0 *)B, ldb,
  906. (float *)C, ldc,
  907. ith, nth};
  908. tb.matmul(m, n, task);
  909. return true;
  910. #else
  911. return false;
  912. #endif
  913. }
  914. default:
  915. return false;
  916. }
  917. (void)m;
  918. (void)n;
  919. (void)k;
  920. (void)A;
  921. (void)lda;
  922. (void)B;
  923. (void)ldb;
  924. (void)C;
  925. (void)ldc;
  926. (void)ith;
  927. (void)nth;
  928. (void)task;
  929. (void)Atype;
  930. (void)Btype;
  931. (void)Ctype;
  932. }