mmq.cu 85 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265
  1. #include "mmq.cuh"
  2. #include "vecdotq.cuh"
  3. typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc);
  4. typedef void (*load_tiles_cuda_t)(
  5. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  6. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row);
  7. typedef float (*vec_dot_q_mul_mat_cuda_t)(
  8. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  9. const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k);
  10. typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
  11. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  12. GGML_UNUSED(x_qh);
  13. GGML_UNUSED(x_sc);
  14. __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y];
  15. __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0];
  16. *x_ql = tile_x_qs;
  17. *x_dm = (half2 *) tile_x_d;
  18. }
  19. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
  20. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  21. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  22. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  23. GGML_CUDA_ASSUME(i_offset >= 0);
  24. GGML_CUDA_ASSUME(i_offset < nwarps);
  25. GGML_CUDA_ASSUME(k >= 0);
  26. GGML_CUDA_ASSUME(k < WARP_SIZE);
  27. const int kbx = k / QI4_0;
  28. const int kqsx = k % QI4_0;
  29. const block_q4_0 * bx0 = (const block_q4_0 *) vx;
  30. float * x_dmf = (float *) x_dm;
  31. #pragma unroll
  32. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  33. int i = i0 + i_offset;
  34. if (need_check) {
  35. i = min(i, i_max);
  36. }
  37. const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
  38. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
  39. // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
  40. }
  41. const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
  42. const int kbxd = k % blocks_per_tile_x_row;
  43. #pragma unroll
  44. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
  45. int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
  46. if (need_check) {
  47. i = min(i, i_max);
  48. }
  49. const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
  50. x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
  51. }
  52. }
  53. static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
  54. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  55. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  56. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  57. const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
  58. const float * x_dmf = (const float *) x_dm;
  59. int u[2*VDR_Q4_0_Q8_1_MMQ];
  60. #pragma unroll
  61. for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
  62. u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
  63. u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
  64. }
  65. return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
  66. (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
  67. y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
  68. }
  69. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  70. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  71. __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y];
  72. __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1];
  73. *x_ql = tile_x_qs;
  74. *x_dm = tile_x_dm;
  75. }
  76. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
  77. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  78. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  79. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  80. GGML_CUDA_ASSUME(i_offset >= 0);
  81. GGML_CUDA_ASSUME(i_offset < nwarps);
  82. GGML_CUDA_ASSUME(k >= 0);
  83. GGML_CUDA_ASSUME(k < WARP_SIZE);
  84. const int kbx = k / QI4_1;
  85. const int kqsx = k % QI4_1;
  86. const block_q4_1 * bx0 = (const block_q4_1 *) vx;
  87. #pragma unroll
  88. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  89. int i = i0 + i_offset;
  90. if (need_check) {
  91. i = min(i, i_max);
  92. }
  93. const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
  94. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  95. }
  96. const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
  97. const int kbxd = k % blocks_per_tile_x_row;
  98. #pragma unroll
  99. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
  100. int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
  101. if (need_check) {
  102. i = min(i, i_max);
  103. }
  104. const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
  105. x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
  106. }
  107. }
  108. static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
  109. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  110. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  111. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  112. const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
  113. int u[2*VDR_Q4_1_Q8_1_MMQ];
  114. #pragma unroll
  115. for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
  116. u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
  117. u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
  118. }
  119. return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
  120. (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
  121. y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
  122. }
  123. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  124. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  125. __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
  126. __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0];
  127. *x_ql = tile_x_ql;
  128. *x_dm = (half2 *) tile_x_d;
  129. }
  130. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
  131. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  132. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  133. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  134. GGML_CUDA_ASSUME(i_offset >= 0);
  135. GGML_CUDA_ASSUME(i_offset < nwarps);
  136. GGML_CUDA_ASSUME(k >= 0);
  137. GGML_CUDA_ASSUME(k < WARP_SIZE);
  138. const int kbx = k / QI5_0;
  139. const int kqsx = k % QI5_0;
  140. const block_q5_0 * bx0 = (const block_q5_0 *) vx;
  141. #pragma unroll
  142. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  143. int i = i0 + i_offset;
  144. if (need_check) {
  145. i = min(i, i_max);
  146. }
  147. const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
  148. const int ql = get_int_from_uint8(bxi->qs, kqsx);
  149. const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));
  150. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  151. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  152. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  153. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  154. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  155. qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
  156. x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
  157. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  158. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  159. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  160. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  161. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  162. qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
  163. x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
  164. }
  165. const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
  166. const int kbxd = k % blocks_per_tile_x_row;
  167. float * x_dmf = (float *) x_dm;
  168. #pragma unroll
  169. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
  170. int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
  171. if (need_check) {
  172. i = min(i, i_max);
  173. }
  174. const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
  175. x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
  176. }
  177. }
  178. static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
  179. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  180. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  181. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  182. const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
  183. const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
  184. const float * x_dmf = (const float *) x_dm;
  185. const float * y_df = (const float *) y_ds;
  186. int u[2*VDR_Q5_0_Q8_1_MMQ];
  187. #pragma unroll
  188. for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
  189. u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
  190. u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
  191. }
  192. return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
  193. (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
  194. }
  195. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  196. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  197. __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
  198. __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1];
  199. *x_ql = tile_x_ql;
  200. *x_dm = tile_x_dm;
  201. }
  202. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
  203. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  204. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  205. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  206. GGML_CUDA_ASSUME(i_offset >= 0);
  207. GGML_CUDA_ASSUME(i_offset < nwarps);
  208. GGML_CUDA_ASSUME(k >= 0);
  209. GGML_CUDA_ASSUME(k < WARP_SIZE);
  210. const int kbx = k / QI5_1;
  211. const int kqsx = k % QI5_1;
  212. const block_q5_1 * bx0 = (const block_q5_1 *) vx;
  213. #pragma unroll
  214. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  215. int i = i0 + i_offset;
  216. if (need_check) {
  217. i = min(i, i_max);
  218. }
  219. const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
  220. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  221. const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
  222. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  223. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  224. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  225. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  226. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  227. x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
  228. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  229. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  230. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  231. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  232. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  233. x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
  234. }
  235. const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
  236. const int kbxd = k % blocks_per_tile_x_row;
  237. #pragma unroll
  238. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
  239. int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
  240. if (need_check) {
  241. i = min(i, i_max);
  242. }
  243. const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
  244. x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
  245. }
  246. }
  247. static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
  248. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  249. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  250. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  251. const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
  252. const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
  253. int u[2*VDR_Q5_1_Q8_1_MMQ];
  254. #pragma unroll
  255. for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
  256. u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
  257. u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
  258. }
  259. return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
  260. (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
  261. }
  262. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  263. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  264. __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y];
  265. __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0];
  266. *x_ql = tile_x_qs;
  267. *x_dm = (half2 *) tile_x_d;
  268. }
  269. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
  270. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  271. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  272. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  273. GGML_CUDA_ASSUME(i_offset >= 0);
  274. GGML_CUDA_ASSUME(i_offset < nwarps);
  275. GGML_CUDA_ASSUME(k >= 0);
  276. GGML_CUDA_ASSUME(k < WARP_SIZE);
  277. const int kbx = k / QI8_0;
  278. const int kqsx = k % QI8_0;
  279. float * x_dmf = (float *) x_dm;
  280. const block_q8_0 * bx0 = (const block_q8_0 *) vx;
  281. #pragma unroll
  282. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  283. int i = i0 + i_offset;
  284. if (need_check) {
  285. i = min(i, i_max);
  286. }
  287. const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
  288. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
  289. }
  290. const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
  291. const int kbxd = k % blocks_per_tile_x_row;
  292. #pragma unroll
  293. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
  294. int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
  295. if (need_check) {
  296. i = min(i, i_max);
  297. }
  298. const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
  299. x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
  300. }
  301. }
  302. static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
  303. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  304. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  305. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  306. const float * x_dmf = (const float *) x_dm;
  307. const float * y_df = (const float *) y_ds;
  308. return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
  309. (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
  310. y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
  311. }
  312. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  313. GGML_UNUSED(x_qh);
  314. __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y];
  315. __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K];
  316. __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4];
  317. *x_ql = tile_x_ql;
  318. *x_dm = tile_x_dm;
  319. *x_sc = tile_x_sc;
  320. }
  321. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
  322. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  323. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  324. GGML_UNUSED(x_qh);
  325. GGML_CUDA_ASSUME(i_offset >= 0);
  326. GGML_CUDA_ASSUME(i_offset < nwarps);
  327. GGML_CUDA_ASSUME(k >= 0);
  328. GGML_CUDA_ASSUME(k < WARP_SIZE);
  329. const int kbx = k / QI2_K;
  330. const int kqsx = k % QI2_K;
  331. const block_q2_K * bx0 = (const block_q2_K *) vx;
  332. #pragma unroll
  333. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  334. int i = i0 + i_offset;
  335. if (need_check) {
  336. i = min(i, i_max);
  337. }
  338. const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
  339. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  340. }
  341. const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
  342. const int kbxd = k % blocks_per_tile_x_row;
  343. #pragma unroll
  344. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
  345. int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;
  346. if (need_check) {
  347. i = min(i, i_max);
  348. }
  349. const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
  350. x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
  351. }
  352. #pragma unroll
  353. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
  354. int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
  355. if (need_check) {
  356. i = min(i, i_max);
  357. }
  358. const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
  359. x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
  360. }
  361. }
  362. static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
  363. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  364. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  365. GGML_UNUSED(x_qh);
  366. const int kbx = k / QI2_K;
  367. const int ky = (k % QI2_K) * QR2_K;
  368. const float * y_df = (const float *) y_ds;
  369. int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
  370. const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
  371. const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
  372. #pragma unroll
  373. for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
  374. v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
  375. }
  376. const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
  377. const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;
  378. return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
  379. }
  380. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  381. __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y];
  382. __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K];
  383. __shared__ int tile_x_qh[mmq_y * (WARP_SIZE/2) + mmq_y/2];
  384. __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4];
  385. *x_ql = tile_x_ql;
  386. *x_dm = tile_x_dm;
  387. *x_qh = tile_x_qh;
  388. *x_sc = tile_x_sc;
  389. }
  390. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
  391. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  392. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  393. GGML_CUDA_ASSUME(i_offset >= 0);
  394. GGML_CUDA_ASSUME(i_offset < nwarps);
  395. GGML_CUDA_ASSUME(k >= 0);
  396. GGML_CUDA_ASSUME(k < WARP_SIZE);
  397. const int kbx = k / QI3_K;
  398. const int kqsx = k % QI3_K;
  399. const block_q3_K * bx0 = (const block_q3_K *) vx;
  400. #pragma unroll
  401. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  402. int i = i0 + i_offset;
  403. if (need_check) {
  404. i = min(i, i_max);
  405. }
  406. const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
  407. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
  408. }
  409. const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
  410. const int kbxd = k % blocks_per_tile_x_row;
  411. float * x_dmf = (float *) x_dm;
  412. #pragma unroll
  413. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
  414. int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;
  415. if (need_check) {
  416. i = min(i, i_max);
  417. }
  418. const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
  419. x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
  420. }
  421. #pragma unroll
  422. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
  423. int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
  424. if (need_check) {
  425. i = min(i, i_max);
  426. }
  427. const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
  428. // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
  429. x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
  430. }
  431. #pragma unroll
  432. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
  433. int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
  434. if (need_check) {
  435. i = min(i, i_max);
  436. }
  437. const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
  438. const int ksc = k % (QI3_K/4);
  439. const int ksc_low = ksc % (QI3_K/8);
  440. const int shift_low = 4 * (ksc / (QI3_K/8));
  441. const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
  442. const int ksc_high = QI3_K/8;
  443. const int shift_high = 2 * ksc;
  444. const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
  445. const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
  446. x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;
  447. }
  448. }
  449. static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
  450. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  451. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  452. const int kbx = k / QI3_K;
  453. const int ky = (k % QI3_K) * QR3_K;
  454. const float * x_dmf = (const float *) x_dm;
  455. const float * y_df = (const float *) y_ds;
  456. const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
  457. int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
  458. #pragma unroll
  459. for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
  460. const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
  461. const int shift = 2 * ((ky % 32) / 8);
  462. const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
  463. const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
  464. const int vlh = (vh << 2) & 0x04040404;
  465. v[l] = __vsubss4(vll, vlh);
  466. }
  467. const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;
  468. return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
  469. }
  470. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  471. GGML_UNUSED(x_qh);
  472. __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y];
  473. __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K];
  474. __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8];
  475. *x_ql = tile_x_ql;
  476. *x_dm = tile_x_dm;
  477. *x_sc = tile_x_sc;
  478. }
  479. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
  480. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  481. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  482. GGML_UNUSED(x_qh);
  483. GGML_CUDA_ASSUME(i_offset >= 0);
  484. GGML_CUDA_ASSUME(i_offset < nwarps);
  485. GGML_CUDA_ASSUME(k >= 0);
  486. GGML_CUDA_ASSUME(k < WARP_SIZE);
  487. const int kbx = k / QI4_K; // == 0 if QK_K == 256
  488. const int kqsx = k % QI4_K; // == k if QK_K == 256
  489. const block_q4_K * bx0 = (const block_q4_K *) vx;
  490. #pragma unroll
  491. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  492. int i = i0 + i_offset;
  493. if (need_check) {
  494. i = min(i, i_max);
  495. }
  496. const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
  497. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  498. }
  499. const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
  500. const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
  501. #pragma unroll
  502. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
  503. int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;
  504. if (need_check) {
  505. i = min(i, i_max);
  506. }
  507. const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
  508. #if QK_K == 256
  509. x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
  510. #else
  511. x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]};
  512. #endif
  513. }
  514. #pragma unroll
  515. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  516. int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
  517. if (need_check) {
  518. i = min(i, i_max);
  519. }
  520. const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
  521. const int * scales = (const int *) bxi->scales;
  522. const int ksc = k % (WARP_SIZE/8);
  523. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  524. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  525. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  526. x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
  527. }
  528. }
  529. static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
  530. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  531. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  532. GGML_UNUSED(x_qh);
  533. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
  534. const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
  535. return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,
  536. x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
  537. }
  538. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  539. GGML_UNUSED(x_qh);
  540. __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
  541. __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K];
  542. __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8];
  543. *x_ql = tile_x_ql;
  544. *x_dm = tile_x_dm;
  545. *x_sc = tile_x_sc;
  546. }
  547. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
  548. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  549. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  550. GGML_UNUSED(x_qh);
  551. GGML_CUDA_ASSUME(i_offset >= 0);
  552. GGML_CUDA_ASSUME(i_offset < nwarps);
  553. GGML_CUDA_ASSUME(k >= 0);
  554. GGML_CUDA_ASSUME(k < WARP_SIZE);
  555. const int kbx = k / QI5_K; // == 0 if QK_K == 256
  556. const int kqsx = k % QI5_K; // == k if QK_K == 256
  557. const block_q5_K * bx0 = (const block_q5_K *) vx;
  558. #pragma unroll
  559. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  560. int i = i0 + i_offset;
  561. if (need_check) {
  562. i = min(i, i_max);
  563. }
  564. const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
  565. const int ky = QR5_K*kqsx;
  566. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  567. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  568. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  569. const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
  570. const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
  571. const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
  572. const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;
  573. const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);
  574. x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
  575. x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
  576. }
  577. const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
  578. const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
  579. #pragma unroll
  580. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
  581. int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;
  582. if (need_check) {
  583. i = min(i, i_max);
  584. }
  585. const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
  586. #if QK_K == 256
  587. x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
  588. #endif
  589. }
  590. #pragma unroll
  591. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  592. int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
  593. if (need_check) {
  594. i = min(i, i_max);
  595. }
  596. const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
  597. const int * scales = (const int *) bxi->scales;
  598. const int ksc = k % (WARP_SIZE/8);
  599. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  600. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  601. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  602. x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
  603. }
  604. }
  605. static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
  606. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  607. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  608. GGML_UNUSED(x_qh);
  609. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
  610. const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k;
  611. const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE;
  612. return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,
  613. x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
  614. }
  615. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  616. GGML_UNUSED(x_qh);
  617. __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
  618. __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K];
  619. __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8];
  620. *x_ql = tile_x_ql;
  621. *x_dm = tile_x_dm;
  622. *x_sc = tile_x_sc;
  623. }
  624. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
  625. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  626. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  627. GGML_UNUSED(x_qh);
  628. GGML_CUDA_ASSUME(i_offset >= 0);
  629. GGML_CUDA_ASSUME(i_offset < nwarps);
  630. GGML_CUDA_ASSUME(k >= 0);
  631. GGML_CUDA_ASSUME(k < WARP_SIZE);
  632. const int kbx = k / QI6_K; // == 0 if QK_K == 256
  633. const int kqsx = k % QI6_K; // == k if QK_K == 256
  634. const block_q6_K * bx0 = (const block_q6_K *) vx;
  635. #pragma unroll
  636. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  637. int i = i0 + i_offset;
  638. if (need_check) {
  639. i = min(i, i_max);
  640. }
  641. const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
  642. const int ky = QR6_K*kqsx;
  643. const int ql = get_int_from_uint8(bxi->ql, kqsx);
  644. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  645. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  646. const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
  647. const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
  648. const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
  649. const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;
  650. const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);
  651. x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
  652. x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
  653. }
  654. const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
  655. const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
  656. float * x_dmf = (float *) x_dm;
  657. #pragma unroll
  658. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
  659. int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;
  660. if (need_check) {
  661. i = min(i, i_max);
  662. }
  663. const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
  664. x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
  665. }
  666. #pragma unroll
  667. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  668. int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
  669. if (need_check) {
  670. i = min(i, i_max);
  671. }
  672. const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
  673. x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
  674. }
  675. }
  676. static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
  677. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  678. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  679. GGML_UNUSED(x_qh);
  680. const float * x_dmf = (const float *) x_dm;
  681. const float * y_df = (const float *) y_ds;
  682. const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
  683. const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k;
  684. const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE;
  685. return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
  686. }
  687. #define MMQ_X_Q4_0_RDNA2 64
  688. #define MMQ_Y_Q4_0_RDNA2 128
  689. #define NWARPS_Q4_0_RDNA2 8
  690. #define MMQ_X_Q4_0_RDNA1 64
  691. #define MMQ_Y_Q4_0_RDNA1 64
  692. #define NWARPS_Q4_0_RDNA1 8
  693. #if defined(CUDA_USE_TENSOR_CORES)
  694. #define MMQ_X_Q4_0_AMPERE 4
  695. #define MMQ_Y_Q4_0_AMPERE 32
  696. #define NWARPS_Q4_0_AMPERE 4
  697. #else
  698. #define MMQ_X_Q4_0_AMPERE 64
  699. #define MMQ_Y_Q4_0_AMPERE 128
  700. #define NWARPS_Q4_0_AMPERE 4
  701. #endif
  702. #define MMQ_X_Q4_0_PASCAL 64
  703. #define MMQ_Y_Q4_0_PASCAL 64
  704. #define NWARPS_Q4_0_PASCAL 8
  705. template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
  706. allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
  707. static __device__ __forceinline__ void mul_mat_q(
  708. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  709. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  710. const block_q_t * x = (const block_q_t *) vx;
  711. const block_q8_1 * y = (const block_q8_1 *) vy;
  712. const int blocks_per_row_x = ncols_x / qk;
  713. const int blocks_per_col_y = nrows_y / QK8_1;
  714. const int blocks_per_warp = WARP_SIZE / qi;
  715. const int & ncols_dst = ncols_y;
  716. const int row_dst_0 = blockIdx.x*mmq_y;
  717. const int & row_x_0 = row_dst_0;
  718. const int col_dst_0 = blockIdx.y*mmq_x;
  719. const int & col_y_0 = col_dst_0;
  720. int * tile_x_ql = nullptr;
  721. half2 * tile_x_dm = nullptr;
  722. int * tile_x_qh = nullptr;
  723. int * tile_x_sc = nullptr;
  724. allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
  725. __shared__ int tile_y_qs[mmq_x * WARP_SIZE];
  726. __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1];
  727. float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
  728. for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
  729. load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
  730. threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x);
  731. #pragma unroll
  732. for (int ir = 0; ir < qr; ++ir) {
  733. const int kqs = ir*WARP_SIZE + threadIdx.x;
  734. const int kbxd = kqs / QI8_1;
  735. #pragma unroll
  736. for (int i = 0; i < mmq_x; i += nwarps) {
  737. const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
  738. const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
  739. const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE;
  740. tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
  741. }
  742. #pragma unroll
  743. for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
  744. const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
  745. const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
  746. const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
  747. // if the sum is not needed it's faster to transform the scale to f32 ahead of time
  748. const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds;
  749. half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
  750. if (need_sum) {
  751. *dsi_dst = *dsi_src;
  752. } else {
  753. float * dfi_dst = (float *) dsi_dst;
  754. *dfi_dst = __low2float(*dsi_src);
  755. }
  756. }
  757. __syncthreads();
  758. // #pragma unroll // unrolling this loop causes too much register pressure
  759. for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
  760. #pragma unroll
  761. for (int j = 0; j < mmq_x; j += nwarps) {
  762. #pragma unroll
  763. for (int i = 0; i < mmq_y; i += WARP_SIZE) {
  764. sum[i/WARP_SIZE][j/nwarps] += vec_dot(
  765. tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
  766. threadIdx.x + i, threadIdx.y + j, k);
  767. }
  768. }
  769. }
  770. __syncthreads();
  771. }
  772. }
  773. #pragma unroll
  774. for (int j = 0; j < mmq_x; j += nwarps) {
  775. const int col_dst = col_dst_0 + j + threadIdx.y;
  776. if (col_dst >= ncols_dst) {
  777. return;
  778. }
  779. #pragma unroll
  780. for (int i = 0; i < mmq_y; i += WARP_SIZE) {
  781. const int row_dst = row_dst_0 + threadIdx.x + i;
  782. if (row_dst >= nrows_dst) {
  783. continue;
  784. }
  785. dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps];
  786. }
  787. }
  788. }
  789. template <bool need_check> static __global__ void
  790. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  791. #if defined(RDNA3) || defined(RDNA2)
  792. __launch_bounds__(WARP_SIZE*NWARPS_Q4_0_RDNA2, 2)
  793. #endif // defined(RDNA3) || defined(RDNA2)
  794. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  795. mul_mat_q4_0(
  796. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  797. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  798. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  799. #if defined(RDNA3) || defined(RDNA2)
  800. const int mmq_x = MMQ_X_Q4_0_RDNA2;
  801. const int mmq_y = MMQ_Y_Q4_0_RDNA2;
  802. const int nwarps = NWARPS_Q4_0_RDNA2;
  803. #else
  804. const int mmq_x = MMQ_X_Q4_0_RDNA1;
  805. const int mmq_y = MMQ_Y_Q4_0_RDNA1;
  806. const int nwarps = NWARPS_Q4_0_RDNA1;
  807. #endif // defined(RDNA3) || defined(RDNA2)
  808. mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
  809. load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
  810. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  811. #elif __CUDA_ARCH__ >= CC_VOLTA
  812. const int mmq_x = MMQ_X_Q4_0_AMPERE;
  813. const int mmq_y = MMQ_Y_Q4_0_AMPERE;
  814. const int nwarps = NWARPS_Q4_0_AMPERE;
  815. mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
  816. load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
  817. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  818. #elif __CUDA_ARCH__ >= MIN_CC_DP4A
  819. const int mmq_x = MMQ_X_Q4_0_PASCAL;
  820. const int mmq_y = MMQ_Y_Q4_0_PASCAL;
  821. const int nwarps = NWARPS_Q4_0_PASCAL;
  822. mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
  823. load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
  824. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  825. #else
  826. GGML_UNUSED(vec_dot_q4_0_q8_1_mul_mat);
  827. NO_DEVICE_CODE;
  828. #endif // __CUDA_ARCH__ >= CC_VOLTA
  829. }
  830. #define MMQ_X_Q4_1_RDNA2 64
  831. #define MMQ_Y_Q4_1_RDNA2 128
  832. #define NWARPS_Q4_1_RDNA2 8
  833. #define MMQ_X_Q4_1_RDNA1 64
  834. #define MMQ_Y_Q4_1_RDNA1 64
  835. #define NWARPS_Q4_1_RDNA1 8
  836. #if defined(CUDA_USE_TENSOR_CORES)
  837. #define MMQ_X_Q4_1_AMPERE 4
  838. #define MMQ_Y_Q4_1_AMPERE 32
  839. #define NWARPS_Q4_1_AMPERE 4
  840. #else
  841. #define MMQ_X_Q4_1_AMPERE 64
  842. #define MMQ_Y_Q4_1_AMPERE 128
  843. #define NWARPS_Q4_1_AMPERE 4
  844. #endif
  845. #define MMQ_X_Q4_1_PASCAL 64
  846. #define MMQ_Y_Q4_1_PASCAL 64
  847. #define NWARPS_Q4_1_PASCAL 8
  848. template <bool need_check> static __global__ void
  849. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  850. #if defined(RDNA3) || defined(RDNA2)
  851. __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_RDNA2, 2)
  852. #endif // defined(RDNA3) || defined(RDNA2)
  853. #elif __CUDA_ARCH__ < CC_VOLTA
  854. __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2)
  855. #endif // __CUDA_ARCH__ < CC_VOLTA
  856. mul_mat_q4_1(
  857. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  858. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  859. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  860. #if defined(RDNA3) || defined(RDNA2)
  861. const int mmq_x = MMQ_X_Q4_1_RDNA2;
  862. const int mmq_y = MMQ_Y_Q4_1_RDNA2;
  863. const int nwarps = NWARPS_Q4_1_RDNA2;
  864. #else
  865. const int mmq_x = MMQ_X_Q4_1_RDNA1;
  866. const int mmq_y = MMQ_Y_Q4_1_RDNA1;
  867. const int nwarps = NWARPS_Q4_1_RDNA1;
  868. #endif // defined(RDNA3) || defined(RDNA2)
  869. mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
  870. load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
  871. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  872. #elif __CUDA_ARCH__ >= CC_VOLTA
  873. const int mmq_x = MMQ_X_Q4_1_AMPERE;
  874. const int mmq_y = MMQ_Y_Q4_1_AMPERE;
  875. const int nwarps = NWARPS_Q4_1_AMPERE;
  876. mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
  877. load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
  878. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  879. #elif __CUDA_ARCH__ >= MIN_CC_DP4A
  880. const int mmq_x = MMQ_X_Q4_1_PASCAL;
  881. const int mmq_y = MMQ_Y_Q4_1_PASCAL;
  882. const int nwarps = NWARPS_Q4_1_PASCAL;
  883. mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
  884. load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
  885. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  886. #else
  887. GGML_UNUSED(vec_dot_q4_1_q8_1_mul_mat);
  888. NO_DEVICE_CODE;
  889. #endif // __CUDA_ARCH__ >= CC_VOLTA
  890. }
  891. #define MMQ_X_Q5_0_RDNA2 64
  892. #define MMQ_Y_Q5_0_RDNA2 128
  893. #define NWARPS_Q5_0_RDNA2 8
  894. #define MMQ_X_Q5_0_RDNA1 64
  895. #define MMQ_Y_Q5_0_RDNA1 64
  896. #define NWARPS_Q5_0_RDNA1 8
  897. #if defined(CUDA_USE_TENSOR_CORES)
  898. #define MMQ_X_Q5_0_AMPERE 4
  899. #define MMQ_Y_Q5_0_AMPERE 32
  900. #define NWARPS_Q5_0_AMPERE 4
  901. #else
  902. #define MMQ_X_Q5_0_AMPERE 128
  903. #define MMQ_Y_Q5_0_AMPERE 64
  904. #define NWARPS_Q5_0_AMPERE 4
  905. #endif
  906. #define MMQ_X_Q5_0_PASCAL 64
  907. #define MMQ_Y_Q5_0_PASCAL 64
  908. #define NWARPS_Q5_0_PASCAL 8
  909. template <bool need_check> static __global__ void
  910. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  911. #if defined(RDNA3) || defined(RDNA2)
  912. __launch_bounds__(WARP_SIZE*NWARPS_Q5_0_RDNA2, 2)
  913. #endif // defined(RDNA3) || defined(RDNA2)
  914. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  915. mul_mat_q5_0(
  916. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  917. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  918. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  919. #if defined(RDNA3) || defined(RDNA2)
  920. const int mmq_x = MMQ_X_Q5_0_RDNA2;
  921. const int mmq_y = MMQ_Y_Q5_0_RDNA2;
  922. const int nwarps = NWARPS_Q5_0_RDNA2;
  923. #else
  924. const int mmq_x = MMQ_X_Q5_0_RDNA1;
  925. const int mmq_y = MMQ_Y_Q5_0_RDNA1;
  926. const int nwarps = NWARPS_Q5_0_RDNA1;
  927. #endif // defined(RDNA3) || defined(RDNA2)
  928. mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
  929. load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
  930. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  931. #elif __CUDA_ARCH__ >= CC_VOLTA
  932. const int mmq_x = MMQ_X_Q5_0_AMPERE;
  933. const int mmq_y = MMQ_Y_Q5_0_AMPERE;
  934. const int nwarps = NWARPS_Q5_0_AMPERE;
  935. mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
  936. load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
  937. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  938. #elif __CUDA_ARCH__ >= MIN_CC_DP4A
  939. const int mmq_x = MMQ_X_Q5_0_PASCAL;
  940. const int mmq_y = MMQ_Y_Q5_0_PASCAL;
  941. const int nwarps = NWARPS_Q5_0_PASCAL;
  942. mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
  943. load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
  944. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  945. #else
  946. GGML_UNUSED(vec_dot_q5_0_q8_1_mul_mat);
  947. NO_DEVICE_CODE;
  948. #endif // __CUDA_ARCH__ >= CC_VOLTA
  949. }
  950. #define MMQ_X_Q5_1_RDNA2 64
  951. #define MMQ_Y_Q5_1_RDNA2 128
  952. #define NWARPS_Q5_1_RDNA2 8
  953. #define MMQ_X_Q5_1_RDNA1 64
  954. #define MMQ_Y_Q5_1_RDNA1 64
  955. #define NWARPS_Q5_1_RDNA1 8
  956. #if defined(CUDA_USE_TENSOR_CORES)
  957. #define MMQ_X_Q5_1_AMPERE 4
  958. #define MMQ_Y_Q5_1_AMPERE 32
  959. #define NWARPS_Q5_1_AMPERE 4
  960. #else
  961. #define MMQ_X_Q5_1_AMPERE 128
  962. #define MMQ_Y_Q5_1_AMPERE 64
  963. #define NWARPS_Q5_1_AMPERE 4
  964. #endif
  965. #define MMQ_X_Q5_1_PASCAL 64
  966. #define MMQ_Y_Q5_1_PASCAL 64
  967. #define NWARPS_Q5_1_PASCAL 8
  968. template <bool need_check> static __global__ void
  969. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  970. #if defined(RDNA3) || defined(RDNA2)
  971. __launch_bounds__(WARP_SIZE*NWARPS_Q5_1_RDNA2, 2)
  972. #endif // defined(RDNA3) || defined(RDNA2)
  973. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  974. mul_mat_q5_1(
  975. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  976. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  977. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  978. #if defined(RDNA3) || defined(RDNA2)
  979. const int mmq_x = MMQ_X_Q5_1_RDNA2;
  980. const int mmq_y = MMQ_Y_Q5_1_RDNA2;
  981. const int nwarps = NWARPS_Q5_1_RDNA2;
  982. #else
  983. const int mmq_x = MMQ_X_Q5_1_RDNA1;
  984. const int mmq_y = MMQ_Y_Q5_1_RDNA1;
  985. const int nwarps = NWARPS_Q5_1_RDNA1;
  986. #endif // defined(RDNA3) || defined(RDNA2)
  987. mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
  988. load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
  989. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  990. #elif __CUDA_ARCH__ >= CC_VOLTA
  991. const int mmq_x = MMQ_X_Q5_1_AMPERE;
  992. const int mmq_y = MMQ_Y_Q5_1_AMPERE;
  993. const int nwarps = NWARPS_Q5_1_AMPERE;
  994. mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
  995. load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
  996. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  997. #elif __CUDA_ARCH__ >= MIN_CC_DP4A
  998. const int mmq_x = MMQ_X_Q5_1_PASCAL;
  999. const int mmq_y = MMQ_Y_Q5_1_PASCAL;
  1000. const int nwarps = NWARPS_Q5_1_PASCAL;
  1001. mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
  1002. load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
  1003. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1004. #else
  1005. GGML_UNUSED(vec_dot_q5_1_q8_1_mul_mat);
  1006. NO_DEVICE_CODE;
  1007. #endif // __CUDA_ARCH__ >= CC_VOLTA
  1008. }
  1009. #define MMQ_X_Q8_0_RDNA2 64
  1010. #define MMQ_Y_Q8_0_RDNA2 128
  1011. #define NWARPS_Q8_0_RDNA2 8
  1012. #define MMQ_X_Q8_0_RDNA1 64
  1013. #define MMQ_Y_Q8_0_RDNA1 64
  1014. #define NWARPS_Q8_0_RDNA1 8
  1015. #if defined(CUDA_USE_TENSOR_CORES)
  1016. #define MMQ_X_Q8_0_AMPERE 4
  1017. #define MMQ_Y_Q8_0_AMPERE 32
  1018. #define NWARPS_Q8_0_AMPERE 4
  1019. #else
  1020. #define MMQ_X_Q8_0_AMPERE 128
  1021. #define MMQ_Y_Q8_0_AMPERE 64
  1022. #define NWARPS_Q8_0_AMPERE 4
  1023. #endif
  1024. #define MMQ_X_Q8_0_PASCAL 64
  1025. #define MMQ_Y_Q8_0_PASCAL 64
  1026. #define NWARPS_Q8_0_PASCAL 8
  1027. template <bool need_check> static __global__ void
  1028. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1029. #if defined(RDNA3) || defined(RDNA2)
  1030. __launch_bounds__(WARP_SIZE*NWARPS_Q8_0_RDNA2, 2)
  1031. #endif // defined(RDNA3) || defined(RDNA2)
  1032. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1033. mul_mat_q8_0(
  1034. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1035. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  1036. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1037. #if defined(RDNA3) || defined(RDNA2)
  1038. const int mmq_x = MMQ_X_Q8_0_RDNA2;
  1039. const int mmq_y = MMQ_Y_Q8_0_RDNA2;
  1040. const int nwarps = NWARPS_Q8_0_RDNA2;
  1041. #else
  1042. const int mmq_x = MMQ_X_Q8_0_RDNA1;
  1043. const int mmq_y = MMQ_Y_Q8_0_RDNA1;
  1044. const int nwarps = NWARPS_Q8_0_RDNA1;
  1045. #endif // defined(RDNA3) || defined(RDNA2)
  1046. mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
  1047. load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
  1048. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1049. #elif __CUDA_ARCH__ >= CC_VOLTA
  1050. const int mmq_x = MMQ_X_Q8_0_AMPERE;
  1051. const int mmq_y = MMQ_Y_Q8_0_AMPERE;
  1052. const int nwarps = NWARPS_Q8_0_AMPERE;
  1053. mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
  1054. load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
  1055. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1056. #elif __CUDA_ARCH__ >= MIN_CC_DP4A
  1057. const int mmq_x = MMQ_X_Q8_0_PASCAL;
  1058. const int mmq_y = MMQ_Y_Q8_0_PASCAL;
  1059. const int nwarps = NWARPS_Q8_0_PASCAL;
  1060. mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
  1061. load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
  1062. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1063. #else
  1064. GGML_UNUSED(vec_dot_q8_0_q8_1_mul_mat);
  1065. NO_DEVICE_CODE;
  1066. #endif // __CUDA_ARCH__ >= CC_VOLTA
  1067. }
  1068. #define MMQ_X_Q2_K_RDNA2 64
  1069. #define MMQ_Y_Q2_K_RDNA2 128
  1070. #define NWARPS_Q2_K_RDNA2 8
  1071. #define MMQ_X_Q2_K_RDNA1 128
  1072. #define MMQ_Y_Q2_K_RDNA1 32
  1073. #define NWARPS_Q2_K_RDNA1 8
  1074. #if defined(CUDA_USE_TENSOR_CORES)
  1075. #define MMQ_X_Q2_K_AMPERE 4
  1076. #define MMQ_Y_Q2_K_AMPERE 32
  1077. #define NWARPS_Q2_K_AMPERE 4
  1078. #else
  1079. #define MMQ_X_Q2_K_AMPERE 64
  1080. #define MMQ_Y_Q2_K_AMPERE 128
  1081. #define NWARPS_Q2_K_AMPERE 4
  1082. #endif
  1083. #define MMQ_X_Q2_K_PASCAL 64
  1084. #define MMQ_Y_Q2_K_PASCAL 64
  1085. #define NWARPS_Q2_K_PASCAL 8
  1086. template <bool need_check> static __global__ void
  1087. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1088. #if defined(RDNA3) || defined(RDNA2)
  1089. __launch_bounds__(WARP_SIZE*NWARPS_Q2_K_RDNA2, 2)
  1090. #endif // defined(RDNA3) || defined(RDNA2)
  1091. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1092. mul_mat_q2_K(
  1093. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1094. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  1095. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1096. #if defined(RDNA3) || defined(RDNA2)
  1097. const int mmq_x = MMQ_X_Q2_K_RDNA2;
  1098. const int mmq_y = MMQ_Y_Q2_K_RDNA2;
  1099. const int nwarps = NWARPS_Q2_K_RDNA2;
  1100. #else
  1101. const int mmq_x = MMQ_X_Q2_K_RDNA1;
  1102. const int mmq_y = MMQ_Y_Q2_K_RDNA1;
  1103. const int nwarps = NWARPS_Q2_K_RDNA1;
  1104. #endif // defined(RDNA3) || defined(RDNA2)
  1105. mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
  1106. load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
  1107. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1108. #elif __CUDA_ARCH__ >= CC_VOLTA
  1109. const int mmq_x = MMQ_X_Q2_K_AMPERE;
  1110. const int mmq_y = MMQ_Y_Q2_K_AMPERE;
  1111. const int nwarps = NWARPS_Q2_K_AMPERE;
  1112. mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
  1113. load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
  1114. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1115. #elif __CUDA_ARCH__ >= MIN_CC_DP4A
  1116. const int mmq_x = MMQ_X_Q2_K_PASCAL;
  1117. const int mmq_y = MMQ_Y_Q2_K_PASCAL;
  1118. const int nwarps = NWARPS_Q2_K_PASCAL;
  1119. mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
  1120. load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
  1121. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1122. #else
  1123. GGML_UNUSED(vec_dot_q2_K_q8_1_mul_mat);
  1124. NO_DEVICE_CODE;
  1125. #endif // __CUDA_ARCH__ >= CC_VOLTA
  1126. }
  1127. #define MMQ_X_Q3_K_RDNA2 128
  1128. #define MMQ_Y_Q3_K_RDNA2 64
  1129. #define NWARPS_Q3_K_RDNA2 8
  1130. #define MMQ_X_Q3_K_RDNA1 32
  1131. #define MMQ_Y_Q3_K_RDNA1 128
  1132. #define NWARPS_Q3_K_RDNA1 8
  1133. #if defined(CUDA_USE_TENSOR_CORES)
  1134. #define MMQ_X_Q3_K_AMPERE 4
  1135. #define MMQ_Y_Q3_K_AMPERE 32
  1136. #define NWARPS_Q3_K_AMPERE 4
  1137. #else
  1138. #define MMQ_X_Q3_K_AMPERE 128
  1139. #define MMQ_Y_Q3_K_AMPERE 128
  1140. #define NWARPS_Q3_K_AMPERE 4
  1141. #endif
  1142. #define MMQ_X_Q3_K_PASCAL 64
  1143. #define MMQ_Y_Q3_K_PASCAL 64
  1144. #define NWARPS_Q3_K_PASCAL 8
  1145. template <bool need_check> static __global__ void
  1146. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1147. #if defined(RDNA3) || defined(RDNA2)
  1148. __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_RDNA2, 2)
  1149. #endif // defined(RDNA3) || defined(RDNA2)
  1150. #elif __CUDA_ARCH__ < CC_VOLTA
  1151. __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2)
  1152. #endif // __CUDA_ARCH__ < CC_VOLTA
  1153. mul_mat_q3_K(
  1154. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1155. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  1156. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1157. #if defined(RDNA3) || defined(RDNA2)
  1158. const int mmq_x = MMQ_X_Q3_K_RDNA2;
  1159. const int mmq_y = MMQ_Y_Q3_K_RDNA2;
  1160. const int nwarps = NWARPS_Q3_K_RDNA2;
  1161. #else
  1162. const int mmq_x = MMQ_X_Q3_K_RDNA1;
  1163. const int mmq_y = MMQ_Y_Q3_K_RDNA1;
  1164. const int nwarps = NWARPS_Q3_K_RDNA1;
  1165. #endif // defined(RDNA3) || defined(RDNA2)
  1166. mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
  1167. load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
  1168. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1169. #elif __CUDA_ARCH__ >= CC_VOLTA
  1170. const int mmq_x = MMQ_X_Q3_K_AMPERE;
  1171. const int mmq_y = MMQ_Y_Q3_K_AMPERE;
  1172. const int nwarps = NWARPS_Q3_K_AMPERE;
  1173. mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
  1174. load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
  1175. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1176. #elif __CUDA_ARCH__ >= MIN_CC_DP4A
  1177. const int mmq_x = MMQ_X_Q3_K_PASCAL;
  1178. const int mmq_y = MMQ_Y_Q3_K_PASCAL;
  1179. const int nwarps = NWARPS_Q3_K_PASCAL;
  1180. mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
  1181. load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
  1182. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1183. #else
  1184. GGML_UNUSED(vec_dot_q3_K_q8_1_mul_mat);
  1185. NO_DEVICE_CODE;
  1186. #endif // __CUDA_ARCH__ >= CC_VOLTA
  1187. }
  1188. #define MMQ_X_Q4_K_RDNA2 64
  1189. #define MMQ_Y_Q4_K_RDNA2 128
  1190. #define NWARPS_Q4_K_RDNA2 8
  1191. #define MMQ_X_Q4_K_RDNA1 32
  1192. #define MMQ_Y_Q4_K_RDNA1 64
  1193. #define NWARPS_Q4_K_RDNA1 8
  1194. #if defined(CUDA_USE_TENSOR_CORES)
  1195. #define MMQ_X_Q4_K_AMPERE 4
  1196. #define MMQ_Y_Q4_K_AMPERE 32
  1197. #define NWARPS_Q4_K_AMPERE 4
  1198. #else
  1199. #define MMQ_X_Q4_K_AMPERE 64
  1200. #define MMQ_Y_Q4_K_AMPERE 128
  1201. #define NWARPS_Q4_K_AMPERE 4
  1202. #endif
  1203. #define MMQ_X_Q4_K_PASCAL 64
  1204. #define MMQ_Y_Q4_K_PASCAL 64
  1205. #define NWARPS_Q4_K_PASCAL 8
  1206. template <bool need_check> static __global__ void
  1207. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1208. #if defined(RDNA3) || defined(RDNA2)
  1209. __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_RDNA2, 2)
  1210. #endif // defined(RDNA3) || defined(RDNA2)
  1211. #elif __CUDA_ARCH__ < CC_VOLTA
  1212. __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2)
  1213. #endif // __CUDA_ARCH__ < CC_VOLTA
  1214. mul_mat_q4_K(
  1215. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1216. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  1217. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1218. #if defined(RDNA3) || defined(RDNA2)
  1219. const int mmq_x = MMQ_X_Q4_K_RDNA2;
  1220. const int mmq_y = MMQ_Y_Q4_K_RDNA2;
  1221. const int nwarps = NWARPS_Q4_K_RDNA2;
  1222. #else
  1223. const int mmq_x = MMQ_X_Q4_K_RDNA1;
  1224. const int mmq_y = MMQ_Y_Q4_K_RDNA1;
  1225. const int nwarps = NWARPS_Q4_K_RDNA1;
  1226. #endif // defined(RDNA3) || defined(RDNA2)
  1227. mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
  1228. load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
  1229. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1230. #elif __CUDA_ARCH__ >= CC_VOLTA
  1231. const int mmq_x = MMQ_X_Q4_K_AMPERE;
  1232. const int mmq_y = MMQ_Y_Q4_K_AMPERE;
  1233. const int nwarps = NWARPS_Q4_K_AMPERE;
  1234. mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
  1235. load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
  1236. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1237. #elif __CUDA_ARCH__ >= MIN_CC_DP4A
  1238. const int mmq_x = MMQ_X_Q4_K_PASCAL;
  1239. const int mmq_y = MMQ_Y_Q4_K_PASCAL;
  1240. const int nwarps = NWARPS_Q4_K_PASCAL;
  1241. mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
  1242. load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
  1243. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1244. #else
  1245. GGML_UNUSED(vec_dot_q4_K_q8_1_mul_mat);
  1246. NO_DEVICE_CODE;
  1247. #endif // __CUDA_ARCH__ >= CC_VOLTA
  1248. }
  1249. #define MMQ_X_Q5_K_RDNA2 64
  1250. #define MMQ_Y_Q5_K_RDNA2 128
  1251. #define NWARPS_Q5_K_RDNA2 8
  1252. #define MMQ_X_Q5_K_RDNA1 32
  1253. #define MMQ_Y_Q5_K_RDNA1 64
  1254. #define NWARPS_Q5_K_RDNA1 8
  1255. #if defined(CUDA_USE_TENSOR_CORES)
  1256. #define MMQ_X_Q5_K_AMPERE 4
  1257. #define MMQ_Y_Q5_K_AMPERE 32
  1258. #define NWARPS_Q5_K_AMPERE 4
  1259. #else
  1260. #define MMQ_X_Q5_K_AMPERE 64
  1261. #define MMQ_Y_Q5_K_AMPERE 128
  1262. #define NWARPS_Q5_K_AMPERE 4
  1263. #endif
  1264. #define MMQ_X_Q5_K_PASCAL 64
  1265. #define MMQ_Y_Q5_K_PASCAL 64
  1266. #define NWARPS_Q5_K_PASCAL 8
  1267. template <bool need_check> static __global__ void
  1268. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1269. #if defined(RDNA3) || defined(RDNA2)
  1270. __launch_bounds__(WARP_SIZE*NWARPS_Q5_K_RDNA2, 2)
  1271. #endif // defined(RDNA3) || defined(RDNA2)
  1272. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1273. mul_mat_q5_K(
  1274. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1275. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  1276. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1277. #if defined(RDNA3) || defined(RDNA2)
  1278. const int mmq_x = MMQ_X_Q5_K_RDNA2;
  1279. const int mmq_y = MMQ_Y_Q5_K_RDNA2;
  1280. const int nwarps = NWARPS_Q5_K_RDNA2;
  1281. #else
  1282. const int mmq_x = MMQ_X_Q5_K_RDNA1;
  1283. const int mmq_y = MMQ_Y_Q5_K_RDNA1;
  1284. const int nwarps = NWARPS_Q5_K_RDNA1;
  1285. #endif // defined(RDNA3) || defined(RDNA2)
  1286. mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
  1287. load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
  1288. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1289. #elif __CUDA_ARCH__ >= CC_VOLTA
  1290. const int mmq_x = MMQ_X_Q5_K_AMPERE;
  1291. const int mmq_y = MMQ_Y_Q5_K_AMPERE;
  1292. const int nwarps = NWARPS_Q5_K_AMPERE;
  1293. mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
  1294. load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
  1295. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1296. #elif __CUDA_ARCH__ >= MIN_CC_DP4A
  1297. const int mmq_x = MMQ_X_Q5_K_PASCAL;
  1298. const int mmq_y = MMQ_Y_Q5_K_PASCAL;
  1299. const int nwarps = NWARPS_Q5_K_PASCAL;
  1300. mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
  1301. load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
  1302. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1303. #else
  1304. GGML_UNUSED(vec_dot_q5_K_q8_1_mul_mat);
  1305. NO_DEVICE_CODE;
  1306. #endif // __CUDA_ARCH__ >= CC_VOLTA
  1307. }
  1308. #define MMQ_X_Q6_K_RDNA2 64
  1309. #define MMQ_Y_Q6_K_RDNA2 128
  1310. #define NWARPS_Q6_K_RDNA2 8
  1311. #define MMQ_X_Q6_K_RDNA1 32
  1312. #define MMQ_Y_Q6_K_RDNA1 64
  1313. #define NWARPS_Q6_K_RDNA1 8
  1314. #if defined(CUDA_USE_TENSOR_CORES)
  1315. #define MMQ_X_Q6_K_AMPERE 4
  1316. #define MMQ_Y_Q6_K_AMPERE 32
  1317. #define NWARPS_Q6_K_AMPERE 4
  1318. #else
  1319. #define MMQ_X_Q6_K_AMPERE 64
  1320. #define MMQ_Y_Q6_K_AMPERE 64
  1321. #define NWARPS_Q6_K_AMPERE 4
  1322. #endif
  1323. #define MMQ_X_Q6_K_PASCAL 64
  1324. #define MMQ_Y_Q6_K_PASCAL 64
  1325. #define NWARPS_Q6_K_PASCAL 8
  1326. template <bool need_check> static __global__ void
  1327. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1328. #if defined(RDNA3) || defined(RDNA2)
  1329. __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_RDNA2, 2)
  1330. #endif // defined(RDNA3) || defined(RDNA2)
  1331. #elif __CUDA_ARCH__ < CC_VOLTA
  1332. __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2)
  1333. #endif // __CUDA_ARCH__ < CC_VOLTA
  1334. mul_mat_q6_K(
  1335. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1336. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  1337. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1338. #if defined(RDNA3) || defined(RDNA2)
  1339. const int mmq_x = MMQ_X_Q6_K_RDNA2;
  1340. const int mmq_y = MMQ_Y_Q6_K_RDNA2;
  1341. const int nwarps = NWARPS_Q6_K_RDNA2;
  1342. #else
  1343. const int mmq_x = MMQ_X_Q6_K_RDNA1;
  1344. const int mmq_y = MMQ_Y_Q6_K_RDNA1;
  1345. const int nwarps = NWARPS_Q6_K_RDNA1;
  1346. #endif // defined(RDNA3) || defined(RDNA2)
  1347. mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
  1348. load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
  1349. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1350. #elif __CUDA_ARCH__ >= CC_VOLTA
  1351. const int mmq_x = MMQ_X_Q6_K_AMPERE;
  1352. const int mmq_y = MMQ_Y_Q6_K_AMPERE;
  1353. const int nwarps = NWARPS_Q6_K_AMPERE;
  1354. mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
  1355. load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
  1356. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1357. #elif __CUDA_ARCH__ >= MIN_CC_DP4A
  1358. const int mmq_x = MMQ_X_Q6_K_PASCAL;
  1359. const int mmq_y = MMQ_Y_Q6_K_PASCAL;
  1360. const int nwarps = NWARPS_Q6_K_PASCAL;
  1361. mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
  1362. load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
  1363. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1364. #else
  1365. GGML_UNUSED(vec_dot_q6_K_q8_1_mul_mat);
  1366. NO_DEVICE_CODE;
  1367. #endif // __CUDA_ARCH__ >= CC_VOLTA
  1368. }
  1369. static void ggml_mul_mat_q4_0_q8_1_cuda(
  1370. const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
  1371. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  1372. int id;
  1373. CUDA_CHECK(cudaGetDevice(&id));
  1374. const int compute_capability = ggml_cuda_info().devices[id].cc;
  1375. int mmq_x, mmq_y, nwarps;
  1376. if (compute_capability >= CC_RDNA2) {
  1377. mmq_x = MMQ_X_Q4_0_RDNA2;
  1378. mmq_y = MMQ_Y_Q4_0_RDNA2;
  1379. nwarps = NWARPS_Q4_0_RDNA2;
  1380. } else if (compute_capability >= CC_OFFSET_AMD) {
  1381. mmq_x = MMQ_X_Q4_0_RDNA1;
  1382. mmq_y = MMQ_Y_Q4_0_RDNA1;
  1383. nwarps = NWARPS_Q4_0_RDNA1;
  1384. } else if (compute_capability >= CC_VOLTA) {
  1385. mmq_x = MMQ_X_Q4_0_AMPERE;
  1386. mmq_y = MMQ_Y_Q4_0_AMPERE;
  1387. nwarps = NWARPS_Q4_0_AMPERE;
  1388. } else if (compute_capability >= MIN_CC_DP4A) {
  1389. mmq_x = MMQ_X_Q4_0_PASCAL;
  1390. mmq_y = MMQ_Y_Q4_0_PASCAL;
  1391. nwarps = NWARPS_Q4_0_PASCAL;
  1392. } else {
  1393. GGML_ASSERT(false);
  1394. }
  1395. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  1396. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  1397. const dim3 block_nums(block_num_x, block_num_y, 1);
  1398. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  1399. if (nrows_x % mmq_y == 0) {
  1400. const bool need_check = false;
  1401. mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
  1402. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1403. } else {
  1404. const bool need_check = true;
  1405. mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
  1406. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1407. }
  1408. }
  1409. static void ggml_mul_mat_q4_1_q8_1_cuda(
  1410. const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
  1411. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  1412. int id;
  1413. CUDA_CHECK(cudaGetDevice(&id));
  1414. const int compute_capability = ggml_cuda_info().devices[id].cc;
  1415. int mmq_x, mmq_y, nwarps;
  1416. if (compute_capability >= CC_RDNA2) {
  1417. mmq_x = MMQ_X_Q4_1_RDNA2;
  1418. mmq_y = MMQ_Y_Q4_1_RDNA2;
  1419. nwarps = NWARPS_Q4_1_RDNA2;
  1420. } else if (compute_capability >= CC_OFFSET_AMD) {
  1421. mmq_x = MMQ_X_Q4_1_RDNA1;
  1422. mmq_y = MMQ_Y_Q4_1_RDNA1;
  1423. nwarps = NWARPS_Q4_1_RDNA1;
  1424. } else if (compute_capability >= CC_VOLTA) {
  1425. mmq_x = MMQ_X_Q4_1_AMPERE;
  1426. mmq_y = MMQ_Y_Q4_1_AMPERE;
  1427. nwarps = NWARPS_Q4_1_AMPERE;
  1428. } else if (compute_capability >= MIN_CC_DP4A) {
  1429. mmq_x = MMQ_X_Q4_1_PASCAL;
  1430. mmq_y = MMQ_Y_Q4_1_PASCAL;
  1431. nwarps = NWARPS_Q4_1_PASCAL;
  1432. } else {
  1433. GGML_ASSERT(false);
  1434. }
  1435. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  1436. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  1437. const dim3 block_nums(block_num_x, block_num_y, 1);
  1438. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  1439. if (nrows_x % mmq_y == 0) {
  1440. const bool need_check = false;
  1441. mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
  1442. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1443. } else {
  1444. const bool need_check = true;
  1445. mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
  1446. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1447. }
  1448. }
  1449. static void ggml_mul_mat_q5_0_q8_1_cuda(
  1450. const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
  1451. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  1452. int id;
  1453. CUDA_CHECK(cudaGetDevice(&id));
  1454. const int compute_capability = ggml_cuda_info().devices[id].cc;
  1455. int mmq_x, mmq_y, nwarps;
  1456. if (compute_capability >= CC_RDNA2) {
  1457. mmq_x = MMQ_X_Q5_0_RDNA2;
  1458. mmq_y = MMQ_Y_Q5_0_RDNA2;
  1459. nwarps = NWARPS_Q5_0_RDNA2;
  1460. } else if (compute_capability >= CC_OFFSET_AMD) {
  1461. mmq_x = MMQ_X_Q5_0_RDNA1;
  1462. mmq_y = MMQ_Y_Q5_0_RDNA1;
  1463. nwarps = NWARPS_Q5_0_RDNA1;
  1464. } else if (compute_capability >= CC_VOLTA) {
  1465. mmq_x = MMQ_X_Q5_0_AMPERE;
  1466. mmq_y = MMQ_Y_Q5_0_AMPERE;
  1467. nwarps = NWARPS_Q5_0_AMPERE;
  1468. } else if (compute_capability >= MIN_CC_DP4A) {
  1469. mmq_x = MMQ_X_Q5_0_PASCAL;
  1470. mmq_y = MMQ_Y_Q5_0_PASCAL;
  1471. nwarps = NWARPS_Q5_0_PASCAL;
  1472. } else {
  1473. GGML_ASSERT(false);
  1474. }
  1475. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  1476. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  1477. const dim3 block_nums(block_num_x, block_num_y, 1);
  1478. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  1479. if (nrows_x % mmq_y == 0) {
  1480. const bool need_check = false;
  1481. mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
  1482. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1483. } else {
  1484. const bool need_check = true;
  1485. mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
  1486. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1487. }
  1488. }
  1489. static void ggml_mul_mat_q5_1_q8_1_cuda(
  1490. const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
  1491. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  1492. int id;
  1493. CUDA_CHECK(cudaGetDevice(&id));
  1494. const int compute_capability = ggml_cuda_info().devices[id].cc;
  1495. int mmq_x, mmq_y, nwarps;
  1496. if (compute_capability >= CC_RDNA2) {
  1497. mmq_x = MMQ_X_Q5_1_RDNA2;
  1498. mmq_y = MMQ_Y_Q5_1_RDNA2;
  1499. nwarps = NWARPS_Q5_1_RDNA2;
  1500. } else if (compute_capability >= CC_OFFSET_AMD) {
  1501. mmq_x = MMQ_X_Q5_1_RDNA1;
  1502. mmq_y = MMQ_Y_Q5_1_RDNA1;
  1503. nwarps = NWARPS_Q5_1_RDNA1;
  1504. } else if (compute_capability >= CC_VOLTA) {
  1505. mmq_x = MMQ_X_Q5_1_AMPERE;
  1506. mmq_y = MMQ_Y_Q5_1_AMPERE;
  1507. nwarps = NWARPS_Q5_1_AMPERE;
  1508. } else if (compute_capability >= MIN_CC_DP4A) {
  1509. mmq_x = MMQ_X_Q5_1_PASCAL;
  1510. mmq_y = MMQ_Y_Q5_1_PASCAL;
  1511. nwarps = NWARPS_Q5_1_PASCAL;
  1512. } else {
  1513. GGML_ASSERT(false);
  1514. }
  1515. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  1516. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  1517. const dim3 block_nums(block_num_x, block_num_y, 1);
  1518. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  1519. if (nrows_x % mmq_y == 0) {
  1520. const bool need_check = false;
  1521. mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
  1522. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1523. } else {
  1524. const bool need_check = true;
  1525. mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
  1526. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1527. }
  1528. }
  1529. static void ggml_mul_mat_q8_0_q8_1_cuda(
  1530. const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
  1531. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  1532. int id;
  1533. CUDA_CHECK(cudaGetDevice(&id));
  1534. const int compute_capability = ggml_cuda_info().devices[id].cc;
  1535. int mmq_x, mmq_y, nwarps;
  1536. if (compute_capability >= CC_RDNA2) {
  1537. mmq_x = MMQ_X_Q8_0_RDNA2;
  1538. mmq_y = MMQ_Y_Q8_0_RDNA2;
  1539. nwarps = NWARPS_Q8_0_RDNA2;
  1540. } else if (compute_capability >= CC_OFFSET_AMD) {
  1541. mmq_x = MMQ_X_Q8_0_RDNA1;
  1542. mmq_y = MMQ_Y_Q8_0_RDNA1;
  1543. nwarps = NWARPS_Q8_0_RDNA1;
  1544. } else if (compute_capability >= CC_VOLTA) {
  1545. mmq_x = MMQ_X_Q8_0_AMPERE;
  1546. mmq_y = MMQ_Y_Q8_0_AMPERE;
  1547. nwarps = NWARPS_Q8_0_AMPERE;
  1548. } else if (compute_capability >= MIN_CC_DP4A) {
  1549. mmq_x = MMQ_X_Q8_0_PASCAL;
  1550. mmq_y = MMQ_Y_Q8_0_PASCAL;
  1551. nwarps = NWARPS_Q8_0_PASCAL;
  1552. } else {
  1553. GGML_ASSERT(false);
  1554. }
  1555. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  1556. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  1557. const dim3 block_nums(block_num_x, block_num_y, 1);
  1558. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  1559. if (nrows_x % mmq_y == 0) {
  1560. const bool need_check = false;
  1561. mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
  1562. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1563. } else {
  1564. const bool need_check = true;
  1565. mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
  1566. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1567. }
  1568. }
  1569. static void ggml_mul_mat_q2_K_q8_1_cuda(
  1570. const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
  1571. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  1572. int id;
  1573. CUDA_CHECK(cudaGetDevice(&id));
  1574. const int compute_capability = ggml_cuda_info().devices[id].cc;
  1575. int mmq_x, mmq_y, nwarps;
  1576. if (compute_capability >= CC_RDNA2) {
  1577. mmq_x = MMQ_X_Q2_K_RDNA2;
  1578. mmq_y = MMQ_Y_Q2_K_RDNA2;
  1579. nwarps = NWARPS_Q2_K_RDNA2;
  1580. } else if (compute_capability >= CC_OFFSET_AMD) {
  1581. mmq_x = MMQ_X_Q2_K_RDNA1;
  1582. mmq_y = MMQ_Y_Q2_K_RDNA1;
  1583. nwarps = NWARPS_Q2_K_RDNA1;
  1584. } else if (compute_capability >= CC_VOLTA) {
  1585. mmq_x = MMQ_X_Q2_K_AMPERE;
  1586. mmq_y = MMQ_Y_Q2_K_AMPERE;
  1587. nwarps = NWARPS_Q2_K_AMPERE;
  1588. } else if (compute_capability >= MIN_CC_DP4A) {
  1589. mmq_x = MMQ_X_Q2_K_PASCAL;
  1590. mmq_y = MMQ_Y_Q2_K_PASCAL;
  1591. nwarps = NWARPS_Q2_K_PASCAL;
  1592. } else {
  1593. GGML_ASSERT(false);
  1594. }
  1595. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  1596. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  1597. const dim3 block_nums(block_num_x, block_num_y, 1);
  1598. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  1599. if (nrows_x % mmq_y == 0) {
  1600. const bool need_check = false;
  1601. mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
  1602. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1603. } else {
  1604. const bool need_check = true;
  1605. mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
  1606. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1607. }
  1608. }
  1609. static void ggml_mul_mat_q3_K_q8_1_cuda(
  1610. const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
  1611. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  1612. #if QK_K == 256
  1613. int id;
  1614. CUDA_CHECK(cudaGetDevice(&id));
  1615. const int compute_capability = ggml_cuda_info().devices[id].cc;
  1616. int mmq_x, mmq_y, nwarps;
  1617. if (compute_capability >= CC_RDNA2) {
  1618. mmq_x = MMQ_X_Q3_K_RDNA2;
  1619. mmq_y = MMQ_Y_Q3_K_RDNA2;
  1620. nwarps = NWARPS_Q3_K_RDNA2;
  1621. } else if (compute_capability >= CC_OFFSET_AMD) {
  1622. mmq_x = MMQ_X_Q3_K_RDNA1;
  1623. mmq_y = MMQ_Y_Q3_K_RDNA1;
  1624. nwarps = NWARPS_Q3_K_RDNA1;
  1625. } else if (compute_capability >= CC_VOLTA) {
  1626. mmq_x = MMQ_X_Q3_K_AMPERE;
  1627. mmq_y = MMQ_Y_Q3_K_AMPERE;
  1628. nwarps = NWARPS_Q3_K_AMPERE;
  1629. } else if (compute_capability >= MIN_CC_DP4A) {
  1630. mmq_x = MMQ_X_Q3_K_PASCAL;
  1631. mmq_y = MMQ_Y_Q3_K_PASCAL;
  1632. nwarps = NWARPS_Q3_K_PASCAL;
  1633. } else {
  1634. GGML_ASSERT(false);
  1635. }
  1636. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  1637. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  1638. const dim3 block_nums(block_num_x, block_num_y, 1);
  1639. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  1640. if (nrows_x % mmq_y == 0) {
  1641. const bool need_check = false;
  1642. mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
  1643. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1644. } else {
  1645. const bool need_check = true;
  1646. mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
  1647. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1648. }
  1649. #endif
  1650. }
  1651. static void ggml_mul_mat_q4_K_q8_1_cuda(
  1652. const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
  1653. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  1654. int id;
  1655. CUDA_CHECK(cudaGetDevice(&id));
  1656. const int compute_capability = ggml_cuda_info().devices[id].cc;
  1657. int mmq_x, mmq_y, nwarps;
  1658. if (compute_capability >= CC_RDNA2) {
  1659. mmq_x = MMQ_X_Q4_K_RDNA2;
  1660. mmq_y = MMQ_Y_Q4_K_RDNA2;
  1661. nwarps = NWARPS_Q4_K_RDNA2;
  1662. } else if (compute_capability >= CC_OFFSET_AMD) {
  1663. mmq_x = MMQ_X_Q4_K_RDNA1;
  1664. mmq_y = MMQ_Y_Q4_K_RDNA1;
  1665. nwarps = NWARPS_Q4_K_RDNA1;
  1666. } else if (compute_capability >= CC_VOLTA) {
  1667. mmq_x = MMQ_X_Q4_K_AMPERE;
  1668. mmq_y = MMQ_Y_Q4_K_AMPERE;
  1669. nwarps = NWARPS_Q4_K_AMPERE;
  1670. } else if (compute_capability >= MIN_CC_DP4A) {
  1671. mmq_x = MMQ_X_Q4_K_PASCAL;
  1672. mmq_y = MMQ_Y_Q4_K_PASCAL;
  1673. nwarps = NWARPS_Q4_K_PASCAL;
  1674. } else {
  1675. GGML_ASSERT(false);
  1676. }
  1677. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  1678. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  1679. const dim3 block_nums(block_num_x, block_num_y, 1);
  1680. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  1681. if (nrows_x % mmq_y == 0) {
  1682. const bool need_check = false;
  1683. mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
  1684. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1685. } else {
  1686. const bool need_check = true;
  1687. mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
  1688. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1689. }
  1690. }
  1691. static void ggml_mul_mat_q5_K_q8_1_cuda(
  1692. const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
  1693. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  1694. int id;
  1695. CUDA_CHECK(cudaGetDevice(&id));
  1696. const int compute_capability = ggml_cuda_info().devices[id].cc;
  1697. int mmq_x, mmq_y, nwarps;
  1698. if (compute_capability >= CC_RDNA2) {
  1699. mmq_x = MMQ_X_Q5_K_RDNA2;
  1700. mmq_y = MMQ_Y_Q5_K_RDNA2;
  1701. nwarps = NWARPS_Q5_K_RDNA2;
  1702. } else if (compute_capability >= CC_OFFSET_AMD) {
  1703. mmq_x = MMQ_X_Q5_K_RDNA1;
  1704. mmq_y = MMQ_Y_Q5_K_RDNA1;
  1705. nwarps = NWARPS_Q5_K_RDNA1;
  1706. } else if (compute_capability >= CC_VOLTA) {
  1707. mmq_x = MMQ_X_Q5_K_AMPERE;
  1708. mmq_y = MMQ_Y_Q5_K_AMPERE;
  1709. nwarps = NWARPS_Q5_K_AMPERE;
  1710. } else if (compute_capability >= MIN_CC_DP4A) {
  1711. mmq_x = MMQ_X_Q5_K_PASCAL;
  1712. mmq_y = MMQ_Y_Q5_K_PASCAL;
  1713. nwarps = NWARPS_Q5_K_PASCAL;
  1714. } else {
  1715. GGML_ASSERT(false);
  1716. }
  1717. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  1718. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  1719. const dim3 block_nums(block_num_x, block_num_y, 1);
  1720. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  1721. if (nrows_x % mmq_y == 0) {
  1722. const bool need_check = false;
  1723. mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
  1724. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1725. } else {
  1726. const bool need_check = true;
  1727. mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
  1728. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1729. }
  1730. }
  1731. static void ggml_mul_mat_q6_K_q8_1_cuda(
  1732. const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
  1733. const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
  1734. int id;
  1735. CUDA_CHECK(cudaGetDevice(&id));
  1736. const int compute_capability = ggml_cuda_info().devices[id].cc;
  1737. int mmq_x, mmq_y, nwarps;
  1738. if (compute_capability >= CC_RDNA2) {
  1739. mmq_x = MMQ_X_Q6_K_RDNA2;
  1740. mmq_y = MMQ_Y_Q6_K_RDNA2;
  1741. nwarps = NWARPS_Q6_K_RDNA2;
  1742. } else if (compute_capability >= CC_OFFSET_AMD) {
  1743. mmq_x = MMQ_X_Q6_K_RDNA1;
  1744. mmq_y = MMQ_Y_Q6_K_RDNA1;
  1745. nwarps = NWARPS_Q6_K_RDNA1;
  1746. } else if (compute_capability >= CC_VOLTA) {
  1747. mmq_x = MMQ_X_Q6_K_AMPERE;
  1748. mmq_y = MMQ_Y_Q6_K_AMPERE;
  1749. nwarps = NWARPS_Q6_K_AMPERE;
  1750. } else if (compute_capability >= MIN_CC_DP4A) {
  1751. mmq_x = MMQ_X_Q6_K_PASCAL;
  1752. mmq_y = MMQ_Y_Q6_K_PASCAL;
  1753. nwarps = NWARPS_Q6_K_PASCAL;
  1754. } else {
  1755. GGML_ASSERT(false);
  1756. }
  1757. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  1758. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  1759. const dim3 block_nums(block_num_x, block_num_y, 1);
  1760. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  1761. if (nrows_x % mmq_y == 0) {
  1762. const bool need_check = false;
  1763. mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
  1764. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1765. } else {
  1766. const bool need_check = true;
  1767. mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
  1768. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1769. }
  1770. }
  1771. void ggml_cuda_op_mul_mat_q(
  1772. ggml_backend_cuda_context & ctx,
  1773. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  1774. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  1775. const int64_t src1_padded_row_size, cudaStream_t stream) {
  1776. const int64_t ne00 = src0->ne[0];
  1777. const int64_t ne10 = src1->ne[0];
  1778. GGML_ASSERT(ne10 % QK8_1 == 0);
  1779. const int64_t ne0 = dst->ne[0];
  1780. const int64_t row_diff = row_high - row_low;
  1781. int id = ggml_cuda_get_device();
  1782. // the main device has a larger memory buffer to hold the results from all GPUs
  1783. // nrows_dst == nrows of the matrix that the kernel writes into
  1784. const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
  1785. switch (src0->type) {
  1786. case GGML_TYPE_Q4_0:
  1787. ggml_mul_mat_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  1788. break;
  1789. case GGML_TYPE_Q4_1:
  1790. ggml_mul_mat_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  1791. break;
  1792. case GGML_TYPE_Q5_0:
  1793. ggml_mul_mat_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  1794. break;
  1795. case GGML_TYPE_Q5_1:
  1796. ggml_mul_mat_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  1797. break;
  1798. case GGML_TYPE_Q8_0:
  1799. ggml_mul_mat_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  1800. break;
  1801. case GGML_TYPE_Q2_K:
  1802. ggml_mul_mat_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  1803. break;
  1804. case GGML_TYPE_Q3_K:
  1805. ggml_mul_mat_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  1806. break;
  1807. case GGML_TYPE_Q4_K:
  1808. ggml_mul_mat_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  1809. break;
  1810. case GGML_TYPE_Q5_K:
  1811. ggml_mul_mat_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  1812. break;
  1813. case GGML_TYPE_Q6_K:
  1814. ggml_mul_mat_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  1815. break;
  1816. default:
  1817. GGML_ASSERT(false);
  1818. break;
  1819. }
  1820. GGML_UNUSED(src1);
  1821. GGML_UNUSED(dst);
  1822. GGML_UNUSED(src1_ddf_i);
  1823. }
  1824. bool ggml_cuda_supports_mmq(enum ggml_type type) {
  1825. switch (type) {
  1826. case GGML_TYPE_Q4_0:
  1827. case GGML_TYPE_Q4_1:
  1828. case GGML_TYPE_Q5_0:
  1829. case GGML_TYPE_Q5_1:
  1830. case GGML_TYPE_Q8_0:
  1831. case GGML_TYPE_Q2_K:
  1832. case GGML_TYPE_Q3_K:
  1833. case GGML_TYPE_Q4_K:
  1834. case GGML_TYPE_Q5_K:
  1835. case GGML_TYPE_Q6_K:
  1836. return true;
  1837. default:
  1838. return false;
  1839. }
  1840. }