mmq.cuh 95 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610
  1. #pragma once
  2. #include "common.cuh"
  3. #include "vecdotq.cuh"
  4. #include "mma.cuh"
  5. #include <climits>
  6. #include <cstdint>
  7. #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
  8. typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
  9. typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0);
  10. typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
  11. struct block_q8_1_mmq {
  12. half2 ds[4];
  13. int8_t qs[4*QK8_1];
  14. };
  15. static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
  16. static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
  17. struct tile_x_sizes {
  18. int qs;
  19. int dm;
  20. int sc;
  21. };
  22. static constexpr int get_mmq_x_max_host(const int cc) {
  23. return int8_mma_available(cc) ? 128 :
  24. #ifdef GGML_CUDA_FORCE_MMQ
  25. cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64;
  26. #else
  27. cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
  28. #endif // GGML_CUDA_FORCE_MMQ
  29. }
  30. static constexpr __device__ int get_mmq_x_max_device() {
  31. #ifdef INT8_MMA_AVAILABLE
  32. return 128;
  33. #else // INT8_MMA_AVAILABLE
  34. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  35. return 128;
  36. #else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  37. #if __CUDA_ARCH__ >= CC_VOLTA
  38. #ifdef GGML_CUDA_FORCE_MMQ
  39. return MMQ_DP4A_MAX_BATCH_SIZE;
  40. #else // GGML_CUDA_FORCE_MMQ
  41. return 128;
  42. #endif // GGML_CUDA_FORCE_MMQ
  43. #else // __CUDA_ARCH__ >= CC_VOLTA
  44. return 64;
  45. #endif // __CUDA_ARCH__ >= CC_VOLTA
  46. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  47. #endif // INT8_MMA_AVAILABLE
  48. }
  49. static constexpr int get_mmq_y_host(const int cc) {
  50. return int8_mma_available(cc) || cc >= CC_VOLTA ? 128 : 64;
  51. }
  52. static constexpr __device__ int get_mmq_y_device() {
  53. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  54. return 128;
  55. #else
  56. #if __CUDA_ARCH__ >= CC_VOLTA
  57. return 128;
  58. #else
  59. return 64;
  60. #endif // __CUDA_ARCH__ >= CC_VOLTA
  61. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  62. }
  63. #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
  64. #define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
  65. #define MMQ_DP4A_TXS_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0}
  66. #define MMQ_DP4A_TXS_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0}
  67. #define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0}
  68. #define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
  69. #define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4}
  70. #define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
  71. #define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
  72. #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
  73. static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
  74. return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
  75. type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
  76. type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
  77. type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
  78. type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
  79. type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
  80. type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
  81. type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
  82. type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
  83. type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
  84. tile_x_sizes{0, 0, 0};
  85. }
  86. #define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4)
  87. #define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1 + 4)
  88. #define MMQ_MMA_TILE_X_K_Q5_0 (2*WARP_SIZE + WARP_SIZE/QI5_0 + 4)
  89. #define MMQ_MMA_TILE_X_K_Q5_1 (2*WARP_SIZE + WARP_SIZE/QI5_1 + 4)
  90. #define MMQ_MMA_TILE_X_K_Q8_0 (1*WARP_SIZE + WARP_SIZE/QI8_0 + 0)
  91. #define MMQ_MMA_TILE_X_K_Q2_K (1*WARP_SIZE + WARP_SIZE + 4)
  92. #define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/QI3_K + WARP_SIZE/4 + 2)
  93. #define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7)
  94. #define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7)
  95. #define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
  96. static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding.");
  97. static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding.");
  98. static_assert(MMQ_MMA_TILE_X_K_Q5_0 % 8 == 4, "Wrong padding.");
  99. static_assert(MMQ_MMA_TILE_X_K_Q5_1 % 8 == 4, "Wrong padding.");
  100. static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
  101. static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
  102. static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
  103. static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
  104. static_assert(MMQ_MMA_TILE_X_K_Q5_K % 8 == 4, "Wrong padding.");
  105. static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
  106. static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
  107. return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
  108. type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
  109. type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
  110. type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
  111. type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
  112. type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
  113. type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
  114. type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
  115. type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
  116. type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
  117. 0;
  118. }
  119. #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
  120. #define MMQ_NWARPS 8
  121. static int mmq_get_granularity_host(const int mmq_x, const int cc) {
  122. return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
  123. }
  124. #ifdef INT8_MMA_AVAILABLE
  125. static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
  126. return mmq_x >= 48 ? 16 : 8;
  127. }
  128. #else
  129. static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
  130. return 8;
  131. }
  132. #endif // INT8_MMA_AVAILABLE
  133. // ------------------------------------------------------------
  134. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
  135. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  136. #ifdef INT8_MMA_AVAILABLE
  137. int * x_qs = (int *) x_tile;
  138. float * x_df = (float *) (x_qs + WARP_SIZE);
  139. #else
  140. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
  141. int * x_qs = (int *) x_tile;
  142. float * x_df = (float *) (x_qs + txs.qs);
  143. #endif // INT8_MMA_AVAILABLE
  144. const int kbx = threadIdx.x / QI4_0;
  145. const int kqsx = threadIdx.x % QI4_0;
  146. #pragma unroll
  147. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  148. int i = i0 + threadIdx.y;
  149. if (need_check) {
  150. i = min(i, i_max);
  151. }
  152. const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
  153. #ifdef INT8_MMA_AVAILABLE
  154. x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
  155. #else
  156. x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
  157. #endif // INT8_MMA_AVAILABLE
  158. }
  159. const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
  160. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  161. #pragma unroll
  162. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
  163. int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row;
  164. if (need_check) {
  165. i = min(i, i_max);
  166. }
  167. const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
  168. #ifdef INT8_MMA_AVAILABLE
  169. x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + kbxd] = bxi->d;
  170. #else
  171. x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
  172. #endif // INT8_MMA_AVAILABLE
  173. }
  174. }
  175. template <int mmq_x, int mmq_y, int nwarps>
  176. static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
  177. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  178. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
  179. const int * x_qs = (const int *) x;
  180. const float * x_df = (const float *) x_qs + txs.qs;
  181. const int * y_qs = (const int *) y + 4;
  182. const half2 * y_ds = (const half2 *) y;
  183. #pragma unroll
  184. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  185. const int j = j0 + threadIdx.y;
  186. #pragma unroll
  187. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  188. const int i = i0 + threadIdx.x;
  189. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  190. int u[2*VDR_Q4_0_Q8_1_MMQ];
  191. #pragma unroll
  192. for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
  193. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  194. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_0) % WARP_SIZE];
  195. }
  196. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
  197. (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
  198. y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  199. }
  200. }
  201. }
  202. template <int mmq_x, int mmq_y, int nwarps>
  203. static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
  204. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  205. #ifdef INT8_MMA_AVAILABLE
  206. typedef mma_int_A_I16K8 mma_A;
  207. typedef mma_int_B_J8K8 mma_B;
  208. typedef mma_int_C_I16J8 mma_C;
  209. constexpr int granularity = mmq_get_granularity_device(mmq_x);
  210. constexpr int rows_per_warp = 2 * granularity;
  211. constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
  212. y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
  213. const int * x_qs = (const int *) x;
  214. const float * x_df = (const float *) x_qs + WARP_SIZE;
  215. const int * y_qs = (const int *) y + 4;
  216. const half2 * y_ds = (const half2 *) y;
  217. mma_A A[ntx];
  218. float dA[ntx][mma_C::ne/2];
  219. const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
  220. #pragma unroll
  221. for (int n = 0; n < ntx; ++n) {
  222. #pragma unroll
  223. for (int l = 0; l < mma_A::ne; ++l) {
  224. const int i = i0 + n*mma_A::I + mma_A::get_i(l);
  225. const int k = k0 + mma_A::get_k(l) % QI4_0;
  226. const int shift = 4*(mma_A::get_k(l) / QI4_0);
  227. A[n].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808);
  228. }
  229. #pragma unroll
  230. for (int l = 0; l < mma_C::ne/2; ++l) {
  231. const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
  232. dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/QI4_0];
  233. }
  234. }
  235. #pragma unroll
  236. for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
  237. mma_B B;
  238. float dB[mma_C::ne/2];
  239. B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
  240. #pragma unroll
  241. for (int l = 0; l < mma_C::ne/2; ++l) {
  242. const int j = j0 + mma_C::get_j(l);
  243. dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  244. }
  245. #pragma unroll
  246. for (int n = 0; n < ntx; ++n) {
  247. mma_C C;
  248. C.mma_K8(A[n], B);
  249. #pragma unroll
  250. for (int l = 0; l < mma_C::ne; ++l) {
  251. sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l];
  252. }
  253. }
  254. }
  255. #else
  256. GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
  257. NO_DEVICE_CODE;
  258. #endif // INT8_MMA_AVAILABLE
  259. }
  260. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
  261. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  262. #ifdef INT8_MMA_AVAILABLE
  263. int * x_qs = (int *) x_tile;
  264. half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
  265. #else
  266. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
  267. int * x_qs = (int *) x_tile;
  268. half2 * x_dm = (half2 *) (x_qs + txs.qs);
  269. #endif // INT8_MMA_AVAILABLE
  270. const int kbx = threadIdx.x / QI4_1;
  271. const int kqsx = threadIdx.x % QI4_1;
  272. #pragma unroll
  273. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  274. int i = i0 + threadIdx.y;
  275. if (need_check) {
  276. i = min(i, i_max);
  277. }
  278. const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
  279. #ifdef INT8_MMA_AVAILABLE
  280. x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  281. #else
  282. x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  283. #endif // INT8_MMA_AVAILABLE
  284. }
  285. const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
  286. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  287. #pragma unroll
  288. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
  289. int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row;
  290. if (need_check) {
  291. i = min(i, i_max);
  292. }
  293. const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
  294. #ifdef INT8_MMA_AVAILABLE
  295. x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + kbxd] = bxi->dm;
  296. #else
  297. x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
  298. #endif // INT8_MMA_AVAILABLE
  299. }
  300. }
  301. template <int mmq_x, int mmq_y, int nwarps>
  302. static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
  303. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  304. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
  305. const int * x_qs = (const int *) x;
  306. const half2 * x_dm = (const half2 *) x_qs + txs.qs;
  307. const int * y_qs = (const int *) y + 4;
  308. const half2 * y_ds = (const half2 *) y;
  309. #pragma unroll
  310. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  311. const int j = j0 + threadIdx.y;
  312. #pragma unroll
  313. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  314. const int i = i0 + threadIdx.x;
  315. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  316. int u[2*VDR_Q4_1_Q8_1_MMQ];
  317. #pragma unroll
  318. for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
  319. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  320. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_1) % WARP_SIZE];
  321. }
  322. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
  323. (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
  324. y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  325. }
  326. }
  327. }
  328. template <int mmq_x, int mmq_y, int nwarps>
  329. static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
  330. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  331. #ifdef INT8_MMA_AVAILABLE
  332. typedef mma_int_A_I16K8 mma_A;
  333. typedef mma_int_A_I16K4 mma_A_K4;
  334. typedef mma_int_B_J8K8 mma_B;
  335. typedef mma_int_C_I16J8 mma_C;
  336. constexpr int granularity = mmq_get_granularity_device(mmq_x);
  337. constexpr int rows_per_warp = 2 * granularity;
  338. constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
  339. y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
  340. const int * x_qs = (const int *) x;
  341. const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
  342. const int * y_qs = (const int *) y + 4;
  343. const half2 * y_ds = (const half2 *) y;
  344. mma_A A[ntx];
  345. half2 dmA[ntx][mma_C::ne/2];
  346. const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
  347. #pragma unroll
  348. for (int n = 0; n < ntx; ++n) {
  349. ((mma_A_K4 *) &A[n])[0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0, MMQ_MMA_TILE_X_K_Q4_1);
  350. A[n].x[2] = (A[n].x[0] >> 4) & 0x0F0F0F0F;
  351. A[n].x[3] = (A[n].x[1] >> 4) & 0x0F0F0F0F;
  352. A[n].x[0] &= 0x0F0F0F0F;
  353. A[n].x[1] &= 0x0F0F0F0F;
  354. #pragma unroll
  355. for (int l = 0; l < mma_C::ne/2; ++l) {
  356. const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
  357. dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/QI4_1];
  358. }
  359. }
  360. #pragma unroll
  361. for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
  362. mma_B B;
  363. half2 dsB[mma_C::ne/2];
  364. B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
  365. #pragma unroll
  366. for (int l = 0; l < mma_C::ne/2; ++l) {
  367. const int j = j0 + mma_C::get_j(l);
  368. dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
  369. }
  370. #pragma unroll
  371. for (int n = 0; n < ntx; ++n) {
  372. mma_C C;
  373. C.mma_K8(A[n], B);
  374. #pragma unroll
  375. for (int l = 0; l < mma_C::ne; ++l) {
  376. const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2];
  377. sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
  378. }
  379. }
  380. }
  381. #else
  382. GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
  383. NO_DEVICE_CODE;
  384. #endif // INT8_MMA_AVAILABLE
  385. }
  386. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
  387. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  388. #ifdef INT8_MMA_AVAILABLE
  389. int * x_qs = (int *) x_tile;
  390. float * x_df = (float *) (x_qs + WARP_SIZE*2);
  391. #else
  392. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
  393. int * x_qs = (int *) x_tile;
  394. float * x_df = (float *) (x_qs + txs.qs);
  395. #endif // INT8_MMA_AVAILABLE
  396. const int kbx = threadIdx.x / QI5_0;
  397. const int kqsx = threadIdx.x % QI5_0;
  398. #pragma unroll
  399. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  400. int i = i0 + threadIdx.y;
  401. if (need_check) {
  402. i = min(i, i_max);
  403. }
  404. const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
  405. const int ql = get_int_from_uint8(bxi->qs, kqsx);
  406. const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
  407. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  408. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  409. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  410. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  411. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  412. qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
  413. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  414. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  415. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  416. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  417. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  418. qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
  419. #ifdef INT8_MMA_AVAILABLE
  420. x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
  421. x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
  422. #else
  423. x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
  424. x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
  425. #endif // INT8_MMA_AVAILABLE
  426. }
  427. const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
  428. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  429. #pragma unroll
  430. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
  431. int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row;
  432. if (need_check) {
  433. i = min(i, i_max);
  434. }
  435. const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
  436. #ifdef INT8_MMA_AVAILABLE
  437. x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = bxi->d;
  438. #else
  439. x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
  440. #endif // INT8_MMA_AVAILABLE
  441. }
  442. }
  443. template <int mmq_x, int mmq_y, int nwarps>
  444. static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
  445. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  446. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
  447. const int * x_qs = (const int *) x;
  448. const float * x_df = (const float *) x_qs + txs.qs;
  449. const int * y_qs = (const int *) y + 4;
  450. const float * y_df = (const float *) y;
  451. #pragma unroll
  452. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  453. const int j = j0 + threadIdx.y;
  454. #pragma unroll
  455. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  456. const int i = i0 + threadIdx.x;
  457. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
  458. (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE],
  459. x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  460. }
  461. }
  462. }
  463. template <int mmq_x, int mmq_y, int nwarps>
  464. static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
  465. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  466. #ifdef INT8_MMA_AVAILABLE
  467. typedef mma_int_A_I16K8 mma_A;
  468. typedef mma_int_B_J8K8 mma_B;
  469. typedef mma_int_C_I16J8 mma_C;
  470. constexpr int granularity = mmq_get_granularity_device(mmq_x);
  471. constexpr int rows_per_warp = 2 * granularity;
  472. constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
  473. y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
  474. const int * x_qs = (const int *) x;
  475. const float * x_df = (const float *) x_qs + WARP_SIZE*2;
  476. const int * y_qs = (const int *) y + 4;
  477. const float * y_df = (const float *) y;
  478. mma_A A[ntx];
  479. float dA[ntx][mma_C::ne/2];
  480. const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
  481. #pragma unroll
  482. for (int n = 0; n < ntx; ++n) {
  483. A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_0 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_0);
  484. #pragma unroll
  485. for (int l = 0; l < mma_C::ne/2; ++l) {
  486. const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I;
  487. dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + k0/QI5_0];
  488. }
  489. }
  490. #pragma unroll
  491. for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
  492. mma_B B;
  493. float dB[mma_C::ne/2];
  494. B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
  495. #pragma unroll
  496. for (int l = 0; l < mma_C::ne/2; ++l) {
  497. const int j = j0 + mma_C::get_j(l);
  498. dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
  499. }
  500. #pragma unroll
  501. for (int n = 0; n < ntx; ++n) {
  502. mma_C C;
  503. C.mma_K8(A[n], B);
  504. #pragma unroll
  505. for (int l = 0; l < mma_C::ne; ++l) {
  506. sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l];
  507. }
  508. }
  509. }
  510. #else
  511. GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
  512. NO_DEVICE_CODE;
  513. #endif // INT8_MMA_AVAILABLE
  514. }
  515. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
  516. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  517. #ifdef INT8_MMA_AVAILABLE
  518. int * x_qs = (int *) x_tile;
  519. half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
  520. #else
  521. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
  522. int * x_qs = (int *) x_tile;
  523. half2 * x_dm = (half2 *) (x_qs + txs.qs);
  524. #endif // INT8_MMA_AVAILABLE
  525. const int kbx = threadIdx.x / QI5_1;
  526. const int kqsx = threadIdx.x % QI5_1;
  527. #pragma unroll
  528. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  529. int i = i0 + threadIdx.y;
  530. if (need_check) {
  531. i = min(i, i_max);
  532. }
  533. const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
  534. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  535. const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
  536. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  537. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  538. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  539. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  540. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  541. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  542. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  543. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  544. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  545. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  546. #ifdef INT8_MMA_AVAILABLE
  547. x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
  548. x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
  549. #else
  550. x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
  551. x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
  552. #endif // INT8_MMA_AVAILABLE
  553. }
  554. const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
  555. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  556. #pragma unroll
  557. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
  558. int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row;
  559. if (need_check) {
  560. i = min(i, i_max);
  561. }
  562. const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
  563. #ifdef INT8_MMA_AVAILABLE
  564. x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + kbxd] = bxi->dm;
  565. #else
  566. x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
  567. #endif // INT8_MMA_AVAILABLE
  568. }
  569. }
  570. template <int mmq_x, int mmq_y, int nwarps>
  571. static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
  572. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  573. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
  574. const int * x_qs = (const int *) x;
  575. const half2 * x_dm = (const half2 *) x_qs + txs.qs;
  576. const int * y_qs = (const int *) y + 4;
  577. const half2 * y_ds = (const half2 *) y;
  578. #pragma unroll
  579. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  580. const int j = j0 + threadIdx.y;
  581. #pragma unroll
  582. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  583. const int i = i0 + threadIdx.x;
  584. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
  585. (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE],
  586. x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  587. }
  588. }
  589. }
  590. template <int mmq_x, int mmq_y, int nwarps>
  591. static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
  592. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  593. #ifdef INT8_MMA_AVAILABLE
  594. typedef mma_int_A_I16K8 mma_A;
  595. typedef mma_int_B_J8K8 mma_B;
  596. typedef mma_int_C_I16J8 mma_C;
  597. constexpr int granularity = mmq_get_granularity_device(mmq_x);
  598. constexpr int rows_per_warp = 2 * granularity;
  599. constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
  600. y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
  601. const int * x_qs = (const int *) x;
  602. const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
  603. const int * y_qs = (const int *) y + 4;
  604. const half2 * y_ds = (const half2 *) y;
  605. mma_A A[ntx];
  606. half2 dmA[ntx][mma_C::ne/2];
  607. const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
  608. #pragma unroll
  609. for (int n = 0; n < ntx; ++n) {
  610. A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_1 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_1);
  611. #pragma unroll
  612. for (int l = 0; l < mma_C::ne/2; ++l) {
  613. const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I;
  614. dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + k0/QI5_1];
  615. }
  616. }
  617. #pragma unroll
  618. for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
  619. mma_B B;
  620. half2 dsB[mma_C::ne/2];
  621. B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
  622. #pragma unroll
  623. for (int l = 0; l < mma_C::ne/2; ++l) {
  624. const int j = j0 + mma_C::get_j(l);
  625. dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
  626. }
  627. #pragma unroll
  628. for (int n = 0; n < ntx; ++n) {
  629. mma_C C;
  630. C.mma_K8(A[n], B);
  631. #pragma unroll
  632. for (int l = 0; l < mma_C::ne; ++l) {
  633. const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2];
  634. sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
  635. }
  636. }
  637. }
  638. #else
  639. GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
  640. NO_DEVICE_CODE;
  641. #endif // INT8_MMA_AVAILABLE
  642. }
  643. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
  644. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  645. #ifdef INT8_MMA_AVAILABLE
  646. int * x_qs = (int *) x_tile;
  647. float * x_df = (float *) (x_tile + WARP_SIZE);
  648. #else
  649. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
  650. int * x_qs = (int *) x_tile;
  651. float * x_df = (float *) (x_qs + txs.qs);
  652. #endif // INT8_MMA_AVAILABLE
  653. const int kbx = threadIdx.x / QI8_0;
  654. const int kqsx = threadIdx.x % QI8_0;
  655. #pragma unroll
  656. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  657. int i = i0 + threadIdx.y;
  658. if (need_check) {
  659. i = min(i, i_max);
  660. }
  661. const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
  662. #ifdef INT8_MMA_AVAILABLE
  663. x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
  664. #else
  665. x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
  666. #endif // INT8_MMA_AVAILABLE
  667. }
  668. const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
  669. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  670. #pragma unroll
  671. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
  672. int i = i0 + threadIdx.y * QI8_0 + threadIdx.x / blocks_per_tile_x_row;
  673. if (need_check) {
  674. i = min(i, i_max);
  675. }
  676. const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
  677. #ifdef INT8_MMA_AVAILABLE
  678. x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
  679. #else
  680. x_df[i*(WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
  681. #endif // INT8_MMA_AVAILABLE
  682. }
  683. }
  684. template <int mmq_x, int mmq_y, int nwarps>
  685. static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
  686. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  687. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
  688. const int * x_qs = (const int *) x;
  689. const float * x_df = (const float *) x_qs + txs.qs;
  690. const int * y_qs = (const int *) y + 4;
  691. const float * y_df = (const float *) y;
  692. #pragma unroll
  693. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  694. const int j = j0 + threadIdx.y;
  695. #pragma unroll
  696. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  697. const int i = i0 + threadIdx.x;
  698. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
  699. (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
  700. y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
  701. }
  702. }
  703. }
  704. template <int mmq_x, int mmq_y, int nwarps>
  705. static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
  706. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  707. #ifdef INT8_MMA_AVAILABLE
  708. typedef mma_int_A_I16K8 mma_A;
  709. typedef mma_int_B_J8K8 mma_B;
  710. typedef mma_int_C_I16J8 mma_C;
  711. constexpr int granularity = mmq_get_granularity_device(mmq_x);
  712. constexpr int rows_per_warp = 2 * granularity;
  713. constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
  714. y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
  715. const int * x_qs = (const int *) x;
  716. const float * x_df = (const float *) x_qs + WARP_SIZE;
  717. const int * y_qs = (const int *) y + 4;
  718. const float * y_df = (const float *) y;
  719. mma_A A[ntx];
  720. float dA[ntx][mma_C::ne/2];
  721. const int i0 = (threadIdx.y/ntx)*rows_per_warp;
  722. #pragma unroll
  723. for (int n = 0; n < ntx; ++n) {
  724. A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
  725. #pragma unroll
  726. for (int l = 0; l < mma_C::ne/2; ++l) {
  727. const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
  728. dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
  729. }
  730. }
  731. #pragma unroll
  732. for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
  733. mma_B B;
  734. float dB[mma_C::ne/2];
  735. B.load(y_qs + j0*MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K);
  736. #pragma unroll
  737. for (int l = 0; l < mma_C::ne/2; ++l) {
  738. const int j = j0 + mma_C::get_j(l);
  739. dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1];
  740. }
  741. #pragma unroll
  742. for (int n = 0; n < ntx; ++n) {
  743. mma_C C;
  744. C.mma_K8(A[n], B);
  745. #pragma unroll
  746. for (int l = 0; l < mma_C::ne; ++l) {
  747. sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2];
  748. }
  749. }
  750. }
  751. #else
  752. GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
  753. NO_DEVICE_CODE;
  754. #endif // INT8_MMA_AVAILABLE
  755. }
  756. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
  757. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  758. #ifdef INT8_MMA_AVAILABLE
  759. int * x_qs = (int *) x_tile;
  760. half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
  761. #else
  762. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
  763. int * x_qs = (int *) x_tile;
  764. half2 * x_dm = (half2 *) (x_qs + txs.qs);
  765. #endif // INT8_MMA_AVAILABLE
  766. const int kbx = threadIdx.x / QI2_K;
  767. const int kqsx = threadIdx.x % QI2_K;
  768. #pragma unroll
  769. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  770. int i = i0 + threadIdx.y;
  771. if (need_check) {
  772. i = min(i, i_max);
  773. }
  774. const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
  775. const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
  776. #pragma unroll
  777. for (int l = 0; l < QR2_K; ++l) {
  778. const int k = kbx*QI2_K + (kqsx/8)*8 + l*2 + (kqsx % 8)/4;
  779. int x_qs_k = ((x_ql_0 >> (2*l)) & 0x03030303) << (2*(kqsx % 4));
  780. x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
  781. x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE);
  782. if (kqsx % QR2_K != 0) {
  783. continue;
  784. }
  785. #ifdef INT8_MMA_AVAILABLE
  786. x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
  787. #else
  788. x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k;
  789. #endif // INT8_MMA_AVAILABLE
  790. }
  791. const int sc_m = bxi->scales[kqsx];
  792. #ifdef FAST_FP16_AVAILABLE
  793. const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4));
  794. #else
  795. const float2 bxi_dmf = __half22float2(bxi->dm);
  796. const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
  797. #endif // FAST_FP16_AVAILABLE
  798. #ifdef INT8_MMA_AVAILABLE
  799. x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + threadIdx.x] = x_dm_ik;
  800. #else
  801. x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik;
  802. #endif // INT8_MMA_AVAILABLE
  803. }
  804. }
  805. template <int mmq_x, int mmq_y, int nwarps>
  806. static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
  807. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  808. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
  809. const int * x_qs = (const int *) x;
  810. const half2 * x_dm = (const half2 *) x_qs + txs.qs;
  811. const int * y_qs = (const int *) y + 4;
  812. const float * y_df = (const float *) y;
  813. #pragma unroll
  814. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  815. const int j = j0 + threadIdx.y;
  816. #pragma unroll
  817. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  818. const int i = i0 + threadIdx.x;
  819. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
  820. &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE],
  821. &x_dm[i*(WARP_SIZE + 1) + k0], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
  822. }
  823. }
  824. }
  825. template <int mmq_x, int mmq_y, int nwarps>
  826. static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
  827. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  828. #ifdef INT8_MMA_AVAILABLE
  829. typedef mma_int_A_I16K4 mma_A;
  830. typedef mma_int_B_J8K4 mma_B;
  831. typedef mma_int_C_I16J8 mma_C;
  832. constexpr int granularity = mmq_get_granularity_device(mmq_x);
  833. constexpr int rows_per_warp = 2 * granularity;
  834. constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
  835. y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
  836. const int * x_qs = (const int *) x;
  837. const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
  838. const int * y_qs = (const int *) y + 4;
  839. const float * y_df = (const float *) y;
  840. const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
  841. mma_A A[ntx][2];
  842. float dA[ntx][mma_C::ne/2][2];
  843. float mA[ntx][mma_C::ne/2][2];
  844. #pragma unroll
  845. for (int n = 0; n < ntx; ++n) {
  846. #pragma unroll
  847. for (int l = 0; l < mma_A::ne; ++l) {
  848. const int i = i0 + n*mma_A::I + mma_A::get_i(l);
  849. const int shift = 2*mma_A::get_k(l);
  850. A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 0] >> shift) & 0x03030303;
  851. A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 1] >> shift) & 0x03030303;
  852. }
  853. #pragma unroll
  854. for (int l = 0; l < mma_C::ne/2; ++l) {
  855. const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
  856. #pragma unroll
  857. for (int kdm = 0; kdm < 2; ++kdm) {
  858. const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + kdm]);
  859. dA[n][l][kdm] = dm.x;
  860. mA[n][l][kdm] = dm.y;
  861. }
  862. }
  863. }
  864. #pragma unroll
  865. for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
  866. mma_B B[2];
  867. float dB[mma_C::ne/2];
  868. B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + 0) % WARP_SIZE, MMQ_TILE_Y_K);
  869. B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K);
  870. #pragma unroll
  871. for (int l = 0; l < mma_C::ne/2; ++l) {
  872. const int j = j0 + mma_C::get_j(l);
  873. dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
  874. }
  875. mma_C Cm[2];
  876. mma_A A1;
  877. A1.x[0] = 0x01010101;
  878. A1.x[1] = 0x01010101;
  879. Cm[0].mma_K4(A1, B[0]);
  880. Cm[1].mma_K4(A1, B[1]);
  881. #pragma unroll
  882. for (int n = 0; n < ntx; ++n) {
  883. mma_C Cd[2];
  884. Cd[0].mma_K4(A[n][0], B[0]);
  885. Cd[1].mma_K4(A[n][1], B[1]);
  886. #pragma unroll
  887. for (int l = 0; l < mma_C::ne; ++l) {
  888. sum[(j0/mma_C::J + n)*mma_C::ne + l] += (
  889. Cd[0].x[l]*dA[n][l/2][0] + Cd[1].x[l]*dA[n][l/2][1] - Cm[0].x[l]*mA[n][l/2][0] - Cm[1].x[l]*mA[n][l/2][1])*dB[l%2];
  890. }
  891. }
  892. }
  893. #else
  894. GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
  895. NO_DEVICE_CODE;
  896. #endif // INT8_MMA_AVAILABLE
  897. }
  898. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
  899. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  900. #ifdef INT8_MMA_AVAILABLE
  901. int * x_qs = (int *) x_tile;
  902. float * x_df = (float *) (x_qs + WARP_SIZE*2);
  903. int * x_sc = (int *) (x_df + WARP_SIZE/QI3_K);
  904. #else
  905. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
  906. int * x_qs = (int *) x_tile;
  907. float * x_df = (float *) (x_qs + txs.qs);
  908. int * x_sc = (int *) (x_df + txs.dm);
  909. #endif // INT8_MMA_AVAILABLE
  910. const int kbx = threadIdx.x / QI3_K;
  911. const int kqsx = threadIdx.x % QI3_K;
  912. #pragma unroll
  913. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  914. int i = i0 + threadIdx.y;
  915. if (need_check) {
  916. i = min(i, i_max);
  917. }
  918. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
  919. const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
  920. const int x_qh_0 = get_int_from_uint8(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
  921. #pragma unroll
  922. for (int l = 0; l < QR3_K; ++l) {
  923. const int k = kbx*(QR3_K*QI3_K) + (kqsx/8)*32 + l*8 + kqsx % 8;
  924. const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303;
  925. const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404;
  926. int x_qs_k = (x_ql_k | x_qh_k) << (4*(k%2));
  927. x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
  928. if (kqsx % 2 != 0) {
  929. continue;
  930. }
  931. #ifdef INT8_MMA_AVAILABLE
  932. x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2] = x_qs_k;
  933. #else
  934. x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k;
  935. #endif // INT8_MMA_AVAILABLE
  936. }
  937. }
  938. const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
  939. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  940. #pragma unroll
  941. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
  942. int i = (i0 + threadIdx.y * QI3_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  943. if (need_check) {
  944. i = min(i, i_max);
  945. }
  946. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd;
  947. #ifdef INT8_MMA_AVAILABLE
  948. x_df[i*MMQ_MMA_TILE_X_K_Q3_K + kbxd] = bxi->d;
  949. #else
  950. x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbxd] = bxi->d;
  951. #endif // INT8_MMA_AVAILABLE
  952. }
  953. #pragma unroll
  954. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
  955. int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
  956. if (need_check) {
  957. i = min(i, i_max);
  958. }
  959. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI3_K/4);
  960. const int ksc = threadIdx.x % (QI3_K/4);
  961. const int ksc_low = ksc % (QI3_K/8);
  962. const int shift_low = 4 * (ksc / (QI3_K/8));
  963. const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
  964. const int ksc_high = QI3_K/8;
  965. const int shift_high = 2 * ksc;
  966. const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
  967. const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
  968. #ifdef INT8_MMA_AVAILABLE
  969. x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/4)] = sc;
  970. #else
  971. x_sc[i*(WARP_SIZE/4) + i/4 + threadIdx.x % (WARP_SIZE/4)] = sc;
  972. #endif // INT8_MMA_AVAILABLE
  973. }
  974. }
  975. template <int mmq_x, int mmq_y, int nwarps>
  976. static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
  977. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  978. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
  979. const int * x_qs = (const int *) x;
  980. const float * x_df = (const float *) x_qs + txs.qs;
  981. const int * x_sc = (const int *) x_df + txs.dm;
  982. const int * y_qs = (const int *) y + 4;
  983. const float * y_df = (const float *) y;
  984. #pragma unroll
  985. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  986. const int j = j0 + threadIdx.y;
  987. #pragma unroll
  988. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  989. const int i = i0 + threadIdx.x;
  990. const int kbx = k0 / QI3_K;
  991. const int ky = (k0 % QI3_K) * QR3_K;
  992. const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
  993. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
  994. &x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
  995. x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
  996. }
  997. }
  998. }
  999. template <int mmq_x, int mmq_y, int nwarps>
  1000. static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
  1001. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  1002. #ifdef INT8_MMA_AVAILABLE
  1003. typedef mma_int_A_I16K4 mma_A;
  1004. typedef mma_int_B_J8K4 mma_B;
  1005. typedef mma_int_C_I16J8 mma_C;
  1006. constexpr int granularity = mmq_get_granularity_device(mmq_x);
  1007. constexpr int rows_per_warp = 2 * granularity;
  1008. constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
  1009. y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
  1010. const int * x_qs = (const int *) x;
  1011. const float * x_df = (const float *) x_qs + WARP_SIZE*2;
  1012. const int * x_sc = (const int *) x_df + WARP_SIZE/QI3_K;
  1013. const int * y_qs = (const int *) y + 4;
  1014. const float * y_df = (const float *) y;
  1015. const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
  1016. mma_A A[ntx][2];
  1017. int scA[ntx][mma_C::ne/2][2];
  1018. float dA[ntx][mma_C::ne/2];
  1019. #pragma unroll
  1020. for (int n = 0; n < ntx; ++n) {
  1021. #pragma unroll
  1022. for (int l = 0; l < mma_A::ne; ++l) {
  1023. const int i = i0 + n*mma_A::I + mma_A::get_i(l);
  1024. const int k = QR3_K*k0 + mma_A::get_k(l);
  1025. A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + 0] >> (4*(k%2))) & 0x0F0F0F0F;
  1026. A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F;
  1027. A[n][0].x[l] = __vsubss4(A[n][0].x[l], 0x04040404);
  1028. A[n][1].x[l] = __vsubss4(A[n][1].x[l], 0x04040404);
  1029. }
  1030. #pragma unroll
  1031. for (int l = 0; l < mma_C::ne/2; ++l) {
  1032. const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
  1033. const int kbx = k0 / QI3_K;
  1034. const int ky = (k0 % QI3_K) * QR3_K;
  1035. const int8_t * sc = ((const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q3_K + kbx*4)) + ky/4;
  1036. scA[n][l][0] = sc[0];
  1037. scA[n][l][1] = sc[1];
  1038. }
  1039. #pragma unroll
  1040. for (int l = 0; l < mma_C::ne/2; ++l) {
  1041. const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
  1042. dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/QI3_K];
  1043. }
  1044. }
  1045. #pragma unroll
  1046. for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
  1047. mma_B B[2];
  1048. float dB[mma_C::ne/2];
  1049. B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + 0) % WARP_SIZE, MMQ_TILE_Y_K);
  1050. B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K);
  1051. #pragma unroll
  1052. for (int l = 0; l < mma_C::ne/2; ++l) {
  1053. const int j = j0 + mma_C::get_j(l);
  1054. dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
  1055. }
  1056. #pragma unroll
  1057. for (int n = 0; n < ntx; ++n) {
  1058. mma_C C[2];
  1059. C[0].mma_K4(A[n][0], B[0]);
  1060. C[1].mma_K4(A[n][1], B[1]);
  1061. #pragma unroll
  1062. for (int l = 0; l < mma_C::ne; ++l) {
  1063. sum[(j0/mma_C::J + n)*mma_C::ne + l] += (C[0].x[l]*scA[n][l/2][0] + C[1].x[l]*scA[n][l/2][1])*dA[n][l/2]*dB[l%2];
  1064. }
  1065. }
  1066. }
  1067. #else
  1068. GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
  1069. NO_DEVICE_CODE;
  1070. #endif // INT8_MMA_AVAILABLE
  1071. }
  1072. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
  1073. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  1074. #ifdef INT8_MMA_AVAILABLE
  1075. int * x_qs = (int *) x_tile;
  1076. half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
  1077. int * x_sc = (int *) (x_dm + WARP_SIZE/QI4_K);
  1078. #else
  1079. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
  1080. int * x_qs = (int *) x_tile;
  1081. half2 * x_dm = (half2 *) (x_qs + txs.qs);
  1082. int * x_sc = (int *) (x_dm + txs.dm);
  1083. #endif // INT8_MMA_AVAILABLE
  1084. const int kbx = 0; // threadIdx.x / QI4_K
  1085. const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
  1086. #pragma unroll
  1087. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  1088. int i = i0 + threadIdx.y;
  1089. if (need_check) {
  1090. i = min(i, i_max);
  1091. }
  1092. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
  1093. #ifdef INT8_MMA_AVAILABLE
  1094. x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  1095. #else
  1096. x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  1097. #endif // INT8_MMA_AVAILABLE
  1098. }
  1099. const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
  1100. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  1101. #pragma unroll
  1102. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
  1103. int i = (i0 + threadIdx.y * QI4_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  1104. if (need_check) {
  1105. i = min(i, i_max);
  1106. }
  1107. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd;
  1108. #ifdef INT8_MMA_AVAILABLE
  1109. x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + kbxd] = bxi->dm;
  1110. #else
  1111. x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K + kbxd] = bxi->dm;
  1112. #endif // INT8_MMA_AVAILABLE
  1113. }
  1114. #pragma unroll
  1115. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  1116. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  1117. if (need_check) {
  1118. i = min(i, i_max);
  1119. }
  1120. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
  1121. const int * scales = (const int *) bxi->scales;
  1122. const int ksc = threadIdx.x % (WARP_SIZE/8);
  1123. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  1124. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  1125. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  1126. #ifdef INT8_MMA_AVAILABLE
  1127. x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + ksc] = scales8;
  1128. #else
  1129. x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
  1130. #endif // INT8_MMA_AVAILABLE
  1131. }
  1132. }
  1133. template <int mmq_x, int mmq_y, int nwarps>
  1134. static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
  1135. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  1136. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
  1137. const int * x_qs = (const int *) x;
  1138. const half2 * x_dm = (const half2 *) x_qs + txs.qs;
  1139. const int * x_sc = (const int *) x_dm + txs.dm;
  1140. const int * y_qs = (const int *) y + 4;
  1141. const half2 * y_ds = (const half2 *) y;
  1142. #pragma unroll
  1143. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1144. const int j = j0 + threadIdx.y;
  1145. #pragma unroll
  1146. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1147. const int i = i0 + threadIdx.x;
  1148. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
  1149. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
  1150. &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
  1151. x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]);
  1152. }
  1153. }
  1154. }
  1155. template <int mmq_x, int mmq_y, int nwarps>
  1156. static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
  1157. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  1158. #ifdef INT8_MMA_AVAILABLE
  1159. typedef mma_int_A_I16K8 mma_A;
  1160. typedef mma_int_B_J8K8 mma_B;
  1161. typedef mma_int_C_I16J8 mma_C;
  1162. constexpr int granularity = mmq_get_granularity_device(mmq_x);
  1163. constexpr int rows_per_warp = 2 * granularity;
  1164. constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
  1165. y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
  1166. const int * x_qs = (const int *) x;
  1167. const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
  1168. const int * x_sc = (const int *) x_dm + WARP_SIZE/QI4_K;
  1169. const int * y_qs = (const int *) y + 4;
  1170. const half2 * y_ds = (const half2 *) y;
  1171. const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
  1172. mma_A A[ntx][2];
  1173. int scA[ntx][mma_C::ne/2][2];
  1174. int mA[ntx][mma_C::ne/2][2];
  1175. half2 dmA[ntx][mma_C::ne/2];
  1176. #pragma unroll
  1177. for (int n = 0; n < ntx; ++n) {
  1178. #pragma unroll
  1179. for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 8) {
  1180. A[n][kvdr/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0, MMQ_MMA_TILE_X_K_Q4_K);
  1181. #pragma unroll
  1182. for (int l = 0; l < mma_A::ne; ++l) {
  1183. A[n][kvdr/4 + 1].x[l] = (A[n][kvdr/4 + 0].x[l] >> 4) & 0x0F0F0F0F;
  1184. A[n][kvdr/4 + 0].x[l] &= 0x0F0F0F0F;
  1185. }
  1186. }
  1187. #pragma unroll
  1188. for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
  1189. #pragma unroll
  1190. for (int l = 0; l < mma_C::ne/2; ++l) {
  1191. const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
  1192. const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + k0/16]) + 2 * ((k0 % 16) / 8);
  1193. const uint8_t * m = sc + 8;
  1194. scA[n][l][kvdr/4] = sc[kvdr/4];
  1195. mA[n][l][kvdr/4] = m[kvdr/4];
  1196. }
  1197. }
  1198. #pragma unroll
  1199. for (int l = 0; l < mma_C::ne/2; ++l) {
  1200. const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
  1201. dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + k0/QI4_K];
  1202. }
  1203. }
  1204. #pragma unroll
  1205. for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
  1206. float tmpd[ntx][mma_C::ne] = {{0.0f}};
  1207. float tmpm[ntx][mma_C::ne] = {{0.0f}};
  1208. #pragma unroll
  1209. for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
  1210. mma_B B;
  1211. half2 dsB[mma_C::ne/2];
  1212. B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K);
  1213. #pragma unroll
  1214. for (int l = 0; l < mma_C::ne/2; ++l) {
  1215. const int j = j0 + mma_C::get_j(l);
  1216. dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
  1217. }
  1218. #pragma unroll
  1219. for (int n = 0; n < ntx; ++n) {
  1220. mma_C C;
  1221. C.mma_K8(A[n][kvdr/4], B);
  1222. #pragma unroll
  1223. for (int l = 0; l < mma_C::ne; ++l) {
  1224. tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) * __low2float(dsB[l%2]);
  1225. tmpm[n][l] += mA[n][l/2][kvdr/4] * __high2float(dsB[l%2]);
  1226. }
  1227. }
  1228. }
  1229. #pragma unroll
  1230. for (int n = 0; n < ntx; ++n) {
  1231. #pragma unroll
  1232. for (int l = 0; l < mma_C::ne; ++l) {
  1233. sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l];
  1234. }
  1235. }
  1236. }
  1237. #else
  1238. GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
  1239. NO_DEVICE_CODE;
  1240. #endif // INT8_MMA_AVAILABLE
  1241. }
  1242. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
  1243. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  1244. #ifdef INT8_MMA_AVAILABLE
  1245. int * x_qs = (int *) x_tile;
  1246. half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
  1247. int * x_sc = (int *) (x_dm + WARP_SIZE/QI5_K);
  1248. #else
  1249. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
  1250. int * x_qs = (int *) x_tile;
  1251. half2 * x_dm = (half2 *) (x_qs + txs.qs);
  1252. int * x_sc = (int *) (x_dm + txs.dm);
  1253. #endif // INT8_MMA_AVAILABLE
  1254. const int kbx = 0; // threadIdx.x / QI5_K
  1255. const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
  1256. #pragma unroll
  1257. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  1258. int i = i0 + threadIdx.y;
  1259. if (need_check) {
  1260. i = min(i, i_max);
  1261. }
  1262. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
  1263. const int ky = QR5_K*kqsx;
  1264. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  1265. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  1266. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  1267. const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
  1268. const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
  1269. const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
  1270. const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
  1271. const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
  1272. #ifdef INT8_MMA_AVAILABLE
  1273. x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq0] = ql0 | qh0;
  1274. x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq1] = ql1 | qh1;
  1275. #else
  1276. x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
  1277. x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
  1278. #endif // INT8_MMA_AVAILABLE
  1279. }
  1280. const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
  1281. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  1282. #pragma unroll
  1283. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
  1284. int i = (i0 + threadIdx.y * QI5_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  1285. if (need_check) {
  1286. i = min(i, i_max);
  1287. }
  1288. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd;
  1289. #ifdef INT8_MMA_AVAILABLE
  1290. x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + kbxd] = bxi->dm;
  1291. #else
  1292. x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + kbxd] = bxi->dm;
  1293. #endif // INT8_MMA_AVAILABLE
  1294. }
  1295. #pragma unroll
  1296. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  1297. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  1298. if (need_check) {
  1299. i = min(i, i_max);
  1300. }
  1301. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI5_K/8);
  1302. const int * scales = (const int *) bxi->scales;
  1303. const int ksc = threadIdx.x % (WARP_SIZE/8);
  1304. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  1305. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  1306. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  1307. #ifdef INT8_MMA_AVAILABLE
  1308. x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + ksc] = scales8;
  1309. #else
  1310. x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
  1311. #endif // INT8_MMA_AVAILABLE
  1312. }
  1313. }
  1314. template <int mmq_x, int mmq_y, int nwarps>
  1315. static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
  1316. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  1317. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
  1318. const int * x_qs = (const int *) x;
  1319. const half2 * x_dm = (const half2 *) x_qs + txs.qs;
  1320. const int * x_sc = (const int *) x_dm + txs.dm;
  1321. const int * y_qs = (const int *) y + 4;
  1322. const half2 * y_ds = (const half2 *) y;
  1323. #pragma unroll
  1324. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1325. const int j = j0 + threadIdx.y;
  1326. #pragma unroll
  1327. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1328. const int i = i0 + threadIdx.x;
  1329. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
  1330. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
  1331. &x_qs[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
  1332. x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]);
  1333. }
  1334. }
  1335. }
  1336. template <int mmq_x, int mmq_y, int nwarps>
  1337. static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
  1338. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  1339. #ifdef INT8_MMA_AVAILABLE
  1340. typedef mma_int_A_I16K8 mma_A;
  1341. typedef mma_int_B_J8K8 mma_B;
  1342. typedef mma_int_C_I16J8 mma_C;
  1343. constexpr int granularity = mmq_get_granularity_device(mmq_x);
  1344. constexpr int rows_per_warp = 2 * granularity;
  1345. constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
  1346. y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
  1347. const int * x_qs = (const int *) x;
  1348. const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
  1349. const int * x_sc = (const int *) x_dm + WARP_SIZE/QI5_K;
  1350. const int * y_qs = (const int *) y + 4;
  1351. const half2 * y_ds = (const half2 *) y;
  1352. const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
  1353. mma_A A[ntx][2];
  1354. int scA[ntx][mma_C::ne/2][2];
  1355. int mA[ntx][mma_C::ne/2][2];
  1356. half2 dmA[ntx][mma_C::ne/2];
  1357. #pragma unroll
  1358. for (int n = 0; n < ntx; ++n) {
  1359. #pragma unroll
  1360. for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
  1361. A[n][kvdr/4].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + (QR5_K*k0 + QR5_K*kvdr), MMQ_MMA_TILE_X_K_Q5_K);
  1362. #pragma unroll
  1363. for (int l = 0; l < mma_C::ne/2; ++l) {
  1364. const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
  1365. const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + k0/16]) + 2 * ((k0 % 16) / 8);
  1366. const uint8_t * m = sc + 8;
  1367. scA[n][l][kvdr/4] = sc[kvdr/4];
  1368. mA[n][l][kvdr/4] = m[kvdr/4];
  1369. }
  1370. }
  1371. #pragma unroll
  1372. for (int l = 0; l < mma_C::ne/2; ++l) {
  1373. const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
  1374. dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + k0/QI5_K];
  1375. }
  1376. }
  1377. #pragma unroll
  1378. for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
  1379. float tmpd[ntx][mma_C::ne] = {{0.0f}};
  1380. float tmpm[ntx][mma_C::ne] = {{0.0f}};
  1381. #pragma unroll
  1382. for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
  1383. mma_B B;
  1384. half2 dsB[mma_C::ne/2];
  1385. B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K);
  1386. #pragma unroll
  1387. for (int l = 0; l < mma_C::ne/2; ++l) {
  1388. const int j = j0 + mma_C::get_j(l);
  1389. dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
  1390. }
  1391. #pragma unroll
  1392. for (int n = 0; n < ntx; ++n) {
  1393. mma_C C;
  1394. C.mma_K8(A[n][kvdr/4], B);
  1395. #pragma unroll
  1396. for (int l = 0; l < mma_C::ne; ++l) {
  1397. tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) * __low2float(dsB[l%2]);
  1398. tmpm[n][l] += mA[n][l/2][kvdr/4] * __high2float(dsB[l%2]);
  1399. }
  1400. }
  1401. }
  1402. #pragma unroll
  1403. for (int n = 0; n < ntx; ++n) {
  1404. #pragma unroll
  1405. for (int l = 0; l < mma_C::ne; ++l) {
  1406. sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l];
  1407. }
  1408. }
  1409. }
  1410. #else
  1411. GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
  1412. NO_DEVICE_CODE;
  1413. #endif // INT8_MMA_AVAILABLE
  1414. }
  1415. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
  1416. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  1417. #ifdef INT8_MMA_AVAILABLE
  1418. int * x_qs = (int *) x_tile;
  1419. float * x_df = (float *) (x_qs + WARP_SIZE*2);
  1420. int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
  1421. #else
  1422. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
  1423. int * x_qs = (int *) x_tile;
  1424. float * x_df = (float *) (x_qs + txs.qs);
  1425. int * x_sc = (int *) (x_df + txs.dm);
  1426. #endif // INT8_MMA_AVAILABLE
  1427. const int kbx = 0; // threadIdx.x / QI6_K
  1428. const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
  1429. #pragma unroll
  1430. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  1431. int i = i0 + threadIdx.y;
  1432. if (need_check) {
  1433. i = min(i, i_max);
  1434. }
  1435. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
  1436. const int ky = QR6_K*kqsx;
  1437. const int ql = get_int_from_uint8(bxi->ql, kqsx);
  1438. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  1439. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  1440. const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
  1441. const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
  1442. const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
  1443. const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
  1444. const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
  1445. #ifdef INT8_MMA_AVAILABLE
  1446. x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
  1447. x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
  1448. #else
  1449. x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
  1450. x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
  1451. #endif // INT8_MMA_AVAILABLE
  1452. }
  1453. const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
  1454. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  1455. #pragma unroll
  1456. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
  1457. int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  1458. if (need_check) {
  1459. i = min(i, i_max);
  1460. }
  1461. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
  1462. #ifdef INT8_MMA_AVAILABLE
  1463. x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d;
  1464. #else
  1465. x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
  1466. #endif // INT8_MMA_AVAILABLE
  1467. }
  1468. #pragma unroll
  1469. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  1470. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  1471. if (need_check) {
  1472. i = min(i, i_max);
  1473. }
  1474. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
  1475. #ifdef INT8_MMA_AVAILABLE
  1476. x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
  1477. #else
  1478. x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
  1479. #endif // INT8_MMA_AVAILABLE
  1480. }
  1481. }
  1482. template <int mmq_x, int mmq_y, int nwarps>
  1483. static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
  1484. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  1485. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
  1486. const int * x_qs = (const int *) x;
  1487. const float * x_df = (const float *) x_qs + txs.qs;
  1488. const int * x_sc = (const int *) x_df + txs.dm;
  1489. const int * y_qs = (const int *) y + 4;
  1490. const float * y_df = (const float *) y;
  1491. #pragma unroll
  1492. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1493. const int j = j0 + threadIdx.y;
  1494. #pragma unroll
  1495. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1496. const int i = i0 + threadIdx.x;
  1497. const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
  1498. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
  1499. &x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
  1500. x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
  1501. }
  1502. }
  1503. }
  1504. template <int mmq_x, int mmq_y, int nwarps>
  1505. static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
  1506. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  1507. #ifdef INT8_MMA_AVAILABLE
  1508. typedef mma_int_A_I16K4 mma_A;
  1509. typedef mma_int_B_J8K4 mma_B;
  1510. typedef mma_int_C_I16J8 mma_C;
  1511. constexpr int granularity = mmq_get_granularity_device(mmq_x);
  1512. constexpr int rows_per_warp = 2 * granularity;
  1513. constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
  1514. y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
  1515. const int * x_qs = (const int *) x;
  1516. const float * x_df = (const float *) x_qs + WARP_SIZE*2;
  1517. const int * x_sc = (const int *) x_df + WARP_SIZE/QI6_K;
  1518. const int * y_qs = (const int *) y + 4;
  1519. const float * y_df = (const float *) y;
  1520. const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
  1521. mma_A A[ntx][4];
  1522. int scA[ntx][mma_C::ne/2][4];
  1523. float dA[ntx][mma_C::ne/2];
  1524. #pragma unroll
  1525. for (int n = 0; n < ntx; ++n) {
  1526. #pragma unroll
  1527. for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
  1528. A[n][kvdr/2 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + 0), MMQ_MMA_TILE_X_K_Q6_K);
  1529. A[n][kvdr/2 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
  1530. #pragma unroll
  1531. for (int l = 0; l < mma_C::ne/2; ++l) {
  1532. const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
  1533. const int8_t * sc = ((const int8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/8]);
  1534. scA[n][l][kvdr/2 + 0] = sc[kvdr/2 + 0];
  1535. scA[n][l][kvdr/2 + 1] = sc[kvdr/2 + 1];
  1536. }
  1537. }
  1538. #pragma unroll
  1539. for (int l = 0; l < mma_C::ne/2; ++l) {
  1540. const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
  1541. dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K + k0/QI6_K];
  1542. }
  1543. }
  1544. #pragma unroll
  1545. for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
  1546. float tmp[ntx][mma_C::ne] = {{0.0f}};
  1547. #pragma unroll
  1548. for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
  1549. mma_B B[2];
  1550. float dB[mma_C::ne/2];
  1551. const int k0B = (2*k0 + 2*kvdr) % WARP_SIZE;
  1552. B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k0B, MMQ_TILE_Y_K);
  1553. B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k0B, MMQ_TILE_Y_K);
  1554. #pragma unroll
  1555. for (int l = 0; l < mma_C::ne/2; ++l) {
  1556. const int j = j0 + mma_C::get_j(l);
  1557. dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
  1558. }
  1559. #pragma unroll
  1560. for (int n = 0; n < ntx; ++n) {
  1561. mma_C C[2];
  1562. C[0].mma_K4(A[n][kvdr/2 + 0], B[0]);
  1563. C[1].mma_K4(A[n][kvdr/2 + 1], B[1]);
  1564. #pragma unroll
  1565. for (int l = 0; l < mma_C::ne; ++l) {
  1566. tmp[n][l] += (C[0].x[l]*scA[n][l/2][kvdr/2 + 0] + C[1].x[l]*scA[n][l/2][kvdr/2 + 1])*dB[l%2];
  1567. }
  1568. }
  1569. }
  1570. #pragma unroll
  1571. for (int n = 0; n < ntx; ++n) {
  1572. #pragma unroll
  1573. for (int l = 0; l < mma_C::ne; ++l) {
  1574. sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
  1575. }
  1576. }
  1577. }
  1578. #else
  1579. GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
  1580. NO_DEVICE_CODE;
  1581. #endif // INT8_MMA_AVAILABLE
  1582. }
  1583. template<int mmq_x, int mmq_y, int nwarps, bool need_check>
  1584. static __device__ __forceinline__ void mmq_write_back_dp4a(
  1585. const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
  1586. #pragma unroll
  1587. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1588. const int j = j0 + threadIdx.y;
  1589. if (j > j_max) {
  1590. return;
  1591. }
  1592. #pragma unroll
  1593. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1594. const int i = i0 + threadIdx.x;
  1595. if (need_check && i > i_max) {
  1596. continue;
  1597. }
  1598. dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
  1599. }
  1600. }
  1601. }
  1602. template<int mmq_x, int mmq_y, int nwarps, bool need_check>
  1603. static __device__ __forceinline__ void mmq_write_back_mma(
  1604. const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
  1605. typedef mma_int_C_I16J8 mma_C;
  1606. constexpr int granularity = mmq_get_granularity_device(mmq_x);
  1607. constexpr int rows_per_warp = 2 * granularity;
  1608. constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
  1609. const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
  1610. #ifdef INT8_MMA_AVAILABLE
  1611. static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
  1612. #endif // INT8_MMA_AVAILABLE
  1613. #pragma unroll
  1614. for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
  1615. #pragma unroll
  1616. for (int n = 0; n < ntx; ++n) {
  1617. #pragma unroll
  1618. for (int l = 0; l < mma_C::ne; ++l) {
  1619. const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
  1620. if (j > j_max) {
  1621. continue;
  1622. }
  1623. const int i = i0 + n*mma_C::I + mma_C::get_i(l);
  1624. if (need_check && i > i_max) {
  1625. continue;
  1626. }
  1627. dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
  1628. }
  1629. }
  1630. }
  1631. }
  1632. // -------------------------------------------------------------------------------------------------------------------------------------
  1633. template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
  1634. struct mmq_type_traits;
  1635. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1636. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
  1637. static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
  1638. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
  1639. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1640. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1641. };
  1642. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1643. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
  1644. static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
  1645. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
  1646. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1647. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1648. };
  1649. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1650. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
  1651. static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
  1652. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
  1653. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1654. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1655. };
  1656. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1657. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
  1658. static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
  1659. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
  1660. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1661. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1662. };
  1663. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1664. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
  1665. static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
  1666. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
  1667. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1668. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1669. };
  1670. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1671. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
  1672. static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
  1673. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
  1674. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1675. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1676. };
  1677. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1678. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
  1679. static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
  1680. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
  1681. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q3_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1682. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1683. };
  1684. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1685. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
  1686. static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
  1687. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
  1688. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1689. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1690. };
  1691. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1692. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
  1693. static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
  1694. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
  1695. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1696. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1697. };
  1698. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1699. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
  1700. static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
  1701. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
  1702. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1703. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1704. };
  1705. static bool mmq_need_sum(const ggml_type type_x) {
  1706. switch (type_x) {
  1707. case GGML_TYPE_Q4_0:
  1708. case GGML_TYPE_Q4_1:
  1709. return true;
  1710. case GGML_TYPE_Q5_0:
  1711. return false;
  1712. case GGML_TYPE_Q5_1:
  1713. return true;
  1714. case GGML_TYPE_Q8_0:
  1715. case GGML_TYPE_Q2_K:
  1716. case GGML_TYPE_Q3_K:
  1717. return false;
  1718. case GGML_TYPE_Q4_K:
  1719. case GGML_TYPE_Q5_K:
  1720. return true;
  1721. case GGML_TYPE_Q6_K:
  1722. return false;
  1723. default:
  1724. GGML_ASSERT(false);
  1725. break;
  1726. }
  1727. return false;
  1728. }
  1729. template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
  1730. static __device__ void mul_mat_q_process_tile(
  1731. const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
  1732. const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0,
  1733. const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
  1734. constexpr int qk = ggml_cuda_type_traits<type>::qk;
  1735. constexpr int qr = ggml_cuda_type_traits<type>::qr;
  1736. constexpr int qi = ggml_cuda_type_traits<type>::qi;
  1737. constexpr int mmq_y = get_mmq_y_device();
  1738. constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
  1739. constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
  1740. extern __shared__ char data_mul_mat_q[];
  1741. int * tile_y = (int *) data_mul_mat_q;
  1742. int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
  1743. #ifdef INT8_MMA_AVAILABLE
  1744. constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
  1745. constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
  1746. #else
  1747. constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
  1748. constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
  1749. #endif // INT8_MMA_AVAILABLE
  1750. constexpr int blocks_per_warp = WARP_SIZE / qi;
  1751. float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
  1752. const int tile_x_max_i = ne01 - it*mmq_y - 1;
  1753. const int tile_y_max_j = ne11 - jt*mmq_x - 1;
  1754. const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
  1755. for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) {
  1756. load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
  1757. #pragma unroll
  1758. for (int kr = 0; kr < qr; ++kr) {
  1759. const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int));
  1760. #pragma unroll
  1761. for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
  1762. int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
  1763. tile_y[l] = by0[l];
  1764. }
  1765. __syncthreads();
  1766. // #pragma unroll // unrolling this loop causes too much register pressure
  1767. for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
  1768. vec_dot(tile_x, tile_y, sum, k0);
  1769. }
  1770. __syncthreads();
  1771. }
  1772. }
  1773. if (fixup) {
  1774. write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
  1775. } else {
  1776. write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
  1777. }
  1778. }
  1779. // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
  1780. template <ggml_type type, int mmq_x, int nwarps, bool need_check>
  1781. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1782. #if defined(RDNA3) || defined(RDNA2)
  1783. __launch_bounds__(WARP_SIZE*nwarps, 2)
  1784. #endif // defined(RDNA3) || defined(RDNA2)
  1785. #else
  1786. #if __CUDA_ARCH__ >= CC_VOLTA
  1787. __launch_bounds__(WARP_SIZE*nwarps, 1)
  1788. #else
  1789. __launch_bounds__(WARP_SIZE*nwarps, 2)
  1790. #endif // __CUDA_ARCH__ >= CC_VOLTA
  1791. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1792. static __global__ void mul_mat_q(
  1793. const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
  1794. const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
  1795. // Skip unused template specializations for faster compilation:
  1796. if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
  1797. NO_DEVICE_CODE;
  1798. return;
  1799. }
  1800. constexpr int qk = ggml_cuda_type_traits<type>::qk;
  1801. constexpr int qi = ggml_cuda_type_traits<type>::qi;
  1802. constexpr int mmq_y = get_mmq_y_device();
  1803. // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
  1804. #if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
  1805. {
  1806. constexpr bool fixup = false;
  1807. mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
  1808. (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
  1809. blockIdx.x, blockIdx.y, 0, ne00/qk);
  1810. return;
  1811. }
  1812. #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
  1813. const int64_t blocks_per_ne00 = ne00 / qk;
  1814. constexpr int blocks_per_warp = WARP_SIZE / qi;
  1815. const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
  1816. const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
  1817. // kbc == k block continuous, current index in continuous ijk space.
  1818. int64_t kbc = GGML_PAD((int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
  1819. const int64_t kbc_stop = GGML_PAD((int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
  1820. // kb0 == k index when doing the matrix multiplication for an output tile.
  1821. int kb0_start = kbc % blocks_per_ne00;
  1822. int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
  1823. while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
  1824. const int jt = kbc / (blocks_per_ne00*nty); // j index of current tile.
  1825. const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile.
  1826. constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
  1827. mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
  1828. (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
  1829. it, jt, kb0_start, kb0_stop);
  1830. kbc += blocks_per_ne00;
  1831. kbc -= kbc % blocks_per_ne00;
  1832. kb0_start = 0;
  1833. kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
  1834. }
  1835. if (kbc >= kbc_stop) {
  1836. return;
  1837. }
  1838. const int jt = kbc / (blocks_per_ne00*nty);
  1839. const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
  1840. constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
  1841. mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
  1842. (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
  1843. it, jt, kb0_start, kb0_stop);
  1844. }
  1845. template <ggml_type type, int mmq_x, int nwarps, bool need_check>
  1846. static __global__ void mul_mat_q_stream_k_fixup(
  1847. float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) {
  1848. constexpr int mmq_y = get_mmq_y_device();
  1849. constexpr int qk = ggml_cuda_type_traits<type>::qk;
  1850. constexpr int qi = ggml_cuda_type_traits<type>::qi;
  1851. constexpr int blocks_per_warp = WARP_SIZE / qi;
  1852. const int64_t blocks_per_ne00 = ne00 / qk;
  1853. float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
  1854. const int ntx = (ne11 + mmq_x - 1) / mmq_x;
  1855. const int nty = (ne01 + mmq_y - 1) / mmq_y;
  1856. bool any_fixup = false;
  1857. const int bidx_start = (blockIdx.y*nty + blockIdx.x) * block_num_mmq / (gridDim.y*gridDim.x);
  1858. const int bidx_stop = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1;
  1859. for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
  1860. const int64_t kbc = GGML_PAD((int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
  1861. const int64_t kbc_stop = GGML_PAD((int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
  1862. // Skip fixup tile if the MMQ CUDA block never wrote anything to it:
  1863. if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
  1864. continue;
  1865. }
  1866. const int jt = kbc_stop / (blocks_per_ne00*nty);
  1867. const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
  1868. // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
  1869. if (it != blockIdx.x || jt != blockIdx.y) {
  1870. continue;
  1871. }
  1872. any_fixup = true;
  1873. #pragma unroll
  1874. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1875. const int j = j0 + threadIdx.y;
  1876. #pragma unroll
  1877. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1878. const int i = i0 + threadIdx.x;
  1879. sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
  1880. }
  1881. }
  1882. }
  1883. if (!any_fixup) {
  1884. return;
  1885. }
  1886. dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y;
  1887. const int i_max = ne01 - blockIdx.x*mmq_y - 1;
  1888. const int j_max = ne11 - blockIdx.y*mmq_x - 1;
  1889. #pragma unroll
  1890. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1891. const int j = j0 + threadIdx.y;
  1892. if (j > j_max) {
  1893. return;
  1894. }
  1895. #pragma unroll
  1896. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1897. const int i = i0 + threadIdx.x;
  1898. if (need_check && i > i_max) {
  1899. continue;
  1900. }
  1901. dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
  1902. }
  1903. }
  1904. }
  1905. struct mmq_args {
  1906. const char * x; const char * y; float * dst;
  1907. int64_t ne00; int64_t ne01; int64_t stride01;
  1908. int64_t ne10; int64_t ne11; int64_t stride11;
  1909. int64_t ne0;
  1910. };
  1911. template<ggml_type type>
  1912. static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
  1913. const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
  1914. const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
  1915. const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
  1916. const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
  1917. return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
  1918. }
  1919. template <ggml_type type, int mmq_x>
  1920. static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
  1921. const int id = ggml_cuda_get_device();
  1922. const int cc = ggml_cuda_info().devices[id].cc;
  1923. const int nsm = ggml_cuda_info().devices[id].nsm;
  1924. const int mmq_y = get_mmq_y_host(cc);
  1925. const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
  1926. const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
  1927. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  1928. static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
  1929. if (!shmem_limit_raised[id]) {
  1930. CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
  1931. CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
  1932. shmem_limit_raised[id] = true;
  1933. }
  1934. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  1935. const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
  1936. const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
  1937. const dim3 block_nums_xy_tiling(nty, ntx, 1);
  1938. const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
  1939. if (!use_stream_k) {
  1940. if (args.ne01 % mmq_y == 0) {
  1941. constexpr bool need_check = false;
  1942. mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
  1943. (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  1944. } else {
  1945. constexpr bool need_check = true;
  1946. mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
  1947. (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  1948. }
  1949. return;
  1950. }
  1951. const dim3 block_nums_mmq(nsm, 1, 1);
  1952. ggml_cuda_pool & pool = ctx.pool();
  1953. ggml_cuda_pool_alloc<float> tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y);
  1954. if (args.ne01 % mmq_y == 0) {
  1955. constexpr bool need_check = false;
  1956. mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
  1957. (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  1958. mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
  1959. (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
  1960. } else {
  1961. constexpr bool need_check = true;
  1962. mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
  1963. (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  1964. mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
  1965. (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
  1966. }
  1967. }
  1968. template <ggml_type type>
  1969. void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
  1970. const int id = ggml_cuda_get_device();
  1971. const int nsm = ggml_cuda_info().devices[id].nsm;
  1972. const int cc = ggml_cuda_info().devices[id].cc;
  1973. const int smpbo = ggml_cuda_info().devices[id].smpbo;
  1974. const int mmq_x_max = get_mmq_x_max_host(cc);
  1975. const int mmq_y = get_mmq_y_host(cc);
  1976. const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
  1977. const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
  1978. int mmq_x_best = 0;
  1979. int nparts_best = INT_MAX;
  1980. for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
  1981. const int granularity = mmq_get_granularity_host(mmq_x, cc);
  1982. if (mmq_x % granularity != 0 || mmq_get_shmem<type>(mmq_x, mmq_y, cc) > smpbo) {
  1983. continue;
  1984. }
  1985. const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
  1986. const int nwaves_xy_tiling = ntiles_x*block_num_y;
  1987. const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
  1988. if (nparts < nparts_best) {
  1989. mmq_x_best = mmq_x;
  1990. nparts_best = nparts;
  1991. }
  1992. }
  1993. switch (mmq_x_best) {
  1994. case 8:
  1995. launch_mul_mat_q<type, 8>(ctx, args, stream);
  1996. break;
  1997. case 16:
  1998. launch_mul_mat_q<type, 16>(ctx, args, stream);
  1999. break;
  2000. case 24:
  2001. launch_mul_mat_q<type, 24>(ctx, args, stream);
  2002. break;
  2003. case 32:
  2004. launch_mul_mat_q<type, 32>(ctx, args, stream);
  2005. break;
  2006. case 40:
  2007. launch_mul_mat_q<type, 40>(ctx, args, stream);
  2008. break;
  2009. case 48:
  2010. launch_mul_mat_q<type, 48>(ctx, args, stream);
  2011. break;
  2012. case 56:
  2013. launch_mul_mat_q<type, 56>(ctx, args, stream);
  2014. break;
  2015. case 64:
  2016. launch_mul_mat_q<type, 64>(ctx, args, stream);
  2017. break;
  2018. case 72:
  2019. launch_mul_mat_q<type, 72>(ctx, args, stream);
  2020. break;
  2021. case 80:
  2022. launch_mul_mat_q<type, 80>(ctx, args, stream);
  2023. break;
  2024. case 88:
  2025. launch_mul_mat_q<type, 88>(ctx, args, stream);
  2026. break;
  2027. case 96:
  2028. launch_mul_mat_q<type, 96>(ctx, args, stream);
  2029. break;
  2030. case 104:
  2031. launch_mul_mat_q<type, 104>(ctx, args, stream);
  2032. break;
  2033. case 112:
  2034. launch_mul_mat_q<type, 112>(ctx, args, stream);
  2035. break;
  2036. case 120:
  2037. launch_mul_mat_q<type, 120>(ctx, args, stream);
  2038. break;
  2039. case 128:
  2040. launch_mul_mat_q<type, 128>(ctx, args, stream);
  2041. break;
  2042. default:
  2043. fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
  2044. GGML_ASSERT(false);
  2045. break;
  2046. }
  2047. }
  2048. #define DECL_MMQ_CASE(type) \
  2049. template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
  2050. extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
  2051. extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
  2052. extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
  2053. extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
  2054. extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
  2055. extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
  2056. extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
  2057. extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
  2058. extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
  2059. extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
  2060. // -------------------------------------------------------------------------------------------------------------------------
  2061. void ggml_cuda_op_mul_mat_q(
  2062. ggml_backend_cuda_context & ctx,
  2063. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  2064. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  2065. const int64_t src1_padded_row_size, cudaStream_t stream);
  2066. bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);