sgemm.cpp 107 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878
  1. // Copyright 2024 Mozilla Foundation
  2. //
  3. // Permission is hereby granted, free of charge, to any person obtaining
  4. // a copy of this software and associated documentation files (the
  5. // "Software"), to deal in the Software without restriction, including
  6. // without limitation the rights to use, copy, modify, merge, publish,
  7. // distribute, sublicense, and/or sell copies of the Software, and to
  8. // permit persons to whom the Software is furnished to do so, subject to
  9. // the following conditions:
  10. //
  11. // The above copyright notice and this permission notice shall be
  12. // included in all copies or substantial portions of the Software.
  13. //
  14. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  15. // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  16. // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  17. // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  18. // BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  19. // ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  20. // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21. // SOFTWARE.
  22. //
  23. // _ _ ___ _ _ ___
  24. // | |_(_)_ _ _ _| _ ) | /_\ / __|
  25. // | _| | ' \ || | _ \ |__ / _ \\__ \.
  26. // \__|_|_||_\_, |___/____/_/ \_\___/
  27. // |__/
  28. //
  29. // BASIC LINEAR ALGEBRA SUBPROGRAMS
  30. //
  31. //
  32. // This file implements multithreaded CPU matrix multiplication for the
  33. // common contiguous use case C = Aᵀ * B. These kernels are designed to
  34. // have excellent performance[1] for matrices that fit in the CPU cache
  35. // without imposing any overhead such as cache filling or malloc calls.
  36. //
  37. // This implementation does not guarantee any upper bound with rounding
  38. // errors, which grow along with k. Our goal's to maximally exploit the
  39. // hardware for performance, and then use whatever resources remain for
  40. // improving numerical accuracy.
  41. //
  42. // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
  43. // Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
  44. #if defined(__GNUC__)
  45. #pragma GCC diagnostic ignored "-Wpedantic"
  46. #pragma GCC diagnostic ignored "-Wignored-attributes"
  47. #endif
  48. #include "sgemm.h"
  49. #include "ggml-impl.h"
  50. #include "ggml-cpu-impl.h"
  51. #include "ggml-quants.h"
  52. #include "simd-mappings.h"
  53. #include <array>
  54. #include <type_traits>
  55. #ifdef _MSC_VER
  56. #define NOINLINE __declspec(noinline)
  57. #else
  58. #define NOINLINE __attribute__((__noinline__))
  59. #endif
  60. #if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__VXE__) || defined(__VXE2__)
  61. #define VECTOR_REGISTERS 32
  62. #else
  63. #define VECTOR_REGISTERS 16
  64. #endif
  65. #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
  66. namespace {
  67. inline float unhalf(ggml_fp16_t d) {
  68. return GGML_CPU_FP16_TO_FP32(d);
  69. }
  70. ////////////////////////////////////////////////////////////////////////////////////////////////////
  71. // VECTORIZED ARITHMETIC OPERATIONS
  72. #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  73. inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
  74. inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
  75. inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
  76. #endif // __SSE__
  77. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  78. inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
  79. inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
  80. inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
  81. #endif // __AVX__
  82. #if defined(__AVX512F__)
  83. inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
  84. inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
  85. inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
  86. #endif // __AVX512F__
  87. #if defined(__ARM_NEON)
  88. inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
  89. inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
  90. inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
  91. #endif // __ARM_NEON
  92. #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
  93. inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
  94. inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
  95. inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
  96. #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  97. #if defined(__VXE__) || defined(__VXE2__)
  98. inline float32x4_t add(float32x4_t x, float32x4_t y) { return vec_add(x, y); }
  99. inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vec_sub(x, y); }
  100. inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
  101. #endif
  102. #if defined(__MMA__)
  103. #include "sgemm-ppc.h"
  104. #endif
  105. ////////////////////////////////////////////////////////////////////////////////////////////////////
  106. // VECTORIZED FUSED MULTIPLY ADD
  107. /**
  108. * Computes a * b + c.
  109. */
  110. template <typename T, typename U>
  111. inline U madd(T a, T b, U c) {
  112. return add(mul(a, b), c);
  113. }
  114. #if defined(__FMA__)
  115. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  116. template <>
  117. inline __m256 madd(__m256 a, __m256 b, __m256 c) {
  118. return _mm256_fmadd_ps(a, b, c);
  119. }
  120. #endif
  121. #if defined(__AVX512F__)
  122. template <>
  123. inline __m512 madd(__m512 a, __m512 b, __m512 c) {
  124. return _mm512_fmadd_ps(a, b, c);
  125. }
  126. #endif
  127. #if defined(__AVX512BF16__)
  128. template <>
  129. inline __m512 madd(__m512bh a, __m512bh b, __m512 c) {
  130. return _mm512_dpbf16_ps(c, a, b);
  131. }
  132. template <>
  133. inline __m256 madd(__m256bh a, __m256bh b, __m256 c) {
  134. return _mm256_dpbf16_ps(c, a, b);
  135. }
  136. #endif
  137. #endif
  138. #if defined(__ARM_FEATURE_FMA)
  139. template <>
  140. inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
  141. return vfmaq_f32(c, b, a);
  142. }
  143. #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
  144. template <>
  145. inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
  146. return vfmaq_f16(c, b, a);
  147. }
  148. #endif
  149. #endif
  150. #if defined(__VXE__) || defined(__VXE2__)
  151. template <>
  152. inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
  153. return vec_madd(a, b, c);
  154. }
  155. #endif
  156. ////////////////////////////////////////////////////////////////////////////////////////////////////
  157. // VECTORIZED HORIZONTAL SUM
  158. #if defined(__ARM_NEON)
  159. inline float hsum(float32x4_t x) {
  160. return vaddvq_f32(x);
  161. }
  162. #endif // __ARM_NEON
  163. #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
  164. inline float hsum(float16x8_t x) {
  165. return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
  166. vcvt_f32_f16(vget_high_f16(x))));
  167. }
  168. #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  169. #if defined(__VXE__) || defined(__VXE2__)
  170. inline float hsum(float32x4_t x) {
  171. float32x4_t tmp = x + vec_reve(x);
  172. return tmp[0] + tmp[1];
  173. }
  174. #endif
  175. #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  176. inline float hsum(__m128 x) {
  177. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  178. x = _mm_add_ps(x, _mm_movehl_ps(x, x));
  179. x = _mm_add_ss(x, _mm_movehdup_ps(x));
  180. #else
  181. __m128 t;
  182. t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
  183. x = _mm_add_ps(x, t);
  184. t = _mm_movehl_ps(t, x);
  185. x = _mm_add_ss(x, t);
  186. #endif
  187. return _mm_cvtss_f32(x);
  188. }
  189. #endif
  190. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  191. inline float hsum(__m256 x) {
  192. return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
  193. _mm256_castps256_ps128(x)));
  194. }
  195. #endif // __AVX__
  196. #if defined(__AVX512F__)
  197. inline float hsum(__m512 x) {
  198. return _mm512_reduce_add_ps(x);
  199. }
  200. #endif // __AVX512F__
  201. ////////////////////////////////////////////////////////////////////////////////////////////////////
  202. // VECTORIZED MEMORY LOADING
  203. template <typename T, typename U> T load(const U *);
  204. #if defined(__ARM_NEON)
  205. template <> inline float32x4_t load(const float *p) {
  206. return vld1q_f32(p);
  207. }
  208. #if !defined(_MSC_VER)
  209. // FIXME: this should check for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  210. template <> inline float16x8_t load(const ggml_fp16_t *p) {
  211. return vld1q_f16((const float16_t *)p);
  212. }
  213. template <> inline float32x4_t load(const ggml_fp16_t *p) {
  214. return vcvt_f32_f16(vld1_f16((const float16_t *)p));
  215. }
  216. #endif // _MSC_VER
  217. #endif // __ARM_NEON
  218. #if defined(__VXE__) || defined(__VXE2__)
  219. template <> inline float32x4_t load(const ggml_fp16_t * p) {
  220. float tmp[4];
  221. for (int i = 0; i < 4; i++) {
  222. tmp[i] = GGML_CPU_FP16_TO_FP32(p[i]);
  223. }
  224. return vec_xl(0, (const float *)(tmp));
  225. }
  226. template <> inline float32x4_t load(const float * p) {
  227. return vec_xl(0, p);
  228. }
  229. #endif
  230. #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  231. template <> inline __m128 load(const float *p) {
  232. return _mm_loadu_ps(p);
  233. }
  234. #endif // __SSE__
  235. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  236. template <> inline __m256 load(const float *p) {
  237. return _mm256_loadu_ps(p);
  238. }
  239. #endif // __AVX__
  240. #if defined(__AVX2__) || defined(__AVX512F__)
  241. template <> inline __m256 load(const ggml_bf16_t *p) {
  242. return _mm256_castsi256_ps(
  243. _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)p)), 16));
  244. }
  245. #endif // __AVX2__
  246. #if defined(__F16C__)
  247. template <> inline __m256 load(const ggml_fp16_t *p) {
  248. return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
  249. }
  250. #endif // __F16C__
  251. #if defined(__AVX512F__)
  252. template <> inline __m512 load(const float *p) {
  253. return _mm512_loadu_ps(p);
  254. }
  255. template <> inline __m512 load(const ggml_fp16_t *p) {
  256. return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
  257. }
  258. template <> inline __m512 load(const ggml_bf16_t *p) {
  259. return _mm512_castsi512_ps(
  260. _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)p)), 16));
  261. }
  262. #endif // __AVX512F__
  263. #if defined(__AVX512BF16__)
  264. template <> inline __m512bh load(const ggml_bf16_t *p) {
  265. return (__m512bh)_mm512_loadu_ps((const float *)p);
  266. }
  267. template <> inline __m256bh load(const ggml_bf16_t *p) {
  268. return (__m256bh)_mm256_loadu_ps((const float *)p);
  269. }
  270. template <> inline __m512bh load(const float *p) {
  271. return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p));
  272. }
  273. template <> inline __m256bh load(const float *p) {
  274. return _mm512_cvtneps_pbh(_mm512_loadu_ps(p));
  275. }
  276. #endif
  277. ////////////////////////////////////////////////////////////////////////////////////////////////////
  278. // FLOATING POINT MATRIX MULTIPLICATION
  279. template <int M>
  280. static inline int64_t BLOCK_SIZE(size_t m) {
  281. const int64_t NB_BLOC_M = (m + M - 1) / M;
  282. return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
  283. }
  284. static constexpr inline int64_t BLOC_POS(int64_t ib, int64_t ibN, int64_t bloc_size) {
  285. return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1);
  286. }
  287. template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
  288. class tinyBLAS {
  289. public:
  290. tinyBLAS(const ggml_compute_params * params, int64_t k,
  291. const TA *A, int64_t lda,
  292. const TB *B, int64_t ldb,
  293. TC *C, int64_t ldc)
  294. : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
  295. }
  296. bool matmul(int64_t m, int64_t n) {
  297. if (k % KN != 0)
  298. return false;
  299. // compute RM for only need tile with size RM&RM-1
  300. #if VECTOR_REGISTERS == 32
  301. if (m % 16 == 0 && (m/16 >= params->nth)) {
  302. const int64_t SIZE_N = BLOCK_SIZE<6>(n);
  303. mnpack<4, 6, 4>(m, n, SIZE_N, 12);
  304. return true;
  305. }
  306. if (m % 8 == 0 ) {
  307. const int64_t SIZE_N = BLOCK_SIZE<6>(n);
  308. mnpack<4, 6, 2>(m, n, SIZE_N, 12);
  309. return true;
  310. }
  311. if (m % 4 == 0) {
  312. const int64_t SIZE_N = BLOCK_SIZE<6>(n);
  313. mnpack<4, 6, 1>(m, n, SIZE_N, 12);
  314. return true;
  315. }
  316. #else // VECTOR_REGISTERS == 16
  317. if (m % 16 == 0 && (m/16 >= params->nth)) {
  318. const int64_t SIZE_N = BLOCK_SIZE<3>(n);
  319. mnpack<4, 3, 4>(m, n, SIZE_N, 24);
  320. return true;
  321. }
  322. if (m % 8 == 0 ) {
  323. const int64_t SIZE_N = BLOCK_SIZE<3>(n);
  324. mnpack<4, 3, 2>(m, n, SIZE_N, 24);
  325. return true;
  326. }
  327. if (m % 4 == 0) {
  328. const int64_t SIZE_N = BLOCK_SIZE<3>(n);
  329. mnpack<4, 3, 1>(m, n, SIZE_N, 24);
  330. return true;
  331. }
  332. #endif
  333. return false;
  334. }
  335. private:
  336. template <int RM, int RN, int BM>
  337. inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
  338. if (SIZE_N == RN) {
  339. return gemm<RM, RN, BM>(m, n, BN);
  340. }
  341. if constexpr (RN > 1) {
  342. return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
  343. } else {
  344. GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
  345. GGML_ASSERT(false); // we have miss something.
  346. }
  347. }
  348. template <int RM, int RN>
  349. inline void gemm_bloc(int64_t ii, int64_t jj) {
  350. D Cv[RN][RM] = {};
  351. for (int64_t l = 0; l < k; l += KN) {
  352. // help compiler for op order.
  353. if constexpr (RM <= RN) {
  354. V Av[RM];
  355. for (int64_t i = 0; i < RM; ++i) {
  356. Av[i] = load<V>(A + lda * (ii + i) + l);
  357. }
  358. for (int64_t j = 0; j < RN; ++j) {
  359. V Bv = load<V>(B + ldb * (jj + j) + l);
  360. for (int64_t i = 0; i < RM; ++i) {
  361. Cv[j][i] = madd(Av[i], Bv, Cv[j][i]);
  362. }
  363. }
  364. } else {
  365. V Bv[RN];
  366. for (int64_t j = 0; j < RN; ++j) {
  367. Bv[j] = load<V>(B + ldb * (jj + j) + l);
  368. }
  369. for (int64_t i = 0; i < RM; ++i) {
  370. V Av = load<V>(A + lda * (ii + i) + l);
  371. for (int64_t j = 0; j < RN; ++j) {
  372. Cv[j][i] = madd(Av, Bv[j], Cv[j][i]);
  373. }
  374. }
  375. }
  376. }
  377. for (int64_t j = 0; j < RN; ++j)
  378. for (int64_t i = 0; i < RM; ++i)
  379. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  380. }
  381. template <int RM, int RN, int BM>
  382. NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
  383. GGML_ASSERT(m % (RM * BM) == 0);
  384. const int64_t ytiles = m / (RM * BM);
  385. const int64_t xtiles = (n + RN -1) / RN;
  386. const int64_t jj_RN = (xtiles - (xtiles * RN - n));
  387. // "round" bloc_size to "nearest" BN
  388. const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
  389. const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
  390. const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
  391. const int64_t nb_job = ytiles * NB_BN;
  392. if (params->ith == 0) {
  393. GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
  394. // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
  395. ggml_threadpool_chunk_set(params->threadpool, params->nth);
  396. }
  397. ggml_barrier(params->threadpool);
  398. int64_t job = params->ith;
  399. while (job < nb_job) {
  400. const int64_t ii = (job % ytiles) * RM * BM;
  401. const int64_t jb = job / ytiles;
  402. const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
  403. const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
  404. const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
  405. const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
  406. const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
  407. for (int64_t bi = 0; bi < BM * RM; bi += RM) {
  408. int64_t jj = jj0;
  409. for (; jj < jj1; jj += RN) {
  410. gemm_bloc<RM, RN>(ii + bi, jj);
  411. }
  412. if constexpr (RN > 1) {
  413. for (; jj < jj2; jj += RN - 1) {
  414. gemm_bloc<RM, RN-1>(ii + bi, jj);
  415. }
  416. }
  417. GGML_ASSERT(jj == jj2);
  418. }
  419. job = ggml_threadpool_chunk_add(params->threadpool, 1);
  420. }
  421. ggml_barrier(params->threadpool);
  422. return;
  423. }
  424. const ggml_compute_params * params;
  425. const TA *const A;
  426. const TB *const B;
  427. TC *const C;
  428. const int64_t k;
  429. const int64_t lda;
  430. const int64_t ldb;
  431. const int64_t ldc;
  432. };
  433. //////////////////////////////////////////////////////////////////////////////////////////
  434. // QUANT ZERO MATRIX MULTIPLICATION
  435. #if defined(__ARM_FEATURE_DOTPROD)
  436. template <typename TA>
  437. class tinyBLAS_Q0_ARM {
  438. public:
  439. tinyBLAS_Q0_ARM(int64_t k,
  440. const TA *A, int64_t lda,
  441. const block_q8_0 *B, int64_t ldb,
  442. float *C, int64_t ldc,
  443. int ith, int nth)
  444. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  445. }
  446. void matmul(int64_t m, int64_t n) {
  447. mnpack(0, m, 0, n);
  448. }
  449. private:
  450. NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  451. int64_t mc, nc, mp, np;
  452. switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
  453. case 0x33:
  454. mc = 3;
  455. nc = 3;
  456. gemm<3, 3>(m0, m, n0, n);
  457. break;
  458. case 0x32:
  459. mc = 3;
  460. nc = 2;
  461. gemm<3, 2>(m0, m, n0, n);
  462. break;
  463. case 0x23:
  464. mc = 2;
  465. nc = 3;
  466. gemm<2, 3>(m0, m, n0, n);
  467. break;
  468. case 0x22:
  469. mc = 2;
  470. nc = 2;
  471. gemm<2, 2>(m0, m, n0, n);
  472. break;
  473. case 0x31:
  474. mc = 3;
  475. nc = 1;
  476. gemm<3, 1>(m0, m, n0, n);
  477. break;
  478. case 0x13:
  479. mc = 1;
  480. nc = 3;
  481. gemm<1, 3>(m0, m, n0, n);
  482. break;
  483. case 0x21:
  484. mc = 2;
  485. nc = 1;
  486. gemm<2, 1>(m0, m, n0, n);
  487. break;
  488. case 0x12:
  489. mc = 1;
  490. nc = 2;
  491. gemm<1, 2>(m0, m, n0, n);
  492. break;
  493. case 0x11:
  494. mc = 1;
  495. nc = 1;
  496. gemm<1, 1>(m0, m, n0, n);
  497. break;
  498. default:
  499. return;
  500. }
  501. mp = m0 + (m - m0) / mc * mc;
  502. np = n0 + (n - n0) / nc * nc;
  503. mnpack(mp, m, n0, np);
  504. mnpack(m0, m, np, n);
  505. }
  506. template <int RM, int RN>
  507. NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  508. int64_t ytiles = (m - m0) / RM;
  509. int64_t xtiles = (n - n0) / RN;
  510. int64_t tiles = xtiles * ytiles;
  511. int64_t duty = (tiles + nth - 1) / nth;
  512. int64_t start = duty * ith;
  513. int64_t end = start + duty;
  514. if (end > tiles)
  515. end = tiles;
  516. for (int64_t job = start; job < end; ++job) {
  517. int64_t ii = m0 + job / xtiles * RM;
  518. int64_t jj = n0 + job % xtiles * RN;
  519. float32x4_t Cv[RN][RM] = {};
  520. for (int64_t l = 0; l < k; ++l)
  521. for (int64_t j = 0; j < RN; ++j)
  522. for (int64_t i = 0; i < RM; ++i)
  523. Cv[j][i] = vmlaq_n_f32(Cv[j][i],
  524. vcvtq_f32_s32(vdotq_s32(
  525. vdotq_s32(vdupq_n_s32(0),
  526. load_lo(A + lda * (ii + i) + l),
  527. load_lo(B + ldb * (jj + j) + l)),
  528. load_hi(A + lda * (ii + i) + l),
  529. load_hi(B + ldb * (jj + j) + l))),
  530. unhalf(A[lda * (ii + i) + l].d) *
  531. unhalf(B[ldb * (jj + j) + l].d));
  532. for (int64_t j = 0; j < RN; ++j)
  533. for (int64_t i = 0; i < RM; ++i)
  534. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  535. }
  536. }
  537. inline int8x16_t load_lo(const block_q8_0 *b) {
  538. return vld1q_s8(b->qs);
  539. }
  540. inline int8x16_t load_hi(const block_q8_0 *b) {
  541. return vld1q_s8(b->qs + 16);
  542. }
  543. inline int8x16_t load_lo(const block_q4_0 *b) {
  544. return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
  545. vdupq_n_u8(0x0f))),
  546. vdupq_n_s8(0x8));
  547. }
  548. inline int8x16_t load_hi(const block_q4_0 *b) {
  549. return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
  550. vdupq_n_s8(0x8));
  551. }
  552. const TA *const A;
  553. const block_q8_0 *const B;
  554. float *const C;
  555. const int64_t k;
  556. const int64_t lda;
  557. const int64_t ldb;
  558. const int64_t ldc;
  559. const int ith;
  560. const int nth;
  561. };
  562. #endif // __ARM_FEATURE_DOTPROD
  563. #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
  564. template <typename TA, typename TB, typename TC>
  565. class tinyBLAS_Q0_AVX {
  566. public:
  567. tinyBLAS_Q0_AVX(int64_t k,
  568. const TA *A, int64_t lda,
  569. const TB *B, int64_t ldb,
  570. TC *C, int64_t ldc,
  571. int ith, int nth)
  572. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  573. const int8_t kvalues_iq4nl[16] = {
  574. -127, -104, -83, -65,
  575. -49, -35, -22, -10,
  576. 1, 13, 25, 38,
  577. 53, 69, 89, 113
  578. };
  579. iq4nlt = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
  580. }
  581. void matmul(int64_t m, int64_t n) {
  582. mnpack(0, m, 0, n);
  583. }
  584. private:
  585. void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  586. int64_t mc, nc, mp, np;
  587. switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
  588. #if VECTOR_REGISTERS == 32
  589. case 0x44:
  590. mc = 4;
  591. nc = 4;
  592. #if defined(__AVX2__) && defined(__F16C__)
  593. gemm4xN<4>(m0, m, n0, n);
  594. #else
  595. gemm<4, 4>(m0, m, n0, n);
  596. #endif
  597. break;
  598. case 0x43:
  599. mc = 4;
  600. nc = 3;
  601. #if defined(__AVX2__) && defined(__F16C__)
  602. gemm4xN<3>(m0, m, n0, n);
  603. #else
  604. gemm<4, 3>(m0, m, n0, n);
  605. #endif
  606. break;
  607. case 0x34:
  608. mc = 3;
  609. nc = 4;
  610. #if defined(__AVX2__) && defined(__F16C__)
  611. gemmMx4<3>(m0, m, n0, n);
  612. #else
  613. gemm<3, 4>(m0, m, n0, n);
  614. #endif
  615. break;
  616. case 0x33:
  617. mc = 3;
  618. nc = 3;
  619. gemm<3, 3>(m0, m, n0, n);
  620. break;
  621. case 0x42:
  622. mc = 4;
  623. nc = 2;
  624. #if defined(__AVX2__) && defined(__F16C__)
  625. gemm4xN<2>(m0, m, n0, n);
  626. #else
  627. gemm<4, 2>(m0, m, n0, n);
  628. #endif
  629. break;
  630. case 0x24:
  631. mc = 2;
  632. nc = 4;
  633. #if defined(__AVX2__) && defined(__F16C__)
  634. gemmMx4<2>(m0, m, n0, n);
  635. #else
  636. gemm<2, 4>(m0, m, n0, n);
  637. #endif
  638. break;
  639. #else
  640. case 0x44:
  641. case 0x43:
  642. case 0x42:
  643. mc = 4;
  644. nc = 2;
  645. #if defined(__AVX2__) && defined(__F16C__)
  646. gemm4xN<2>(m0, m, n0, n);
  647. #else
  648. gemm<4, 2>(m0, m, n0, n);
  649. #endif
  650. break;
  651. case 0x34:
  652. case 0x24:
  653. mc = 2;
  654. nc = 4;
  655. #if defined(__AVX2__) && defined(__F16C__)
  656. gemmMx4<2>(m0, m, n0, n);
  657. #else
  658. gemm<2, 4>(m0, m, n0, n);
  659. #endif
  660. break;
  661. case 0x33:
  662. #endif
  663. case 0x32:
  664. mc = 3;
  665. nc = 2;
  666. gemm<3, 2>(m0, m, n0, n);
  667. break;
  668. case 0x23:
  669. mc = 2;
  670. nc = 3;
  671. gemm<2, 3>(m0, m, n0, n);
  672. break;
  673. case 0x41:
  674. mc = 4;
  675. nc = 1;
  676. #if defined(__AVX2__) && defined(__F16C__)
  677. gemm4xN<1>(m0, m, n0, n);
  678. #else
  679. gemm<4, 1>(m0, m, n0, n);
  680. #endif
  681. break;
  682. case 0x22:
  683. mc = 2;
  684. nc = 2;
  685. gemm<2, 2>(m0, m, n0, n);
  686. break;
  687. case 0x14:
  688. mc = 1;
  689. nc = 4;
  690. #if defined(__AVX2__) && defined(__F16C__)
  691. gemmMx4<1>(m0, m, n0, n);
  692. #else
  693. gemm<1, 4>(m0, m, n0, n);
  694. #endif
  695. break;
  696. case 0x31:
  697. mc = 3;
  698. nc = 1;
  699. gemm<3, 1>(m0, m, n0, n);
  700. break;
  701. case 0x13:
  702. mc = 1;
  703. nc = 3;
  704. gemm<1, 3>(m0, m, n0, n);
  705. break;
  706. case 0x21:
  707. mc = 2;
  708. nc = 1;
  709. gemm<2, 1>(m0, m, n0, n);
  710. break;
  711. case 0x12:
  712. mc = 1;
  713. nc = 2;
  714. gemm<1, 2>(m0, m, n0, n);
  715. break;
  716. case 0x11:
  717. mc = 1;
  718. nc = 1;
  719. gemm<1, 1>(m0, m, n0, n);
  720. break;
  721. default:
  722. return;
  723. }
  724. mp = m0 + (m - m0) / mc * mc;
  725. np = n0 + (n - n0) / nc * nc;
  726. mnpack(mp, m, n0, np);
  727. mnpack(m0, m, np, n);
  728. }
  729. #if defined(__AVX2__) && defined(__F16C__)
  730. // Templated functions for gemm of dimensions 4xN
  731. template <int RN>
  732. NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  733. int64_t ytiles = (m - m0) / 4;
  734. int64_t xtiles = (n - n0) / RN;
  735. int64_t tiles = xtiles * ytiles;
  736. int64_t duty = (tiles + nth - 1) / nth;
  737. int64_t start = duty * ith;
  738. int64_t end = start + duty;
  739. if (end > tiles)
  740. end = tiles;
  741. for (int64_t job = start; job < end; ++job) {
  742. int64_t ii = m0 + job / xtiles * 4;
  743. int64_t jj = n0 + job % xtiles * RN;
  744. __m256 Cv[RN][4] = {};
  745. for (int64_t l = 0; l < k; ++l) {
  746. uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
  747. // Convert delta values for four blocks to float values
  748. __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
  749. __m256i avec0 = load(A + lda * (ii + 0) + l);
  750. __m256i avec1 = load(A + lda * (ii + 1) + l);
  751. __m256i avec2 = load(A + lda * (ii + 2) + l);
  752. __m256i avec3 = load(A + lda * (ii + 3) + l);
  753. for (int64_t j = 0; j < RN; ++j) {
  754. __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
  755. // Computation of product of delta values for four blocks and replicate it across 256 bit lane
  756. __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
  757. dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
  758. // Computation of dot product and multiplication with appropriate delta value products
  759. Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
  760. updot(_mm256_sign_epi8(avec0, avec0),
  761. _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
  762. Cv[j][0]);
  763. Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
  764. updot(_mm256_sign_epi8(avec1, avec1),
  765. _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
  766. Cv[j][1]);
  767. Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
  768. updot(_mm256_sign_epi8(avec2, avec2),
  769. _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
  770. Cv[j][2]);
  771. Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
  772. updot(_mm256_sign_epi8(avec3, avec3),
  773. _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
  774. Cv[j][3]);
  775. }
  776. }
  777. for (int64_t j = 0; j < RN; ++j)
  778. for (int64_t i = 0; i < 4; ++i)
  779. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  780. }
  781. }
  782. // Templated functions for gemm of dimensions Mx4
  783. template <int RM>
  784. NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  785. int64_t ytiles = (m - m0) / RM;
  786. int64_t xtiles = (n - n0) / 4;
  787. int64_t tiles = xtiles * ytiles;
  788. int64_t duty = (tiles + nth - 1) / nth;
  789. int64_t start = duty * ith;
  790. int64_t end = start + duty;
  791. if (end > tiles)
  792. end = tiles;
  793. for (int64_t job = start; job < end; ++job) {
  794. int64_t ii = m0 + job / xtiles * RM;
  795. int64_t jj = n0 + job % xtiles * 4;
  796. __m256 Cv[4][RM] = {};
  797. for (int64_t l = 0; l < k; ++l) {
  798. uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
  799. // Convert delta values for four blocks to float values
  800. __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
  801. __m256i bvec0 = load(B + ldb * (jj + 0) + l);
  802. __m256i bvec1 = load(B + ldb * (jj + 1) + l);
  803. __m256i bvec2 = load(B + ldb * (jj + 2) + l);
  804. __m256i bvec3 = load(B + ldb * (jj + 3) + l);
  805. for (int64_t i = 0; i < RM; ++i) {
  806. __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
  807. // Computation of product of delta values for four blocks and replicate it across 256 bit lane
  808. __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
  809. dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
  810. // Computation of dot product and multiplication with appropriate delta value products
  811. Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
  812. updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
  813. load(A + lda * (ii + i) + l)),
  814. _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
  815. Cv[0][i]);
  816. Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
  817. updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
  818. load(A + lda * (ii + i) + l)),
  819. _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
  820. Cv[1][i]);
  821. Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
  822. updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
  823. load(A + lda * (ii + i) + l)),
  824. _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
  825. Cv[2][i]);
  826. Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
  827. updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
  828. load(A + lda * (ii + i) + l)),
  829. _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
  830. Cv[3][i]);
  831. }
  832. }
  833. for (int64_t j = 0; j < 4; ++j)
  834. for (int64_t i = 0; i < RM; ++i)
  835. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  836. }
  837. }
  838. #endif
  839. template <int RM, int RN>
  840. NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  841. int64_t ytiles = (m - m0) / RM;
  842. int64_t xtiles = (n - n0) / RN;
  843. int64_t tiles = xtiles * ytiles;
  844. int64_t duty = (tiles + nth - 1) / nth;
  845. int64_t start = duty * ith;
  846. int64_t end = start + duty;
  847. if (end > tiles)
  848. end = tiles;
  849. for (int64_t job = start; job < end; ++job) {
  850. int64_t ii = m0 + job / xtiles * RM;
  851. int64_t jj = n0 + job % xtiles * RN;
  852. __m256 Cv[RN][RM] = {};
  853. for (int64_t l = 0; l < k; ++l)
  854. for (int64_t j = 0; j < RN; ++j)
  855. for (int64_t i = 0; i < RM; ++i) {
  856. #if defined(__AVX2__)
  857. __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
  858. load(A + lda * (ii + i) + l)),
  859. _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
  860. load(A + lda * (ii + i) + l)));
  861. #else
  862. __m128i ali0 = load0(A + lda * (ii + i) + l);
  863. __m128i ali1 = load1(A + lda * (ii + i) + l);
  864. __m128i blj0 = load0(B + ldb * (jj + j) + l);
  865. __m128i blj1 = load1(B + ldb * (jj + j) + l);
  866. __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
  867. __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
  868. __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
  869. __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
  870. // updot
  871. const __m128i oneFill = _mm_set1_epi16(1);
  872. __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
  873. __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
  874. __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
  875. #endif
  876. Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
  877. unhalf(B[ldb * (jj + j) + l].d)),
  878. udTmp,
  879. Cv[j][i]);
  880. }
  881. for (int64_t j = 0; j < RN; ++j)
  882. for (int64_t i = 0; i < RM; ++i)
  883. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  884. }
  885. }
  886. inline __m256i load(const block_q8_0 *b) {
  887. return _mm256_loadu_si256((const __m256i *)b->qs);
  888. }
  889. inline __m128i load0(const block_q8_0 *b) {
  890. return _mm_loadu_si128((const __m128i *)b->qs);
  891. }
  892. inline __m128i load1(const block_q8_0 *b) {
  893. return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
  894. }
  895. inline __m256i load(const block_q4_0 *b) {
  896. return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
  897. }
  898. inline __m128i load0(const block_q4_0 *b) {
  899. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  900. return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
  901. }
  902. inline __m128i load1(const block_q4_0 *b) {
  903. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  904. return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
  905. }
  906. inline __m256i load(const block_q5_0 *b) {
  907. return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));
  908. }
  909. inline __m128i load0(const block_q5_0* b) {
  910. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  911. uint32_t x32;
  912. memcpy(&x32, b->qh, sizeof(uint32_t));
  913. __m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);
  914. __m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
  915. _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
  916. _mm_shuffle_epi8(_mm_set1_epi32(x32),
  917. _mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));
  918. bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));
  919. return _mm_or_si128(qxl, bytesl);
  920. }
  921. inline __m128i load1(const block_q5_0* b) {
  922. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  923. uint32_t x32;
  924. memcpy(&x32, b->qh, sizeof(uint32_t));
  925. __m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
  926. __m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
  927. _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
  928. _mm_shuffle_epi8(_mm_set1_epi32(x32),
  929. _mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));
  930. bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));
  931. return _mm_or_si128(qxh, bytesh);
  932. }
  933. inline __m256i load(const block_iq4_nl *b) {
  934. return MM256_SET_M128I(load1(b), load0(b));
  935. }
  936. inline __m128i load0(const block_iq4_nl *b) {
  937. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  938. return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x));
  939. }
  940. inline __m128i load1(const block_iq4_nl *b) {
  941. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  942. return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)));
  943. }
  944. inline __m256 updot(__m256i u, __m256i s) {
  945. __m256i res;
  946. #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
  947. res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
  948. #elif defined(__AVXVNNI__)
  949. res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
  950. #else
  951. res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
  952. #endif
  953. return _mm256_cvtepi32_ps(res);
  954. }
  955. static inline __m256i denibble(const uint8_t *p) {
  956. __m128i x = _mm_loadu_si128((const __m128i *)p);
  957. return _mm256_and_si256(_mm256_set1_epi8(15),
  958. _mm256_insertf128_si256(_mm256_castsi128_si256(x),
  959. _mm_srli_epi16(x, 4), 1));
  960. }
  961. static inline __m256i bittobyte(const uint8_t *p) {
  962. uint32_t x32;
  963. memcpy(&x32, p, sizeof(uint32_t));
  964. __m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),
  965. _mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),
  966. _mm256_shuffle_epi8(_mm256_set1_epi32(x32),
  967. _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,
  968. 0x0101010101010101, 0x0000000000000000))));
  969. return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));
  970. }
  971. const TA *const A;
  972. const TB *const B;
  973. TC *const C;
  974. const int64_t k;
  975. const int64_t lda;
  976. const int64_t ldb;
  977. const int64_t ldc;
  978. const int ith;
  979. const int nth;
  980. __m128i iq4nlt;
  981. };
  982. #endif // __AVX__
  983. //PPC Implementation
  984. #if defined(__MMA__)
  985. #define SAVE_ACC(ACC, ii, jj) \
  986. __builtin_mma_disassemble_acc(vec_C, ACC); \
  987. for (int I = 0; I < 4; I++) { \
  988. for (int J = 0; J < 4; J++) { \
  989. *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
  990. } \
  991. } \
  992. template <typename TA, typename TB, typename TC>
  993. class tinyBLAS_BF16_PPC {
  994. public:
  995. tinyBLAS_BF16_PPC(int64_t k,
  996. const TA *A, int64_t lda,
  997. const TB *B, int64_t ldb,
  998. TC *C, int64_t ldc,
  999. int ith, int nth)
  1000. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  1001. }
  1002. void matmul(int64_t m, int64_t n) {
  1003. mnpack(0, m, 0, n);
  1004. }
  1005. private:
  1006. void vector_permute_store(vec_t *c, int numVec, unsigned char *vecOffset) {
  1007. vec_t t[8], s[8];
  1008. vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
  1009. vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
  1010. vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
  1011. vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
  1012. if (numVec == 2) {
  1013. t[0] = vec_perm(c[0], c[1], swiz1);
  1014. t[1] = vec_perm(c[2], c[3], swiz1);
  1015. s[0] = vec_perm(t[0], t[1], swiz3);
  1016. s[1] = vec_perm(t[0], t[1], swiz4);
  1017. vec_xst(s[0], 0, (vec_t*)vecOffset);
  1018. vec_xst(s[1], 0, (vec_t*)(vecOffset + 16));
  1019. } else if (numVec == 4) {
  1020. t[0] = vec_perm(c[0], c[1], swiz1);
  1021. t[1] = vec_perm(c[0], c[1], swiz2);
  1022. t[2] = vec_perm(c[2], c[3], swiz1);
  1023. t[3] = vec_perm(c[2], c[3], swiz2);
  1024. s[0] = vec_perm(t[0], t[2], swiz3);
  1025. s[1] = vec_perm(t[0], t[2], swiz4);
  1026. s[2] = vec_perm(t[1], t[3], swiz3);
  1027. s[3] = vec_perm(t[1], t[3], swiz4);
  1028. for (int i = 0; i < 4; ++i)
  1029. vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
  1030. } else if (numVec == 8) {
  1031. for (int i = 0; i < 4; i += 2) {
  1032. t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
  1033. t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
  1034. }
  1035. for (int i = 4; i < 8; i += 2) {
  1036. t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
  1037. t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
  1038. }
  1039. s[0] = vec_perm(t[0], t[2], swiz3);
  1040. s[1] = vec_perm(t[0], t[2], swiz4);
  1041. s[2] = vec_perm(t[1], t[3], swiz3);
  1042. s[3] = vec_perm(t[1], t[3], swiz4);
  1043. s[4] = vec_perm(t[4], t[6], swiz3);
  1044. s[5] = vec_perm(t[4], t[6], swiz4);
  1045. s[6] = vec_perm(t[5], t[7], swiz3);
  1046. s[7] = vec_perm(t[5], t[7], swiz4);
  1047. for (int i = 0; i < 8; ++i)
  1048. vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
  1049. }
  1050. }
  1051. void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) {
  1052. int64_t i, j;
  1053. TA *aoffset = NULL;
  1054. unsigned char *vecOffset = NULL;
  1055. TA * aoffsets[8];
  1056. vector unsigned char c_arr[8];
  1057. aoffset = const_cast<TA*>(a);
  1058. vecOffset = vec;
  1059. j = (rows >> 3);
  1060. if (j > 0) {
  1061. do {
  1062. if (cols == 4) {
  1063. aoffsets[0] = aoffset;
  1064. for (int it = 1; it < 4; ++it)
  1065. aoffsets[it] = aoffsets[it-1] + lda;
  1066. aoffset += 4 * lda;
  1067. for (int i = 0; i < 4; ++i)
  1068. c_arr[i] = vec_xl(0, (vector unsigned char*)aoffsets[i]);
  1069. vector_permute_store(c_arr, 4, vecOffset);
  1070. for (int i = 0; i<4; i++)
  1071. aoffsets[i] = aoffsets[i]+lda;
  1072. vecOffset +=64;
  1073. }
  1074. i = (cols >> 3);
  1075. if (i > 0) {
  1076. aoffsets[0] = aoffset;
  1077. for (int it = 1; it < 8; ++it) {
  1078. aoffsets[it] = aoffsets[it-1] + lda;
  1079. }
  1080. aoffset += 8 * lda;
  1081. do {
  1082. for (int it = 0; it < 8; ++it)
  1083. c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
  1084. vector_permute_store(c_arr, 8, vecOffset);
  1085. for (int it = 0; it < 8; ++it)
  1086. aoffsets[it] = aoffsets[it] + 8*lda;
  1087. vecOffset += 128;
  1088. i--;
  1089. } while(i > 0);
  1090. }
  1091. j--;
  1092. } while(j > 0);
  1093. }
  1094. if (rows & 4) {
  1095. aoffsets[0] = aoffset;
  1096. for (int it = 1; it < 4; ++it)
  1097. aoffsets[it] = aoffsets[it-1] + lda;
  1098. aoffset += 4 * lda;
  1099. if (cols == 4) {
  1100. for (int it = 0; it < 4; ++it)
  1101. c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
  1102. vector_permute_store(c_arr, 2, vecOffset);
  1103. for (int it = 0; it< 4; it++)
  1104. aoffsets[it] = aoffsets[it] + lda;
  1105. vecOffset += 32;
  1106. }
  1107. i = (cols >> 3);
  1108. if (i > 0) {
  1109. do {
  1110. for (int it = 0; it < 4; ++it)
  1111. c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
  1112. vector_permute_store(c_arr, 4, vecOffset);
  1113. for (int it = 0; it< 4; it++)
  1114. aoffsets[it] = aoffsets[it] + 8*lda;
  1115. vecOffset += 64;
  1116. i--;
  1117. } while(i > 0);
  1118. }
  1119. }
  1120. if (rows & 3) {
  1121. aoffsets[0] = aoffset;
  1122. for (int it = 1; it < 4; ++it)
  1123. aoffsets[it] = aoffsets[it-1] + lda;
  1124. if (cols == 4) {
  1125. switch(rows) {
  1126. case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
  1127. case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
  1128. case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
  1129. break;
  1130. }
  1131. vector_permute_store(c_arr, 2, vecOffset);
  1132. for (int it = 0; it< 4; it++)
  1133. aoffsets[it] = aoffsets[it] + lda;
  1134. vecOffset += 32;
  1135. }
  1136. i = (cols >> 3);
  1137. if (i > 0) {
  1138. do {
  1139. switch(rows) {
  1140. case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
  1141. case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
  1142. case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
  1143. break;
  1144. }
  1145. vector_permute_store(c_arr, 4, vecOffset);
  1146. for (int it = 0; it <4; it++)
  1147. aoffsets[it] = aoffsets[it] + 8* lda;
  1148. vecOffset += 64;
  1149. i--;
  1150. } while(i > 0);
  1151. }
  1152. }
  1153. }
  1154. void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  1155. int64_t mc, nc, mp, np;
  1156. int m_rem = MIN(m - m0, 8);
  1157. int n_rem = MIN(n - n0, 8);
  1158. if (m_rem >= 8 && n_rem >= 8) {
  1159. mc = 8;
  1160. nc = 8;
  1161. gemm<8,8>(m0, m, n0, n);
  1162. } else if (m_rem >= 4 && n_rem >= 8) {
  1163. mc = 4;
  1164. nc = 8;
  1165. gemm<4,8>(m0, m, n0, n);
  1166. } else if (m_rem >=8 && n_rem >=4){
  1167. mc = 8;
  1168. nc = 4;
  1169. gemm<8,4>(m0, m, n0, n);
  1170. } else if ((m_rem < 4) && (n_rem >= 8)) {
  1171. nc = 8;
  1172. switch(m_rem) {
  1173. case 1:
  1174. mc = 1;
  1175. gemm_Mx8<1>(m0, m, n0, n);
  1176. break;
  1177. case 2:
  1178. mc = 2;
  1179. gemm_Mx8<2>(m0, m, n0, n);
  1180. break;
  1181. case 3:
  1182. mc = 3;
  1183. gemm_Mx8<3>(m0, m, n0, n);
  1184. break;
  1185. default:
  1186. return;
  1187. }
  1188. } else if (m_rem >= 4 && n_rem >= 4) {
  1189. mc = 4;
  1190. nc = 4;
  1191. gemm_small<4, 4>(m0, m, n0, n);
  1192. } else if ((m_rem > 4) && (n_rem < 4)) {
  1193. mc = 4;
  1194. switch(n_rem) {
  1195. case 1:
  1196. nc = 1;
  1197. gemm_small<4, 1>(m0, m, n0, n);
  1198. break;
  1199. case 2:
  1200. nc = 2;
  1201. gemm_small<4, 2>(m0, m, n0, n);
  1202. break;
  1203. case 3:
  1204. nc = 3;
  1205. gemm_small<4, 3>(m0, m, n0, n);
  1206. break;
  1207. default:
  1208. return;
  1209. }
  1210. } else {
  1211. switch((m_rem << 4) | n_rem) {
  1212. case 0x43:
  1213. mc = 4;
  1214. nc = 3;
  1215. gemm_small<4, 3>(m0, m, n0, n);
  1216. break;
  1217. case 0x42:
  1218. mc = 4;
  1219. nc = 2;
  1220. gemm_small<4, 2>(m0, m, n0, n);
  1221. break;
  1222. case 0x41:
  1223. mc = 4;
  1224. nc = 1;
  1225. gemm_small<4, 1>(m0, m, n0, n);
  1226. break;
  1227. case 0x34:
  1228. mc = 3;
  1229. nc = 4;
  1230. gemm_small<3, 4>(m0, m, n0, n);
  1231. break;
  1232. case 0x33:
  1233. mc = 3;
  1234. nc = 3;
  1235. gemm_small<3, 3>(m0, m, n0, n);
  1236. break;
  1237. case 0x32:
  1238. mc = 3;
  1239. nc = 2;
  1240. gemm_small<3, 2>(m0, m, n0, n);
  1241. break;
  1242. case 0x31:
  1243. mc = 3;
  1244. nc = 1;
  1245. gemm_small<3, 1>(m0, m, n0, n);
  1246. break;
  1247. case 0x24:
  1248. mc = 2;
  1249. nc = 4;
  1250. gemm_small<2,4>(m0, m, n0, n);
  1251. break;
  1252. case 0x23:
  1253. mc = 2;
  1254. nc = 3;
  1255. gemm_small<2, 3>(m0, m, n0, n);
  1256. break;
  1257. case 0x22:
  1258. mc = 2;
  1259. nc = 2;
  1260. gemm_small<2, 2>(m0, m, n0, n);
  1261. break;
  1262. case 0x21:
  1263. mc = 2;
  1264. nc = 1;
  1265. gemm_small<2, 1>(m0, m, n0, n);
  1266. break;
  1267. case 0x14:
  1268. mc = 1;
  1269. nc = 4;
  1270. gemm_small<1, 4>(m0, m, n0, n);
  1271. break;
  1272. case 0x13:
  1273. mc = 1;
  1274. nc = 3;
  1275. gemm_small<1, 3>(m0, m, n0, n);
  1276. break;
  1277. case 0x12:
  1278. mc = 1;
  1279. nc = 2;
  1280. gemm_small<1, 2>(m0, m, n0, n);
  1281. break;
  1282. case 0x11:
  1283. mc = 1;
  1284. nc = 1;
  1285. gemm_small<1, 1>(m0, m, n0, n);
  1286. break;
  1287. default:
  1288. return;
  1289. }
  1290. }
  1291. mp = m0 + (m - m0) / mc * mc;
  1292. np = n0 + (n - n0) / nc * nc;
  1293. mnpack(mp, m, n0, np);
  1294. mnpack(m0, m, np, n);
  1295. }
  1296. void KERNEL_4x8(int64_t ii, int64_t jj) {
  1297. vec_t vec_A[4], vec_B[8] , vec_C[4];
  1298. acc_t acc_0, acc_1;
  1299. __builtin_mma_xxsetaccz(&acc_0);
  1300. __builtin_mma_xxsetaccz(&acc_1);
  1301. for (int l = 0; l < k; l+=8) {
  1302. packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
  1303. packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
  1304. for (int x = 0; x < 4; x++) {
  1305. __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
  1306. __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
  1307. }
  1308. }
  1309. SAVE_ACC(&acc_0, ii, jj);
  1310. SAVE_ACC(&acc_1, ii, jj+4);
  1311. }
  1312. void KERNEL_8x4(int64_t ii, int64_t jj) {
  1313. vec_t vec_A[8], vec_B[4] , vec_C[4];
  1314. acc_t acc_0, acc_1;
  1315. __builtin_mma_xxsetaccz(&acc_0);
  1316. __builtin_mma_xxsetaccz(&acc_1);
  1317. for (int l = 0; l < k; l+=8) {
  1318. packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
  1319. packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
  1320. for (int x = 0; x < 4; x++) {
  1321. __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
  1322. __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
  1323. }
  1324. }
  1325. SAVE_ACC(&acc_0, ii, jj);
  1326. SAVE_ACC(&acc_1, ii+4, jj);
  1327. }
  1328. void KERNEL_8x8(int64_t ii, int64_t jj) {
  1329. vec_t vec_A[8], vec_B[8], vec_C[4];
  1330. acc_t acc_0, acc_1, acc_2, acc_3;
  1331. __builtin_mma_xxsetaccz(&acc_0);
  1332. __builtin_mma_xxsetaccz(&acc_1);
  1333. __builtin_mma_xxsetaccz(&acc_2);
  1334. __builtin_mma_xxsetaccz(&acc_3);
  1335. for (int l = 0; l < k; l+=8) {
  1336. packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
  1337. packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
  1338. for (int x = 0; x < 4; x++) {
  1339. __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
  1340. __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
  1341. __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
  1342. __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
  1343. }
  1344. }
  1345. SAVE_ACC(&acc_0, ii, jj);
  1346. SAVE_ACC(&acc_1, ii, jj+4);
  1347. SAVE_ACC(&acc_2, ii+4, jj);
  1348. SAVE_ACC(&acc_3, ii+4, jj+4);
  1349. }
  1350. template<int RM, int RN>
  1351. void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  1352. int64_t ytiles = (m - m0) / RM;
  1353. int64_t xtiles = (n - n0) / RN;
  1354. int64_t tiles = xtiles * ytiles;
  1355. int64_t duty = (tiles + nth - 1) / nth;
  1356. int64_t start = duty * ith;
  1357. int64_t end = start + duty;
  1358. if (end > tiles)
  1359. end = tiles;
  1360. for (int64_t job = start; job < end; ++job) {
  1361. int64_t ii = m0 + job / xtiles * RM;
  1362. int64_t jj = n0 + job % xtiles * RN;
  1363. vec_t vec_C[4];
  1364. acc_t acc_0;
  1365. __builtin_mma_xxsetaccz(&acc_0);
  1366. vec_t vec_A[2], vec_B[2];
  1367. for (int l=0; l<k; l+=4) {
  1368. packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
  1369. packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
  1370. for (int x = 0; x<2; x++) {
  1371. __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
  1372. }
  1373. }
  1374. __builtin_mma_disassemble_acc(vec_C, &acc_0);
  1375. for (int I = 0; I < RM; I++) {
  1376. for (int J = 0; J < RN; J++) {
  1377. *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
  1378. }
  1379. }
  1380. }
  1381. }
  1382. template<int RM>
  1383. void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  1384. int RN = 8;
  1385. int64_t ytiles = (m - m0) / RM;
  1386. int64_t xtiles = (n - n0) / RN;
  1387. int64_t tiles = xtiles * ytiles;
  1388. int64_t duty = (tiles + nth - 1) / nth;
  1389. int64_t start = duty * ith;
  1390. int64_t end = start + duty;
  1391. if (end > tiles)
  1392. end = tiles;
  1393. for (int64_t job = start; job < end; ++job) {
  1394. int64_t ii = m0 + job / xtiles * RM;
  1395. int64_t jj = n0 + job % xtiles * RN;
  1396. vec_t vec_C[4];
  1397. acc_t acc_0, acc_1;
  1398. __builtin_mma_xxsetaccz(&acc_0);
  1399. __builtin_mma_xxsetaccz(&acc_1);
  1400. vec_t vec_A[4], vec_B[8];
  1401. for (int l=0; l<k; l+=8) {
  1402. packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
  1403. packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
  1404. for (int x = 0; x<4; x++) {
  1405. __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
  1406. __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
  1407. }
  1408. }
  1409. __builtin_mma_disassemble_acc(vec_C, &acc_0);
  1410. for (int I = 0; I < RM; I++) {
  1411. for (int J = 0; J < 4; J++) {
  1412. *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
  1413. }
  1414. }
  1415. __builtin_mma_disassemble_acc(vec_C, &acc_1);
  1416. for (int I = 0; I < RM; I++) {
  1417. for (int J = 0; J < 4; J++) {
  1418. *((TC*)(C+ii+((jj+4+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
  1419. }
  1420. }
  1421. }
  1422. }
  1423. template<int RM, int RN>
  1424. inline void kernel(int64_t ii, int64_t jj) {
  1425. if constexpr(RM == 4 && RN == 8) {
  1426. KERNEL_4x8(ii,jj);
  1427. } else if constexpr(RM == 8 && RN == 8) {
  1428. KERNEL_8x8(ii,jj);
  1429. } else if constexpr(RM == 8 && RN == 4) {
  1430. KERNEL_8x4(ii,jj);
  1431. } else {
  1432. assert(false && "RN/RM values not supported");
  1433. }
  1434. }
  1435. template <int RM, int RN>
  1436. NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  1437. int64_t ytiles = (m - m0) / RM;
  1438. int64_t xtiles = (n - n0) / RN;
  1439. int64_t tiles = xtiles * ytiles;
  1440. int64_t duty = (tiles + nth - 1) / nth;
  1441. int64_t start = duty * ith;
  1442. int64_t end = start + duty;
  1443. if (end > tiles)
  1444. end = tiles;
  1445. for (int64_t job = start; job < end; ++job) {
  1446. int64_t ii = m0 + job / xtiles * RM;
  1447. int64_t jj = n0 + job % xtiles * RN;
  1448. kernel<RM, RN>(ii, jj);
  1449. }
  1450. }
  1451. const TA *const A;
  1452. const TB *const B;
  1453. TC *C;
  1454. const int64_t k;
  1455. const int64_t lda;
  1456. const int64_t ldb;
  1457. const int64_t ldc;
  1458. const int ith;
  1459. const int nth;
  1460. };
  1461. template <typename TA>
  1462. tinyBLAS_Q0_PPC<TA>::tinyBLAS_Q0_PPC(int64_t k,
  1463. const TA *A, int64_t lda,
  1464. const block_q8_0 *B, int64_t ldb,
  1465. float *C, int64_t ldc,
  1466. int ith, int nth)
  1467. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  1468. kc = 64;
  1469. }
  1470. template<typename TA>
  1471. void tinyBLAS_Q0_PPC<TA>::matmul(int64_t m, int64_t n) {
  1472. int mc = 64; int nc = 64;
  1473. if (n % 8 == 0 && n < nc) {
  1474. nc = n;
  1475. mc = 32 ;
  1476. kc = 32;
  1477. }
  1478. const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
  1479. if (is_aligned) {
  1480. this->matmul_tiled_q0(m, n, mc, nc, kc);
  1481. } else {
  1482. mnpack(0, m, 0, n);
  1483. }
  1484. }
  1485. template<typename TA>
  1486. template<int size>
  1487. void tinyBLAS_Q0_PPC<TA>::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
  1488. int64_t i, j;
  1489. TA *aoffset = NULL;
  1490. int8_t *vecOffset = NULL;
  1491. TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
  1492. TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
  1493. vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
  1494. vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
  1495. aoffset = const_cast<TA*>(a);
  1496. vecOffset = vec;
  1497. j = (rows >> 3);
  1498. if (j > 0) {
  1499. do {
  1500. aoffset1 = aoffset;
  1501. aoffset2 = aoffset1 + lda;
  1502. aoffset3 = aoffset2 + lda;
  1503. aoffset4 = aoffset3 + lda;
  1504. aoffset5 = aoffset4 + lda;
  1505. aoffset6 = aoffset5 + lda;
  1506. aoffset7 = aoffset6 + lda;
  1507. aoffset8 = aoffset7 + lda;
  1508. aoffset += 8 * lda;
  1509. i = (cols >> 2);
  1510. if (i > 0) {
  1511. do {
  1512. c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
  1513. c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
  1514. c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
  1515. c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
  1516. c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset5->qs));
  1517. c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset6->qs));
  1518. c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
  1519. c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
  1520. process_q4_elements(c1, &comparray[0]);
  1521. process_q4_elements(c2, &comparray[1]);
  1522. process_q4_elements(c3, &comparray[2]);
  1523. process_q4_elements(c4, &comparray[3]);
  1524. process_q4_elements(c5, &comparray[4]);
  1525. process_q4_elements(c6, &comparray[5]);
  1526. process_q4_elements(c7, &comparray[6]);
  1527. process_q4_elements(c8, &comparray[7]);
  1528. vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
  1529. vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
  1530. vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
  1531. vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
  1532. aoffset1 += lda;
  1533. aoffset2 += lda;
  1534. aoffset3 += lda;
  1535. aoffset4 += lda;
  1536. aoffset5 += lda;
  1537. aoffset6 += lda;
  1538. aoffset7 += lda;
  1539. aoffset8 += lda;
  1540. vecOffset += 256;
  1541. i--;
  1542. } while (i > 0);
  1543. }
  1544. j--;
  1545. } while (j > 0);
  1546. }
  1547. if (rows & 4) {
  1548. aoffset1 = aoffset;
  1549. aoffset2 = aoffset1 + lda;
  1550. aoffset3 = aoffset2 + lda;
  1551. aoffset4 = aoffset3 + lda;
  1552. aoffset += 4 * lda;
  1553. i = (cols >> 2);
  1554. if (i > 0) {
  1555. do {
  1556. c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
  1557. c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
  1558. c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
  1559. c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
  1560. process_q4_elements(c1, &comparray[0]);
  1561. process_q4_elements(c2, &comparray[1]);
  1562. process_q4_elements(c3, &comparray[2]);
  1563. process_q4_elements(c4, &comparray[3]);
  1564. vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
  1565. vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
  1566. aoffset1 += lda;
  1567. aoffset2 += lda;
  1568. aoffset3 += lda;
  1569. aoffset4 += lda;
  1570. vecOffset += 128;
  1571. i--;
  1572. } while (i > 0);
  1573. }
  1574. }
  1575. if (rows & 3) {
  1576. aoffset1 = aoffset;
  1577. aoffset2 = aoffset1 + lda;
  1578. aoffset3 = aoffset2 + lda;
  1579. i = (cols >> 2);
  1580. if (i > 0) {
  1581. do {
  1582. switch(rows) {
  1583. case 3: c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
  1584. case 2: c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
  1585. case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
  1586. break;
  1587. }
  1588. process_q4_elements(c1, &comparray[0]);
  1589. process_q4_elements(c2, &comparray[1]);
  1590. process_q4_elements(c3, &comparray[2]);
  1591. process_q4_elements(c4, &comparray[3]);
  1592. vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
  1593. vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
  1594. aoffset1 += lda;
  1595. aoffset2 += lda;
  1596. aoffset3 += lda;
  1597. vecOffset += 128;
  1598. i--;
  1599. } while(i > 0);
  1600. }
  1601. }
  1602. }
  1603. template<typename TA>
  1604. template<typename VA, typename VB>
  1605. void tinyBLAS_Q0_PPC<TA>::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
  1606. int64_t i, j;
  1607. block_q8_0 *aoffset = NULL;
  1608. VA *vecOffset = NULL;
  1609. block_q8_0* aoffsets[8];
  1610. __vector_pair arr[8];
  1611. VB c[8][2] = {0};
  1612. VB c1[8] = {0}; VB c2[8] = {0};
  1613. aoffset = const_cast<block_q8_0*>(a);
  1614. vecOffset = vec;
  1615. j = (rows >> 3);
  1616. if (j > 0) {
  1617. do {
  1618. aoffsets[0] = aoffset;
  1619. for (int it = 1; it < 8; it++)
  1620. aoffsets[it] = aoffsets[it-1] + lda;
  1621. aoffset += 8 * lda;
  1622. i = (cols >> 3);
  1623. if (i > 0) {
  1624. do {
  1625. for (int it = 0; it < 8; it++) {
  1626. arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
  1627. __builtin_vsx_disassemble_pair(c[it], &arr[it]);
  1628. c1[it] = c[it][0];
  1629. c2[it] = c[it][1];
  1630. }
  1631. vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
  1632. vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
  1633. vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
  1634. vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
  1635. for (int it = 0; it < 8; it++)
  1636. aoffsets[it] += lda;
  1637. vecOffset += 256;
  1638. i--;
  1639. } while(i > 0);
  1640. }
  1641. j--;
  1642. } while(j > 0);
  1643. }
  1644. if (rows & 4) {
  1645. aoffsets[0] = aoffset;
  1646. for (int it = 1; it < 4; it++ )
  1647. aoffsets[it] = aoffsets[it-1] + lda;
  1648. aoffset += 4 * lda;
  1649. i = (cols >> 3);
  1650. if (i > 0) {
  1651. do {
  1652. for (int it = 0; it < 4; it++) {
  1653. arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
  1654. __builtin_vsx_disassemble_pair(c[it], &arr[it]);
  1655. c1[it] = c[it][0];
  1656. c2[it] = c[it][1];
  1657. }
  1658. vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
  1659. vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
  1660. for (int it = 0; it < 4; it++) {
  1661. aoffsets[it] += lda;
  1662. }
  1663. vecOffset += 128;
  1664. i--;
  1665. } while(i > 0);
  1666. }
  1667. }
  1668. if (rows & 3) {
  1669. aoffsets[0] = aoffset;
  1670. for (int it = 1; it < 3; it++ )
  1671. aoffsets[it] = aoffsets[it-1] + lda;
  1672. i = (cols >> 3);
  1673. if (i > 0) {
  1674. do {
  1675. switch(rows) {
  1676. case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
  1677. __builtin_vsx_disassemble_pair(c[2], &arr[2]);
  1678. c1[2] = c[2][0]; c2[2] = c[2][1];
  1679. case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
  1680. __builtin_vsx_disassemble_pair(c[1], &arr[1]);
  1681. c1[1] = c[1][0]; c2[1] = c[1][1];
  1682. case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
  1683. __builtin_vsx_disassemble_pair(c[0], &arr[0]);
  1684. c1[0] = c[0][0]; c2[0] = c[0][1];
  1685. break;
  1686. }
  1687. vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
  1688. vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
  1689. for (int it = 0; it < 3; it++)
  1690. aoffsets[it] += lda;
  1691. vecOffset += 128;
  1692. i--;
  1693. } while(i > 0);
  1694. }
  1695. }
  1696. }
  1697. template<typename TA>
  1698. void tinyBLAS_Q0_PPC<TA>::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  1699. int m_rem = MIN(m - m0, 16);
  1700. int n_rem = MIN(n - n0, 16);
  1701. int mc = 0, nc = 0;
  1702. if (m_rem >= 8 && n_rem >= 8) {
  1703. mc = 8;
  1704. nc = 8;
  1705. gemm<8, 8>(m0, m, n0, n);
  1706. } else if (m_rem >= 4 && n_rem >= 8) {
  1707. mc = 4;
  1708. nc = 8;
  1709. gemm<4, 8>(m0, m, n0, n);
  1710. } else if (m_rem >= 8 && n_rem >= 4) {
  1711. mc = 8;
  1712. nc = 4;
  1713. gemm<8, 4>(m0, m, n0, n);
  1714. } else if (m_rem >= 4 && n_rem >= 4) {
  1715. mc = 4;
  1716. nc = 4;
  1717. gemm_small(m0, m, n0, n, mc, nc);
  1718. } else {
  1719. mc = (m_rem >= 4) ? 4 : m_rem;
  1720. nc = (n_rem >= 4) ? 4 : n_rem;
  1721. if (mc == 0 || nc == 0)
  1722. return;
  1723. gemm_small(m0, m, n0, n, mc, nc);
  1724. }
  1725. int64_t mp = m0 + ((m - m0) / mc) * mc;
  1726. int64_t np = n0 + ((n - n0) / nc) * nc;
  1727. mnpack(mp, m, n0, np);
  1728. mnpack(m0, m, np, n);
  1729. }
  1730. template<typename TA>
  1731. void tinyBLAS_Q0_PPC<TA>::KERNEL_4x8(int64_t ii, int64_t jj) {
  1732. vec_t vec_A[8], vec_B[16] = {0};
  1733. acc_t acc_0, acc_1;
  1734. std::array<int, 4> comparray {};
  1735. vector float fin_res[8] = {0};
  1736. vector float vs[8] = {0};
  1737. bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
  1738. for (int l = 0; l < k; l++) {
  1739. __builtin_mma_xxsetaccz(&acc_0);
  1740. __builtin_mma_xxsetaccz(&acc_1);
  1741. if (std::is_same_v<TA, block_q4_0>) {
  1742. packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
  1743. } else {
  1744. packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
  1745. }
  1746. packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
  1747. for(int x = 0; x < 8; x++) {
  1748. __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
  1749. __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
  1750. }
  1751. for (int I = 0; I<4; I++) {
  1752. for (int J = 0; J<4; J++) {
  1753. *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
  1754. *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
  1755. }
  1756. }
  1757. if (!isAblock_q4) {
  1758. auto aoffset = A+(ii*lda)+l;
  1759. for (int i = 0; i < 4; i++) {
  1760. comparray[i] = 0;
  1761. int ca = 0;
  1762. auto *at = aoffset->qs;
  1763. for (int j = 0; j < 32; j++)
  1764. ca += (int)*at++;
  1765. comparray[i] = ca;
  1766. aoffset += lda;
  1767. }
  1768. }
  1769. compute(&acc_0, 0, 0, comparray, vs, fin_res);
  1770. compute(&acc_1, 0, 4, comparray, vs, fin_res);
  1771. }
  1772. save_res(ii, jj, 0, fin_res);
  1773. save_res(ii, jj+4, 4, fin_res);
  1774. }
  1775. template<typename TA>
  1776. void tinyBLAS_Q0_PPC<TA>::KERNEL_8x4(int64_t ii, int64_t jj) {
  1777. vec_t vec_A[16], vec_B[8] = {0};
  1778. acc_t acc_0, acc_1;
  1779. std::array<int, 8> comparray {};
  1780. vector float fin_res[8] = {0};
  1781. vector float vs[8] = {0};
  1782. bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
  1783. for (int l = 0; l < k; l++) {
  1784. __builtin_mma_xxsetaccz(&acc_0);
  1785. __builtin_mma_xxsetaccz(&acc_1);
  1786. if (std::is_same_v<TA, block_q4_0>) {
  1787. packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
  1788. } else {
  1789. packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
  1790. }
  1791. packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
  1792. for(int x = 0; x < 8; x++) {
  1793. __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
  1794. __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
  1795. }
  1796. for (int I = 0; I<8; I++) {
  1797. for (int J = 0; J<4; J++) {
  1798. *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
  1799. }
  1800. }
  1801. if (!isAblock_q4) {
  1802. auto aoffset = A+(ii*lda)+l;
  1803. for (int i = 0; i < 8; i++) {
  1804. comparray[i] = 0;
  1805. int ca = 0;
  1806. auto *at = aoffset->qs;
  1807. for (int j = 0; j < 32; j++)
  1808. ca += (int)*at++;
  1809. comparray[i] = ca;
  1810. aoffset += lda;
  1811. }
  1812. }
  1813. compute(&acc_0, 0, 0, comparray, vs, fin_res);
  1814. compute(&acc_1, 4, 4, comparray, vs, fin_res);
  1815. }
  1816. save_res(ii, jj, 0, fin_res);
  1817. save_res(ii+4, jj, 4, fin_res);
  1818. }
  1819. template<typename TA>
  1820. void tinyBLAS_Q0_PPC<TA>::KERNEL_8x8(int64_t ii, int64_t jj) {
  1821. vec_t vec_A[16], vec_B[16] = {0};
  1822. acc_t acc_0, acc_1, acc_2, acc_3;
  1823. acc_t acc_4, acc_5, acc_6, acc_7;
  1824. std::array<int, 8> comparray {};
  1825. vector float fin_res[16] = {0};
  1826. vector float vs[16] = {0};
  1827. bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
  1828. for (int l = 0; l < k; l++) {
  1829. __builtin_mma_xxsetaccz(&acc_0);
  1830. __builtin_mma_xxsetaccz(&acc_1);
  1831. __builtin_mma_xxsetaccz(&acc_2);
  1832. __builtin_mma_xxsetaccz(&acc_3);
  1833. if (std::is_same_v<TA, block_q4_0>) {
  1834. packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
  1835. } else {
  1836. packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
  1837. }
  1838. packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
  1839. for(int x = 0; x < 8; x++) {
  1840. __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
  1841. __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
  1842. __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
  1843. __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
  1844. }
  1845. for (int I = 0; I<8; I++) {
  1846. for (int J = 0; J<4; J++) {
  1847. *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
  1848. *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
  1849. }
  1850. }
  1851. if (!isAblock_q4) {
  1852. auto aoffset = A+(ii*lda)+l;
  1853. for (int i = 0; i < 8; i++) {
  1854. comparray[i] = 0;
  1855. int ca = 0;
  1856. auto *at = aoffset->qs;
  1857. for (int j = 0; j < 32; j++)
  1858. ca += (int)*at++;
  1859. comparray[i] = ca;
  1860. aoffset += lda;
  1861. }
  1862. }
  1863. compute(&acc_0, 0, 0, comparray, vs, fin_res);
  1864. compute(&acc_1, 4, 4, comparray, vs, fin_res);
  1865. compute(&acc_2, 0, 8, comparray, vs, fin_res);
  1866. compute(&acc_3, 4, 12, comparray, vs, fin_res);
  1867. }
  1868. save_res(ii, jj, 0, fin_res);
  1869. save_res(ii+4, jj, 4, fin_res);
  1870. save_res(ii, jj+4, 8, fin_res);
  1871. save_res(ii+4, jj+4, 12, fin_res);
  1872. }
  1873. template<typename TA>
  1874. void tinyBLAS_Q0_PPC<TA>::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
  1875. int64_t ytiles = (m - m0) / RM;
  1876. int64_t xtiles = (n - n0) / RN;
  1877. int64_t tiles = xtiles * ytiles;
  1878. int64_t duty = (tiles + nth - 1) / nth;
  1879. int64_t start = duty * ith;
  1880. int64_t end = start + duty;
  1881. vec_t vec_A[8] = {0}, vec_B[8] = {0};
  1882. vector signed int vec_C[4];
  1883. acc_t acc_0;
  1884. bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
  1885. if (end > tiles)
  1886. end = tiles;
  1887. for (int64_t job = start; job < end; ++job) {
  1888. int64_t ii = m0 + job / xtiles * RM;
  1889. int64_t jj = n0 + job % xtiles * RN;
  1890. std::array<int, 4> comparray{};
  1891. vector float res[4] = {0};
  1892. vector float fin_res[4] = {0};
  1893. vector float vs[4] = {0};
  1894. vector float CA[4] = {0};
  1895. __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
  1896. __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
  1897. for (int l = 0; l < k; l++) {
  1898. __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
  1899. __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
  1900. __builtin_mma_xxsetaccz(&acc_0);
  1901. if (isAblock_q4) {
  1902. packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
  1903. } else {
  1904. packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
  1905. }
  1906. packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
  1907. for(int x = 0; x < 8; x+=4) {
  1908. __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
  1909. __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
  1910. __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
  1911. __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
  1912. }
  1913. for (int I = 0; I<RM; I++) {
  1914. for (int J = 0; J<RN; J++) {
  1915. *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
  1916. }
  1917. }
  1918. __builtin_mma_disassemble_acc(vec_C, &acc_0);
  1919. if (!isAblock_q4) {
  1920. auto aoffset = A+(ii*lda)+l;
  1921. for (int i = 0; i < RM; i++) {
  1922. comparray[i] = 0;
  1923. int ca = 0;
  1924. auto *at = aoffset->qs;
  1925. for (int j = 0; j < 32; j++)
  1926. ca += (int)*at++;
  1927. comparray[i] = ca;
  1928. aoffset += lda;
  1929. }
  1930. }
  1931. for (int i = 0; i < RM; i++) {
  1932. CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
  1933. res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
  1934. fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
  1935. }
  1936. }
  1937. save_res(ii, jj, 0, fin_res, RM, RN);
  1938. }
  1939. }
  1940. template<typename TA>
  1941. template <int RM, int RN>
  1942. NOINLINE void tinyBLAS_Q0_PPC<TA>::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  1943. int64_t ytiles = (m - m0) / RM;
  1944. int64_t xtiles = (n - n0) / RN;
  1945. int64_t tiles = xtiles * ytiles;
  1946. int64_t duty = (tiles + nth - 1) / nth;
  1947. int64_t start = duty * ith;
  1948. int64_t end = start + duty;
  1949. if (end > tiles)
  1950. end = tiles;
  1951. for (int64_t job = start; job < end; ++job) {
  1952. int64_t ii = m0 + job / xtiles * RM;
  1953. int64_t jj = n0 + job % xtiles * RN;
  1954. this->kernel<RM, RN>(ii, jj);
  1955. }
  1956. }
  1957. template class tinyBLAS_Q0_PPC<block_q4_0>;
  1958. template class tinyBLAS_Q0_PPC<block_q8_0>;
  1959. class tinyBLAS_PPC {
  1960. public:
  1961. tinyBLAS_PPC(int64_t k,
  1962. const float * A, int64_t lda,
  1963. const float * B, int64_t ldb,
  1964. float * C, int64_t ldc,
  1965. int ith, int nth)
  1966. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  1967. }
  1968. void matmul(int64_t m, int64_t n) {
  1969. int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
  1970. if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
  1971. matmul_tiled(m, n, mc, nc, kc);
  1972. } else {
  1973. mnpack(0, m, 0, n);
  1974. }
  1975. }
  1976. private:
  1977. inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
  1978. vec_t vec_C[4];
  1979. __builtin_mma_disassemble_acc(vec_C, ACC);
  1980. for (int I = 0; I < 4; I++) {
  1981. for (int J = 0; J < 4; J++) {
  1982. *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
  1983. }
  1984. }
  1985. }
  1986. inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
  1987. vec_t vec_C[4];
  1988. __builtin_mma_disassemble_acc(vec_C, ACC);
  1989. for (int I = 0; I < 4; I++) {
  1990. for (int J = 0; J < 4; J++) {
  1991. float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
  1992. *c_ptr += *((float *)&vec_C[I]+J);
  1993. }
  1994. }
  1995. }
  1996. inline void vector_permute_store_4(vector float * src, float * vecOffset) {
  1997. vector float t1, t2, t3, t4, t5, t6, t7, t8;
  1998. t1 = vec_mergeh(src[0], src[1]);
  1999. t2 = vec_mergeh(src[2], src[3]);
  2000. t3 = vec_mergel(src[0], src[1]);
  2001. t4 = vec_mergel(src[2], src[3]);
  2002. t5 = vec_xxpermdi(t1, t2, 0);
  2003. t6 = vec_xxpermdi(t1, t2, 3);
  2004. t7 = vec_xxpermdi(t3, t4, 0);
  2005. t8 = vec_xxpermdi(t3, t4, 3);
  2006. vec_xst(t5, 0, vecOffset);
  2007. vec_xst(t6, 0, vecOffset + 4);
  2008. vec_xst(t7, 0, vecOffset + 8);
  2009. vec_xst(t8, 0, vecOffset + 12);
  2010. }
  2011. inline void vector_permute_store_8(vector float * src, float * vecOffset) {
  2012. vector float t1, t2, t3, t4, t5, t6, t7, t8;
  2013. t1 = vec_mergeh(src[0], src[1]);
  2014. t2 = vec_mergeh(src[2], src[3]);
  2015. t3 = vec_mergeh(src[4], src[5]);
  2016. t4 = vec_mergeh(src[6], src[7]);
  2017. t5 = vec_xxpermdi(t1, t2, 0);
  2018. t6 = vec_xxpermdi(t3, t4, 0);
  2019. t7 = vec_xxpermdi(t1, t2, 3);
  2020. t8 = vec_xxpermdi(t3, t4, 3);
  2021. vec_xst(t5, 0, vecOffset);
  2022. vec_xst(t6, 0, vecOffset + 4);
  2023. vec_xst(t7, 0, vecOffset + 8);
  2024. vec_xst(t8, 0, vecOffset + 12);
  2025. t1 = vec_mergel(src[0], src[1]);
  2026. t2 = vec_mergel(src[2], src[3]);
  2027. t3 = vec_mergel(src[4], src[5]);
  2028. t4 = vec_mergel(src[6], src[7]);
  2029. t5 = vec_xxpermdi(t1, t2, 0);
  2030. t6 = vec_xxpermdi(t3, t4, 0);
  2031. t7 = vec_xxpermdi(t1, t2, 3);
  2032. t8 = vec_xxpermdi(t3, t4, 3);
  2033. vec_xst(t5, 0, vecOffset + 16);
  2034. vec_xst(t6, 0, vecOffset + 20);
  2035. vec_xst(t7, 0, vecOffset + 24);
  2036. vec_xst(t8, 0, vecOffset + 28);
  2037. }
  2038. void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {
  2039. int64_t i, j;
  2040. float * aoffsets[8];
  2041. float * aoffset = NULL, * boffset = NULL;
  2042. __vector_pair arr[8];
  2043. vector float c[8][2] = {0};
  2044. vector float c1[8] = {0};
  2045. vector float c2[8] = {0};
  2046. aoffset = const_cast<float *>(a);
  2047. boffset = vec;
  2048. j = (rows >> 3);
  2049. if (j > 0) {
  2050. do {
  2051. aoffsets[0] = aoffset;
  2052. for (int it = 1; it < 8; it++)
  2053. aoffsets[it] = aoffsets[it-1] + lda;
  2054. aoffset += 8 * lda;
  2055. i = (cols >> 3);
  2056. if (i > 0) {
  2057. do {
  2058. for (int it = 0; it < 8; it++) {
  2059. arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
  2060. __builtin_vsx_disassemble_pair(c[it], &arr[it]);
  2061. c1[it] = c[it][0];
  2062. c2[it] = c[it][1];
  2063. }
  2064. vector_permute_store_8(c1, boffset);
  2065. vector_permute_store_8(c2, boffset + 32);
  2066. boffset += 64;
  2067. i--;
  2068. if (i > 0) {
  2069. for (int it = 0; it < 8; it++) {
  2070. aoffsets[it] = aoffsets[it] + 8;
  2071. }
  2072. }
  2073. } while(i > 0);
  2074. }
  2075. if (cols & 4) {
  2076. for (int it = 0; it < 8 ; it++)
  2077. c1[it] = vec_xl(0, aoffsets[it]);
  2078. vector_permute_store_8(c1, boffset);
  2079. }
  2080. j--;
  2081. } while(j > 0);
  2082. }
  2083. if (rows & 4) {
  2084. aoffsets[0] = aoffset;
  2085. for (int it = 1; it < 4; it++)
  2086. aoffsets[it] = aoffsets[it-1] + lda;
  2087. aoffset += 4 * lda;
  2088. i = (cols >> 3);
  2089. if (i > 0) {
  2090. do {
  2091. for (int it = 0; it < 4; it++) {
  2092. arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
  2093. __builtin_vsx_disassemble_pair(c[it], &arr[it]);
  2094. c1[it] = c[it][0];
  2095. c2[it] = c[it][1];
  2096. }
  2097. vector_permute_store_4(c1, boffset);
  2098. vector_permute_store_4(c2, boffset + 16);
  2099. for (int it = 0; it < 4; it++)
  2100. aoffsets[it] += 8 * lda;
  2101. boffset += 32;
  2102. i--;
  2103. } while(i > 0);
  2104. }
  2105. if (cols & 4) {
  2106. for (int it = 0; it < 4; it++)
  2107. c1[it] = vec_xl(0, aoffsets[it]);
  2108. vector_permute_store_4(c1, boffset);
  2109. }
  2110. }
  2111. if (rows & 3) {
  2112. aoffsets[0] = aoffset;
  2113. for (int it = 1; it < 3; it++)
  2114. aoffsets[it] = aoffsets[it-1] + lda;
  2115. if (cols & 4) {
  2116. for (int it = 0; it < 3; it++)
  2117. c1[it] = vec_xl(0, aoffsets[it]);
  2118. vector_permute_store_4(c1, boffset);
  2119. }
  2120. }
  2121. }
  2122. void KERNEL_4x4(int64_t ii, int64_t jj) {
  2123. vec_t vec_A[4], vec_B[4], vec_C[4];
  2124. acc_t acc_0;
  2125. __builtin_mma_xxsetaccz(&acc_0);
  2126. for (int l = 0; l < k; l += 4) {
  2127. packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
  2128. packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
  2129. __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
  2130. __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
  2131. __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
  2132. __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
  2133. }
  2134. save_acc(&acc_0, ii, jj);
  2135. }
  2136. void KERNEL_4x8(int64_t ii, int64_t jj) {
  2137. vec_t vec_A[4], vec_B[8], vec_C[4];
  2138. acc_t acc_0, acc_1;
  2139. __builtin_mma_xxsetaccz(&acc_0);
  2140. __builtin_mma_xxsetaccz(&acc_1);
  2141. for (int64_t l = 0; l < k; l += 4) {
  2142. packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
  2143. packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);
  2144. __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
  2145. __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
  2146. __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
  2147. __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
  2148. __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
  2149. __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
  2150. __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
  2151. __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
  2152. }
  2153. save_acc(&acc_0, ii, jj);
  2154. save_acc(&acc_1, ii, jj + 4);
  2155. }
  2156. void KERNEL_8x4(int64_t ii, int64_t jj) {
  2157. vec_t vec_A[8], vec_B[4], vec_C[4];
  2158. acc_t acc_0, acc_1;
  2159. __builtin_mma_xxsetaccz(&acc_0);
  2160. __builtin_mma_xxsetaccz(&acc_1);
  2161. for (int64_t l = 0; l < k; l += 4) {
  2162. packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);
  2163. packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
  2164. __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
  2165. __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
  2166. __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
  2167. __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
  2168. __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
  2169. __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
  2170. __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
  2171. __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
  2172. }
  2173. save_acc(&acc_0, ii, jj);
  2174. save_acc(&acc_1, ii + 4, jj);
  2175. }
  2176. void KERNEL_8x8(int64_t ii, int64_t jj) {
  2177. vec_t vec_A[16], vec_B[16], vec_C[4];
  2178. acc_t acc_0, acc_1, acc_2, acc_3;
  2179. __builtin_mma_xxsetaccz(&acc_0);
  2180. __builtin_mma_xxsetaccz(&acc_1);
  2181. __builtin_mma_xxsetaccz(&acc_2);
  2182. __builtin_mma_xxsetaccz(&acc_3);
  2183. for (int l = 0; l < k; l+=8) {
  2184. packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);
  2185. packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);
  2186. for(int x = 0; x < 16; x+=2) {
  2187. __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
  2188. __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);
  2189. __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);
  2190. __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);
  2191. }
  2192. }
  2193. save_acc(&acc_0, ii, jj);
  2194. save_acc(&acc_1, ii, jj + 4);
  2195. save_acc(&acc_2, ii + 4, jj);
  2196. save_acc(&acc_3, ii + 4, jj + 4);
  2197. }
  2198. inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {
  2199. for (int x = 0; x < 16; x += 2) {
  2200. __builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);
  2201. __builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);
  2202. __builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);
  2203. __builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);
  2204. __builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);
  2205. __builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);
  2206. __builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);
  2207. __builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);
  2208. }
  2209. }
  2210. void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
  2211. for (int64_t i = 0; i < mc; i += 16) {
  2212. int A_base_addr = (mc / 8) * (i / 8) * 16;
  2213. for (int64_t j = 0; j < nc; j += 8) {
  2214. int B_base_addr = (nc / 8) * (j / 8) * 16;
  2215. acc_t acc[8];
  2216. vec_t A0_block[16]; vec_t A1_block[16];
  2217. for (int x = 0; x < 8; x++)
  2218. __builtin_mma_xxsetaccz(&acc[x]);
  2219. for (int64_t l = 0; l < kc; l += 8) {
  2220. int A0_block_idx = A_base_addr + (l / 8) * 16;
  2221. int A1_block_idx = A0_block_idx + (mc / 8) * 16;
  2222. int B_block_idx = B_base_addr + (l / 8) * 16;
  2223. vec_t* A0_block = &vec_A[A0_block_idx];
  2224. vec_t* A1_block = &vec_A[A1_block_idx];
  2225. vec_t* B_block = &vec_B[B_block_idx];
  2226. MMA_16x8(A0_block, A1_block, B_block, acc);
  2227. }
  2228. if (kk == 0) {
  2229. save_acc(&acc[0], ii + i, jj + j);
  2230. save_acc(&acc[1], ii + i, jj + j + 4);
  2231. save_acc(&acc[2], ii + i + 4, jj + j);
  2232. save_acc(&acc[3], ii + i + 4, jj + j + 4);
  2233. save_acc(&acc[4], ii + i + 8, jj + j);
  2234. save_acc(&acc[5], ii + i + 8, jj + j + 4);
  2235. save_acc(&acc[6], ii + i + 12, jj + j);
  2236. save_acc(&acc[7], ii + i + 12, jj + j + 4);
  2237. } else {
  2238. add_save_acc(&acc[0], ii + i, jj + j);
  2239. add_save_acc(&acc[1], ii + i, jj + j + 4);
  2240. add_save_acc(&acc[2], ii + i + 4, jj + j);
  2241. add_save_acc(&acc[3], ii + i + 4, jj + j + 4);
  2242. add_save_acc(&acc[4], ii + i + 8, jj + j);
  2243. add_save_acc(&acc[5], ii + i + 8, jj + j + 4);
  2244. add_save_acc(&acc[6], ii + i + 12, jj + j);
  2245. add_save_acc(&acc[7], ii + i + 12, jj + j + 4);
  2246. }
  2247. }
  2248. }
  2249. }
  2250. void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
  2251. int64_t ytiles = m / mc;
  2252. int64_t xtiles = n / nc;
  2253. int64_t tiles = xtiles * ytiles;
  2254. int64_t duty = (tiles + nth - 1) / nth;
  2255. int64_t start = duty * ith;
  2256. int64_t end = start + duty;
  2257. if (end > tiles) {
  2258. end = tiles;
  2259. }
  2260. for (int64_t job = start; job < end; ++job) {
  2261. int64_t ii = (job / xtiles) * mc;
  2262. int64_t jj = (job % xtiles) * nc;
  2263. for (int64_t kk = 0; kk < k; kk += kc) {
  2264. vec_t A_pack[kc * mc / 4];
  2265. vec_t B_pack[kc * nc / 4];
  2266. packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);
  2267. packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);
  2268. KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);
  2269. }
  2270. }
  2271. }
  2272. void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  2273. int m_rem = MIN(m - m0, 8);
  2274. int n_rem = MIN(n - n0, 8);
  2275. int mc = 0, nc = 0;
  2276. if (m_rem >= 8 && n_rem >= 8) {
  2277. mc = 8;
  2278. nc = 8;
  2279. gemm<8, 8>(m0, m, n0, n);
  2280. } else if (m_rem >= 4 && n_rem >= 8) {
  2281. mc = 4;
  2282. nc = 8;
  2283. gemm<4, 8>(m0, m, n0, n);
  2284. } else if (m_rem >= 8 && n_rem >= 4) {
  2285. mc = 8;
  2286. nc = 4;
  2287. gemm<8, 4>(m0, m, n0, n);
  2288. } else if (m_rem >= 4 && n_rem >= 4) {
  2289. mc = 4;
  2290. nc = 4;
  2291. gemm<4, 4>(m0, m, n0, n);
  2292. } else {
  2293. mc = (m_rem >= 4) ? 4 : m_rem;
  2294. nc = (n_rem >= 4) ? 4 : n_rem;
  2295. if (mc == 0 || nc == 0)
  2296. return;
  2297. gemm_small(m0, m, n0, n, mc, nc);
  2298. }
  2299. int64_t mp = m0 + ((m - m0) / mc) * mc;
  2300. int64_t np = n0 + ((n - n0) / nc) * nc;
  2301. mnpack(mp, m, n0, np);
  2302. mnpack(m0, m, np, n);
  2303. }
  2304. void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
  2305. int64_t ytiles = (m - m0) / RM;
  2306. int64_t xtiles = (n - n0) / RN;
  2307. int64_t tiles = xtiles * ytiles;
  2308. int64_t duty = (tiles + nth - 1) / nth;
  2309. int64_t start = duty * ith;
  2310. int64_t end = start + duty;
  2311. if (end > tiles)
  2312. end = tiles;
  2313. for (int64_t job = start; job < end; ++job) {
  2314. int64_t ii = m0 + job / xtiles * RM;
  2315. int64_t jj = n0 + job % xtiles * RN;
  2316. vec_t vec_C[4];
  2317. acc_t acc_0;
  2318. __builtin_mma_xxsetaccz(&acc_0);
  2319. vec_t vec_A[4] = {0}, vec_B[4] = {0};
  2320. for (int l = 0; l < k; l += 4) {
  2321. /* 'GEMV Forwarding' concept is used in first two conditional loops.
  2322. * when one of the matrix has a single row/column, the elements are
  2323. * broadcasted, instead of using packing routine to prepack the
  2324. * matrix elements.
  2325. */
  2326. if (RM == 1) {
  2327. float * a = const_cast<float *>(A + (ii) * lda + l);
  2328. packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
  2329. vec_A[0] = (vec_t)vec_xl(0,a);
  2330. vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));
  2331. vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));
  2332. vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));
  2333. } else if (RN == 1) {
  2334. packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
  2335. float * b = const_cast<float *>(B + (jj) * ldb + l);
  2336. vec_B[0] = (vec_t)vec_xl(0,b);
  2337. vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));
  2338. vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));
  2339. vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));
  2340. } else {
  2341. packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
  2342. packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
  2343. }
  2344. __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
  2345. __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
  2346. __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
  2347. __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
  2348. }
  2349. __builtin_mma_disassemble_acc(vec_C, &acc_0);
  2350. for (int I = 0; I < RM; I++) {
  2351. for (int J = 0; J < RN; J++) {
  2352. *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
  2353. }
  2354. }
  2355. }
  2356. }
  2357. template<int RM, int RN>
  2358. inline void kernel(int64_t ii, int64_t jj) {
  2359. if constexpr(RM == 4 && RN == 4) {
  2360. KERNEL_4x4(ii, jj);
  2361. } else if constexpr(RM == 4 && RN == 8) {
  2362. KERNEL_4x8(ii, jj);
  2363. } else if constexpr(RM == 8 && RN == 4) {
  2364. KERNEL_8x4(ii, jj);
  2365. } else if constexpr(RM == 8 && RN == 8) {
  2366. KERNEL_8x8(ii, jj);
  2367. } else {
  2368. static_assert(false, "RN/RM values not supported");
  2369. }
  2370. }
  2371. template <int RM, int RN>
  2372. NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  2373. int64_t ytiles = (m - m0) / RM;
  2374. int64_t xtiles = (n - n0) / RN;
  2375. int64_t tiles = xtiles * ytiles;
  2376. int64_t duty = (tiles + nth - 1) / nth;
  2377. int64_t start = duty * ith;
  2378. int64_t end = start + duty;
  2379. if (end > tiles)
  2380. end = tiles;
  2381. for (int64_t job = start; job < end; ++job) {
  2382. int64_t ii = m0 + job / xtiles * RM;
  2383. int64_t jj = n0 + job % xtiles * RN;
  2384. kernel<RM, RN>(ii, jj);
  2385. }
  2386. }
  2387. const float * const A;
  2388. const float * const B;
  2389. float * C;
  2390. const int64_t k;
  2391. const int64_t lda;
  2392. const int64_t ldb;
  2393. const int64_t ldc;
  2394. const int ith;
  2395. const int nth;
  2396. };
  2397. #endif
  2398. } // namespace
  2399. /**
  2400. * Performs optimized matrix multiplication on CPU.
  2401. *
  2402. * This subroutine may compute C = Aᵀ * B with column major ordering.
  2403. * Despite its name, this isn't a generalized implementation. Work is
  2404. * only performed when a handwritten kernel is written and available.
  2405. * Otherwise the caller should fall back to a general matmul routine.
  2406. *
  2407. * For example, for single-threaded single-precision GEMM you can say
  2408. *
  2409. * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
  2410. * 0, 1,
  2411. * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
  2412. *
  2413. * @param m is rows in `A` and `C`
  2414. * @param n is cols in `B` and `C`
  2415. * @param k is cols in `A` and rows in `B`
  2416. * @param A is first input matrix (always transposed)
  2417. * @param lda is row stride of `A`
  2418. * @param B is second input matrix (never transposed)
  2419. * @param ldb is row stride of `B`
  2420. * @param C is input/output array of output matrices
  2421. * @param ldc is row stride of `C`
  2422. * @param ith is thread id (must be less than `nth`)
  2423. * @param nth is number of threads (must be greater than zero)
  2424. * @param Atype is GGML data type of `A`
  2425. * @param Btype is GGML data type of `B`
  2426. * @param Ctype is GGML data type of `C`
  2427. * @return true if this function was able to service the matmul request
  2428. */
  2429. bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
  2430. const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
  2431. int64_t ldc, int Atype, int Btype, int Ctype) {
  2432. assert(m >= 0);
  2433. assert(n >= 0);
  2434. assert(k >= 0);
  2435. assert(lda >= k);
  2436. assert(ldb >= k);
  2437. assert(ldc >= m);
  2438. assert(params->nth > 0);
  2439. assert(params->ith < params->nth);
  2440. // only enable sgemm for prompt processing
  2441. #if !defined(__MMA__)
  2442. if (n < 2)
  2443. return false;
  2444. #endif
  2445. if (Ctype != GGML_TYPE_F32)
  2446. return false;
  2447. switch (Atype) {
  2448. case GGML_TYPE_F32: {
  2449. if (Btype != GGML_TYPE_F32)
  2450. return false;
  2451. #if defined(__AVX512F__)
  2452. tinyBLAS<16, __m512, __m512, float, float, float> tb{ params,
  2453. k, (const float *)A, lda,
  2454. (const float *)B, ldb,
  2455. (float *)C, ldc};
  2456. return tb.matmul(m, n);
  2457. #elif defined(__AVX__) || defined(__AVX2__)
  2458. tinyBLAS<8, __m256, __m256, float, float, float> tb{ params,
  2459. k, (const float *)A, lda,
  2460. (const float *)B, ldb,
  2461. (float *)C, ldc};
  2462. return tb.matmul(m, n);
  2463. #elif defined(__ARM_NEON)
  2464. if (n < 4)
  2465. return false;
  2466. tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
  2467. k, (const float *)A, lda,
  2468. (const float *)B, ldb,
  2469. (float *)C, ldc};
  2470. return tb.matmul(m, n);
  2471. #elif defined(__VXE__) || defined(__VXE2__)
  2472. if (n < 4)
  2473. return false;
  2474. tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
  2475. k, (const float *)A, lda,
  2476. (const float *)B, ldb,
  2477. (float *)C, ldc};
  2478. return tb.matmul(m, n);
  2479. #elif defined(__MMA__)
  2480. if (k % 8)
  2481. return false;
  2482. tinyBLAS_PPC tb{
  2483. k, (const float *)A, lda,
  2484. (const float *)B, ldb,
  2485. (float *)C, ldc,
  2486. params->ith, params->nth};
  2487. tb.matmul(m, n);
  2488. return true;
  2489. #else
  2490. return false;
  2491. #endif
  2492. }
  2493. case GGML_TYPE_BF16: {
  2494. #if defined(__AVX512BF16__)
  2495. if (Btype == GGML_TYPE_BF16) {
  2496. tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
  2497. (const ggml_bf16_t *)A, lda,
  2498. (const ggml_bf16_t *)B, ldb,
  2499. (float *)C, ldc};
  2500. return tb.matmul(m, n);
  2501. }
  2502. #elif defined(__AVX512F__)
  2503. if (Btype == GGML_TYPE_BF16) {
  2504. tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
  2505. (const ggml_bf16_t *)A, lda,
  2506. (const ggml_bf16_t *)B, ldb,
  2507. (float *)C, ldc};
  2508. return tb.matmul(m, n);
  2509. }
  2510. #elif defined(__AVX2__)
  2511. if (Btype == GGML_TYPE_BF16) {
  2512. tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
  2513. (const ggml_bf16_t *)A, lda,
  2514. (const ggml_bf16_t *)B, ldb,
  2515. (float *)C, ldc};
  2516. return tb.matmul(m, n);
  2517. }
  2518. #elif defined(__MMA__)
  2519. if ((k % 8))
  2520. return false;
  2521. if(Btype == GGML_TYPE_BF16) {
  2522. tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
  2523. (const ggml_bf16_t *)A, lda,
  2524. (const ggml_bf16_t *)B, ldb,
  2525. (float *)C, ldc,
  2526. params->ith, params->nth};
  2527. tb.matmul(m, n);
  2528. return true;
  2529. }
  2530. #endif
  2531. return false;
  2532. }
  2533. case GGML_TYPE_F16: {
  2534. #if defined(__AVX512F__)
  2535. if (Btype == GGML_TYPE_F16) {
  2536. tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
  2537. (const ggml_fp16_t *)A, lda,
  2538. (const ggml_fp16_t *)B, ldb,
  2539. (float *)C, ldc};
  2540. return tb.matmul(m, n);
  2541. }
  2542. #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
  2543. if (Btype == GGML_TYPE_F16) {
  2544. tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
  2545. (const ggml_fp16_t *)A, lda,
  2546. (const ggml_fp16_t *)B, ldb,
  2547. (float *)C, ldc};
  2548. return tb.matmul(m, n);
  2549. }
  2550. #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
  2551. if (n < 8)
  2552. return false;
  2553. if (Btype == GGML_TYPE_F16) {
  2554. tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
  2555. k, (const ggml_fp16_t *)A, lda,
  2556. (const ggml_fp16_t *)B, ldb,
  2557. (float *)C, ldc};
  2558. return tb.matmul(m, n);
  2559. }
  2560. #elif defined(__ARM_NEON) && !defined(_MSC_VER)
  2561. if (Btype == GGML_TYPE_F32) {
  2562. tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ params,
  2563. k, (const ggml_fp16_t *)A, lda,
  2564. (const float *)B, ldb,
  2565. (float *)C, ldc};
  2566. return tb.matmul(m, n);
  2567. }
  2568. #elif defined(__VXE__) || defined(__VXE2__)
  2569. if (n < 4)
  2570. return false;
  2571. if (Btype == GGML_TYPE_F16) {
  2572. tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
  2573. k, (const ggml_fp16_t *)A, lda,
  2574. (const ggml_fp16_t *)B, ldb,
  2575. (float *)C, ldc};
  2576. return tb.matmul(m, n);
  2577. }
  2578. #endif
  2579. return false;
  2580. }
  2581. case GGML_TYPE_Q8_0: {
  2582. if (Btype != GGML_TYPE_Q8_0)
  2583. return false;
  2584. #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
  2585. tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
  2586. k, (const block_q8_0 *)A, lda,
  2587. (const block_q8_0 *)B, ldb,
  2588. (float *)C, ldc,
  2589. params->ith, params->nth};
  2590. tb.matmul(m, n);
  2591. return true;
  2592. #elif defined(__ARM_FEATURE_DOTPROD)
  2593. tinyBLAS_Q0_ARM<block_q8_0> tb{
  2594. k, (const block_q8_0 *)A, lda,
  2595. (const block_q8_0 *)B, ldb,
  2596. (float *)C, ldc,
  2597. params->ith, params->nth};
  2598. tb.matmul(m, n);
  2599. return true;
  2600. #elif defined(__MMA__)
  2601. //TO-DO: Remove this condition once gemv forwarding is enabled.
  2602. if (n < 8 && n != 4)
  2603. return false;
  2604. if (m < 8 && m != 4)
  2605. return false;
  2606. tinyBLAS_Q0_PPC<block_q8_0> tb{
  2607. k, (const block_q8_0 *)A, lda,
  2608. (const block_q8_0 *)B, ldb,
  2609. (float *)C, ldc,
  2610. params->ith, params->nth};
  2611. tb.matmul(m, n);
  2612. return true;
  2613. #else
  2614. return false;
  2615. #endif
  2616. }
  2617. case GGML_TYPE_Q4_0: {
  2618. if (Btype != GGML_TYPE_Q8_0)
  2619. return false;
  2620. #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
  2621. tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
  2622. k, (const block_q4_0 *)A, lda,
  2623. (const block_q8_0 *)B, ldb,
  2624. (float *)C, ldc,
  2625. params->ith, params->nth};
  2626. tb.matmul(m, n);
  2627. return true;
  2628. #elif defined(__ARM_FEATURE_DOTPROD)
  2629. tinyBLAS_Q0_ARM<block_q4_0> tb{
  2630. k, (const block_q4_0 *)A, lda,
  2631. (const block_q8_0 *)B, ldb,
  2632. (float *)C, ldc,
  2633. params->ith, params->nth};
  2634. tb.matmul(m, n);
  2635. return true;
  2636. #elif defined(__MMA__)
  2637. //TO-DO: Remove this condition once gemv forwarding is enabled.
  2638. if (n < 8 && n != 4)
  2639. return false;
  2640. if (m < 8 && m != 4)
  2641. return false;
  2642. tinyBLAS_Q0_PPC<block_q4_0> tb{
  2643. k, (const block_q4_0 *)A, lda,
  2644. (const block_q8_0 *)B, ldb,
  2645. (float *)C, ldc,
  2646. params->ith, params->nth};
  2647. tb.matmul(m, n);
  2648. return true;
  2649. #else
  2650. return false;
  2651. #endif
  2652. }
  2653. case GGML_TYPE_Q5_0: {
  2654. if (Btype != GGML_TYPE_Q8_0)
  2655. return false;
  2656. #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
  2657. tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{
  2658. k, (const block_q5_0 *)A, lda,
  2659. (const block_q8_0 *)B, ldb,
  2660. (float *)C, ldc,
  2661. params->ith, params->nth};
  2662. tb.matmul(m, n);
  2663. return true;
  2664. #else
  2665. return false;
  2666. #endif
  2667. }
  2668. case GGML_TYPE_IQ4_NL: {
  2669. if (Btype != GGML_TYPE_Q8_0)
  2670. return false;
  2671. #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
  2672. tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float> tb{
  2673. k, (const block_iq4_nl *)A, lda,
  2674. (const block_q8_0 *)B, ldb,
  2675. (float *)C, ldc,
  2676. params->ith, params->nth};
  2677. tb.matmul(m, n);
  2678. return true;
  2679. #else
  2680. return false;
  2681. #endif
  2682. }
  2683. default:
  2684. return false;
  2685. }
  2686. (void)params;
  2687. (void)m;
  2688. (void)n;
  2689. (void)k;
  2690. (void)A;
  2691. (void)lda;
  2692. (void)B;
  2693. (void)ldb;
  2694. (void)C;
  2695. (void)ldc;
  2696. (void)Atype;
  2697. (void)Btype;
  2698. (void)Ctype;
  2699. }