mmq.cuh 76 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138
  1. #pragma once
  2. #include "common.cuh"
  3. #include "vecdotq.cuh"
  4. #include "mma.cuh"
  5. #include <climits>
  6. #include <cstdint>
  7. #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
  8. typedef void (*load_tiles_mmq_t)(
  9. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  10. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
  11. typedef void (*vec_dot_mmq_t)(
  12. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  13. const int * __restrict__ y, float * __restrict__ sum, const int & k0);
  14. typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
  15. struct block_q8_1_mmq {
  16. half2 ds[4];
  17. int8_t qs[4*QK8_1];
  18. };
  19. static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
  20. static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
  21. struct tile_x_sizes {
  22. int qs;
  23. int dm;
  24. int sc;
  25. };
  26. // get_mmq_x_max_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
  27. static constexpr __device__ int get_mmq_x_max_device() {
  28. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  29. return 64;
  30. #else
  31. #if __CUDA_ARCH__ >= CC_VOLTA
  32. #ifdef CUDA_USE_TENSOR_CORES
  33. return MMQ_MAX_BATCH_SIZE;
  34. #else
  35. return 128;
  36. #endif // CUDA_USE_TENSOR_CORES
  37. #else
  38. return 64;
  39. #endif // __CUDA_ARCH__ >= CC_VOLTA
  40. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  41. }
  42. // get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
  43. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  44. static constexpr __device__ int get_mmq_y_device(int mmq_x) {
  45. return mmq_x >= 32 ? 128 : 64;
  46. }
  47. #else
  48. #if __CUDA_ARCH__ >= CC_VOLTA
  49. static constexpr __device__ int get_mmq_y_device(int mmq_x) {
  50. return mmq_x >= 32 ? 128 : 64;
  51. }
  52. #else
  53. static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
  54. return 64;
  55. }
  56. #endif // __CUDA_ARCH__ >= CC_VOLTA
  57. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  58. #define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
  59. #define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
  60. #define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0}
  61. #define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0}
  62. #define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0}
  63. #define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
  64. #define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4}
  65. #define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
  66. #define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
  67. #define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
  68. #define GET_TILE_X_SIZES_BODY \
  69. return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \
  70. type == GGML_TYPE_Q4_1 ? TILE_X_SIZES_Q4_1 : \
  71. type == GGML_TYPE_Q5_0 ? TILE_X_SIZES_Q5_0 : \
  72. type == GGML_TYPE_Q5_1 ? TILE_X_SIZES_Q5_1 : \
  73. type == GGML_TYPE_Q8_0 ? TILE_X_SIZES_Q8_0 : \
  74. type == GGML_TYPE_Q2_K ? TILE_X_SIZES_Q2_K : \
  75. type == GGML_TYPE_Q3_K ? TILE_X_SIZES_Q3_K : \
  76. type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K : \
  77. type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K : \
  78. type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K : \
  79. tile_x_sizes{0, 0, 0}
  80. static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
  81. GET_TILE_X_SIZES_BODY;
  82. }
  83. template <int mmq_y>
  84. static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type) {
  85. GET_TILE_X_SIZES_BODY;
  86. }
  87. // ------------------------------------------------------------
  88. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
  89. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  90. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  91. GGML_UNUSED(x_sc);
  92. const int kbx = threadIdx.x / QI4_0;
  93. const int kqsx = threadIdx.x % QI4_0;
  94. float * x_dmf = (float *) x_dm;
  95. #pragma unroll
  96. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  97. int i = i0 + threadIdx.y;
  98. if (need_check) {
  99. i = min(i, i_max);
  100. }
  101. const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
  102. x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
  103. }
  104. const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
  105. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  106. #pragma unroll
  107. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
  108. int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row;
  109. if (need_check) {
  110. i = min(i, i_max);
  111. }
  112. const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
  113. x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
  114. }
  115. }
  116. template <int mmq_x, int mmq_y, int nwarps>
  117. static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
  118. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  119. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  120. GGML_UNUSED(x_sc);
  121. const float * x_df = (const float *) x_dm;
  122. const int * y_qs = (const int *) y + 4;
  123. const half2 * y_ds = (const half2 *) y;
  124. #pragma unroll
  125. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  126. const int j = j0 + threadIdx.y;
  127. #pragma unroll
  128. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  129. const int i = i0 + threadIdx.x;
  130. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  131. int u[2*VDR_Q4_0_Q8_1_MMQ];
  132. #pragma unroll
  133. for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
  134. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  135. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_0) % WARP_SIZE];
  136. }
  137. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
  138. (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
  139. y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  140. }
  141. }
  142. }
  143. template <int mmq_x, int mmq_y, int nwarps>
  144. static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
  145. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  146. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  147. #ifdef INT8_MMA_AVAILABLE
  148. GGML_UNUSED(x_sc);
  149. typedef mma_int_A_I16K8 mma_A;
  150. typedef mma_int_B_J8K8 mma_B;
  151. typedef mma_int_C_I16J8 mma_C;
  152. const float * x_df = (const float *) x_dm;
  153. const int * y_qs = (const int *) y + 4;
  154. const half2 * y_ds = (const half2 *) y;
  155. mma_A A;
  156. float dA[mma_C::ne/2];
  157. const int i0 = threadIdx.y*mma_A::I;
  158. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  159. #pragma unroll
  160. for (int l = 0; l < mma_A::ne; ++l) {
  161. const int i = i0 + mma_A::get_i(l);
  162. const int k = k0 + mma_A::get_k(l) % QI4_0;
  163. const int shift = 4*(mma_A::get_k(l) / QI4_0);
  164. A.x[l] = __vsubss4((x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
  165. }
  166. #pragma unroll
  167. for (int l = 0; l < mma_C::ne/2; ++l) {
  168. const int i = i0 + mma_C::get_i(2*l);
  169. dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
  170. }
  171. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  172. mma_C C;
  173. mma_B B;
  174. half2 dsB[mma_C::ne/2];
  175. #pragma unroll
  176. for (int l = 0; l < mma_B::ne; ++l) {
  177. const int j = j0 + mma_B::get_j(l);
  178. const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
  179. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  180. }
  181. #pragma unroll
  182. for (int l = 0; l < mma_C::ne/2; ++l) {
  183. const int j = j0 + mma_C::get_j(l);
  184. dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
  185. }
  186. C.mma_K8(A, B);
  187. #pragma unroll
  188. for (int l = 0; l < mma_C::ne; ++l) {
  189. sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
  190. }
  191. }
  192. #else
  193. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  194. NO_DEVICE_CODE;
  195. #endif // INT8_MMA_AVAILABLE
  196. }
  197. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
  198. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  199. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  200. GGML_UNUSED(x_sc);
  201. const int kbx = threadIdx.x / QI4_1;
  202. const int kqsx = threadIdx.x % QI4_1;
  203. #pragma unroll
  204. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  205. int i = i0 + threadIdx.y;
  206. if (need_check) {
  207. i = min(i, i_max);
  208. }
  209. const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
  210. x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  211. }
  212. const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
  213. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  214. #pragma unroll
  215. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
  216. int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row;
  217. if (need_check) {
  218. i = min(i, i_max);
  219. }
  220. const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
  221. x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
  222. }
  223. }
  224. template <int mmq_x, int mmq_y, int nwarps>
  225. static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
  226. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  227. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  228. GGML_UNUSED(x_sc);
  229. const int * y_qs = (const int *) y + 4;
  230. const half2 * y_ds = (const half2 *) y;
  231. #pragma unroll
  232. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  233. const int j = j0 + threadIdx.y;
  234. #pragma unroll
  235. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  236. const int i = i0 + threadIdx.x;
  237. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  238. int u[2*VDR_Q4_1_Q8_1_MMQ];
  239. #pragma unroll
  240. for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
  241. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  242. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_1) % WARP_SIZE];
  243. }
  244. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
  245. (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
  246. y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  247. }
  248. }
  249. }
  250. template <int mmq_x, int mmq_y, int nwarps>
  251. static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
  252. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  253. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  254. #ifdef INT8_MMA_AVAILABLE
  255. GGML_UNUSED(x_sc);
  256. typedef mma_int_A_I16K8 mma_A;
  257. typedef mma_int_B_J8K8 mma_B;
  258. typedef mma_int_C_I16J8 mma_C;
  259. const int * y_qs = (const int *) y + 4;
  260. const half2 * y_ds = (const half2 *) y;
  261. mma_A A;
  262. half2 dmA[mma_C::ne/2];
  263. const int i0 = threadIdx.y*mma_A::I;
  264. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  265. #pragma unroll
  266. for (int l = 0; l < mma_A::ne; ++l) {
  267. const int i = i0 + mma_A::get_i(l);
  268. const int k = k0 + mma_A::get_k(l) % QI4_0;
  269. const int shift = 4*(mma_A::get_k(l) / QI4_0);
  270. A.x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
  271. }
  272. #pragma unroll
  273. for (int l = 0; l < mma_C::ne/2; ++l) {
  274. const int i = i0 + mma_C::get_i(2*l);
  275. dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
  276. }
  277. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  278. mma_C C;
  279. mma_B B;
  280. half2 dsB[mma_C::ne/2];
  281. #pragma unroll
  282. for (int l = 0; l < mma_B::ne; ++l) {
  283. const int j = j0 + mma_B::get_j(l);
  284. const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
  285. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  286. }
  287. #pragma unroll
  288. for (int l = 0; l < mma_C::ne/2; ++l) {
  289. const int j = j0 + mma_C::get_j(l);
  290. dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
  291. }
  292. C.mma_K8(A, B);
  293. #pragma unroll
  294. for (int l = 0; l < mma_C::ne; ++l) {
  295. const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
  296. sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
  297. }
  298. }
  299. #else
  300. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  301. NO_DEVICE_CODE;
  302. #endif // INT8_MMA_AVAILABLE
  303. }
  304. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
  305. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  306. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  307. GGML_UNUSED(x_sc);
  308. const int kbx = threadIdx.x / QI5_0;
  309. const int kqsx = threadIdx.x % QI5_0;
  310. #pragma unroll
  311. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  312. int i = i0 + threadIdx.y;
  313. if (need_check) {
  314. i = min(i, i_max);
  315. }
  316. const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
  317. const int ql = get_int_from_uint8(bxi->qs, kqsx);
  318. const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
  319. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  320. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  321. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  322. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  323. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  324. qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
  325. x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
  326. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  327. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  328. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  329. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  330. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  331. qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
  332. x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
  333. }
  334. const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
  335. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  336. float * x_dmf = (float *) x_dm;
  337. #pragma unroll
  338. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
  339. int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row;
  340. if (need_check) {
  341. i = min(i, i_max);
  342. }
  343. const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
  344. x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
  345. }
  346. }
  347. template <int mmq_x, int mmq_y, int nwarps>
  348. static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
  349. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  350. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  351. GGML_UNUSED(x_sc);
  352. const float * x_dmf = (const float *) x_dm;
  353. const int * y_qs = (const int *) y + 4;
  354. const float * y_df = (const float *) y;
  355. #pragma unroll
  356. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  357. const int j = j0 + threadIdx.y;
  358. #pragma unroll
  359. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  360. const int i = i0 + threadIdx.x;
  361. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  362. const int index_bx = i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0;
  363. int u[2*VDR_Q5_0_Q8_1_MMQ];
  364. #pragma unroll
  365. for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
  366. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  367. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_0) % WARP_SIZE];
  368. }
  369. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
  370. (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  371. }
  372. }
  373. }
  374. template <int mmq_x, int mmq_y, int nwarps>
  375. static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
  376. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  377. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  378. #ifdef INT8_MMA_AVAILABLE
  379. GGML_UNUSED(x_sc);
  380. typedef mma_int_A_I16K8 mma_A;
  381. typedef mma_int_B_J8K8 mma_B;
  382. typedef mma_int_C_I16J8 mma_C;
  383. const float * x_df = (const float *) x_dm;
  384. const int * y_qs = (const int *) y + 4;
  385. const float * y_df = (const float *) y;
  386. mma_A A;
  387. float dA[mma_C::ne/2];
  388. const int i0 = threadIdx.y*mma_A::I;
  389. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  390. #pragma unroll
  391. for (int l = 0; l < mma_A::ne; ++l) {
  392. const int i = i0 + mma_A::get_i(l);
  393. const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
  394. A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
  395. }
  396. #pragma unroll
  397. for (int l = 0; l < mma_C::ne/2; ++l) {
  398. const int i = i0 + mma_C::get_i(2*l);
  399. dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0];
  400. }
  401. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  402. mma_C C;
  403. mma_B B;
  404. float dB[mma_C::ne/2];
  405. #pragma unroll
  406. for (int l = 0; l < mma_B::ne; ++l) {
  407. const int j = j0 + mma_B::get_j(l);
  408. const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
  409. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  410. }
  411. #pragma unroll
  412. for (int l = 0; l < mma_C::ne/2; ++l) {
  413. const int j = j0 + mma_C::get_j(l);
  414. dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
  415. }
  416. C.mma_K8(A, B);
  417. #pragma unroll
  418. for (int l = 0; l < mma_C::ne; ++l) {
  419. sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
  420. }
  421. }
  422. #else
  423. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  424. NO_DEVICE_CODE;
  425. #endif // INT8_MMA_AVAILABLE
  426. }
  427. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
  428. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  429. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  430. GGML_UNUSED(x_sc);
  431. const int kbx = threadIdx.x / QI5_1;
  432. const int kqsx = threadIdx.x % QI5_1;
  433. #pragma unroll
  434. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  435. int i = i0 + threadIdx.y;
  436. if (need_check) {
  437. i = min(i, i_max);
  438. }
  439. const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
  440. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  441. const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
  442. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  443. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  444. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  445. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  446. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  447. x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
  448. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  449. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  450. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  451. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  452. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  453. x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
  454. }
  455. const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
  456. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  457. #pragma unroll
  458. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
  459. int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row;
  460. if (need_check) {
  461. i = min(i, i_max);
  462. }
  463. const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
  464. x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
  465. }
  466. }
  467. template <int mmq_x, int mmq_y, int nwarps>
  468. static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
  469. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  470. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  471. GGML_UNUSED(x_sc);
  472. const int * y_qs = (const int *) y + 4;
  473. const half2 * y_ds = (const half2 *) y;
  474. #pragma unroll
  475. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  476. const int j = j0 + threadIdx.y;
  477. #pragma unroll
  478. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  479. const int i = i0 + threadIdx.x;
  480. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  481. const int index_bx = i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1;
  482. int u[2*VDR_Q5_1_Q8_1_MMQ];
  483. #pragma unroll
  484. for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
  485. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  486. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_1) % WARP_SIZE];
  487. }
  488. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
  489. (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  490. }
  491. }
  492. }
  493. template <int mmq_x, int mmq_y, int nwarps>
  494. static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
  495. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  496. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  497. #ifdef INT8_MMA_AVAILABLE
  498. GGML_UNUSED(x_sc);
  499. typedef mma_int_A_I16K8 mma_A;
  500. typedef mma_int_B_J8K8 mma_B;
  501. typedef mma_int_C_I16J8 mma_C;
  502. const int * y_qs = (const int *) y + 4;
  503. const half2 * y_ds = (const half2 *) y;
  504. mma_A A;
  505. half2 dmA[mma_C::ne/2];
  506. const int i0 = threadIdx.y*mma_A::I;
  507. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  508. #pragma unroll
  509. for (int l = 0; l < mma_A::ne; ++l) {
  510. const int i = i0 + mma_A::get_i(l);
  511. const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
  512. A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
  513. }
  514. #pragma unroll
  515. for (int l = 0; l < mma_C::ne/2; ++l) {
  516. const int i = i0 + mma_C::get_i(2*l);
  517. dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1];
  518. }
  519. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  520. mma_C C;
  521. mma_B B;
  522. half2 dsB[mma_C::ne/2];
  523. #pragma unroll
  524. for (int l = 0; l < mma_B::ne; ++l) {
  525. const int j = j0 + mma_B::get_j(l);
  526. const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
  527. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  528. }
  529. #pragma unroll
  530. for (int l = 0; l < mma_C::ne/2; ++l) {
  531. const int j = j0 + mma_C::get_j(l);
  532. dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
  533. }
  534. C.mma_K8(A, B);
  535. #pragma unroll
  536. for (int l = 0; l < mma_C::ne; ++l) {
  537. const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
  538. sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
  539. }
  540. }
  541. #else
  542. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  543. NO_DEVICE_CODE;
  544. #endif // INT8_MMA_AVAILABLE
  545. }
  546. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
  547. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  548. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  549. GGML_UNUSED(x_sc);
  550. const int kbx = threadIdx.x / QI8_0;
  551. const int kqsx = threadIdx.x % QI8_0;
  552. float * x_dmf = (float *) x_dm;
  553. #pragma unroll
  554. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  555. int i = i0 + threadIdx.y;
  556. if (need_check) {
  557. i = min(i, i_max);
  558. }
  559. const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
  560. x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
  561. }
  562. const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
  563. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  564. #pragma unroll
  565. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
  566. int i = i0 + threadIdx.y * QI8_0 + threadIdx.x / blocks_per_tile_x_row;
  567. if (need_check) {
  568. i = min(i, i_max);
  569. }
  570. const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
  571. x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
  572. }
  573. }
  574. template <int mmq_x, int mmq_y, int nwarps>
  575. static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
  576. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  577. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  578. GGML_UNUSED(x_sc);
  579. const float * x_dmf = (const float *) x_dm;
  580. const int * y_qs = (const int *) y + 4;
  581. const float * y_df = (const float *) y;
  582. #pragma unroll
  583. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  584. const int j = j0 + threadIdx.y;
  585. #pragma unroll
  586. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  587. const int i = i0 + threadIdx.x;
  588. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
  589. (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
  590. y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
  591. }
  592. }
  593. }
  594. template <int mmq_x, int mmq_y, int nwarps>
  595. static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
  596. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  597. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  598. #ifdef INT8_MMA_AVAILABLE
  599. GGML_UNUSED(x_sc);
  600. typedef mma_int_A_I16K8 mma_A;
  601. typedef mma_int_B_J8K8 mma_B;
  602. typedef mma_int_C_I16J8 mma_C;
  603. const float * x_df = (const float *) x_dm;
  604. const int * y_qs = (const int *) y + 4;
  605. const float * y_df = (const float *) y;
  606. mma_A A;
  607. float dA[mma_C::ne/2];
  608. const int i0 = threadIdx.y*mma_A::I;
  609. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  610. #pragma unroll
  611. for (int l = 0; l < mma_A::ne; ++l) {
  612. const int i = i0 + mma_A::get_i(l);
  613. const int k = k0 + mma_A::get_k(l);
  614. A.x[l] = x_qs[i*(WARP_SIZE + 1) + k];
  615. }
  616. #pragma unroll
  617. for (int l = 0; l < mma_C::ne/2; ++l) {
  618. const int i = i0 + mma_C::get_i(2*l);
  619. dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0];
  620. }
  621. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  622. mma_C C;
  623. mma_B B;
  624. float dB[mma_C::ne/2];
  625. #pragma unroll
  626. for (int l = 0; l < mma_B::ne; ++l) {
  627. const int j = j0 + mma_B::get_j(l);
  628. const int k = k0 + mma_B::get_k(l);
  629. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  630. }
  631. #pragma unroll
  632. for (int l = 0; l < mma_C::ne/2; ++l) {
  633. const int j = j0 + mma_C::get_j(l);
  634. dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1];
  635. }
  636. C.mma_K8(A, B);
  637. #pragma unroll
  638. for (int l = 0; l < mma_C::ne; ++l) {
  639. sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
  640. }
  641. }
  642. #else
  643. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  644. NO_DEVICE_CODE;
  645. #endif // INT8_MMA_AVAILABLE
  646. }
  647. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
  648. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  649. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  650. const int kbx = threadIdx.x / QI2_K;
  651. const int kqsx = threadIdx.x % QI2_K;
  652. #pragma unroll
  653. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  654. int i = i0 + threadIdx.y;
  655. if (need_check) {
  656. i = min(i, i_max);
  657. }
  658. const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
  659. const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
  660. #pragma unroll
  661. for (int l = 0; l < QR2_K; ++l) {
  662. const int k = kbx*QI2_K + (kqsx/8)*8 + l*2 + (kqsx % 8)/4;
  663. int x_qs_k = ((x_ql_0 >> (2*l)) & 0x03030303) << (2*(kqsx % 4));
  664. x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
  665. x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE);
  666. if (kqsx % QR2_K != 0) {
  667. continue;
  668. }
  669. x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k;
  670. }
  671. const int sc_m = bxi->scales[kqsx];
  672. #ifdef FAST_FP16_AVAILABLE
  673. const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4));
  674. #else
  675. const float2 bxi_dmf = __half22float2(bxi->dm);
  676. const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
  677. #endif // FAST_FP16_AVAILABLE
  678. x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik;
  679. }
  680. }
  681. template <int mmq_x, int mmq_y, int nwarps>
  682. static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
  683. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  684. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  685. const int * y_qs = (const int *) y + 4;
  686. const float * y_df = (const float *) y;
  687. #pragma unroll
  688. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  689. const int j = j0 + threadIdx.y;
  690. #pragma unroll
  691. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  692. const int i = i0 + threadIdx.x;
  693. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
  694. &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE],
  695. &x_dm[i*(WARP_SIZE + 1) + k0], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
  696. }
  697. }
  698. }
  699. template <int mmq_x, int mmq_y, int nwarps>
  700. static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
  701. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  702. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  703. #ifdef INT8_MMA_AVAILABLE
  704. typedef mma_int_A_I16K4 mma_A;
  705. typedef mma_int_B_J8K4 mma_B;
  706. typedef mma_int_C_I16J8 mma_C;
  707. const int * y_qs = (const int *) y + 4;
  708. const float * y_df = (const float *) y;
  709. const int i0 = threadIdx.y*mma_A::I;
  710. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  711. mma_A A[2];
  712. float dA[mma_C::ne/2][2];
  713. float mA[mma_C::ne/2][2];
  714. #pragma unroll
  715. for (int l = 0; l < mma_A::ne; ++l) {
  716. const int i = i0 + mma_A::get_i(l);
  717. const int shift = 2*mma_A::get_k(l);
  718. A[0].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 0] >> shift) & 0x03030303;
  719. A[1].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 1] >> shift) & 0x03030303;
  720. }
  721. #pragma unroll
  722. for (int l = 0; l < mma_C::ne/2; ++l) {
  723. const int i = i0 + mma_C::get_i(2*l);
  724. #pragma unroll
  725. for (int kk = 0; kk < 2; ++kk) {
  726. const float2 dm = __half22float2(x_dm[i*(WARP_SIZE + 1) + k0 + kk]);
  727. dA[l][kk] = dm.x;
  728. mA[l][kk] = dm.y;
  729. }
  730. }
  731. #pragma unroll
  732. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  733. mma_C Cd[2];
  734. mma_C Cm[2];
  735. mma_B B[2];
  736. float dB[mma_C::ne/2];
  737. #pragma unroll
  738. for (int l = 0; l < mma_B::ne; ++l) {
  739. const int j = j0 + mma_B::get_j(l);
  740. const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
  741. B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
  742. B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
  743. }
  744. #pragma unroll
  745. for (int l = 0; l < mma_C::ne/2; ++l) {
  746. const int j = j0 + mma_C::get_j(l);
  747. dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
  748. }
  749. Cd[0].mma_K4(A[0], B[0]);
  750. Cd[1].mma_K4(A[1], B[1]);
  751. mma_A A1;
  752. A1.x[0] = 0x01010101;
  753. A1.x[1] = 0x01010101;
  754. Cm[0].mma_K4(A1, B[0]);
  755. Cm[1].mma_K4(A1, B[1]);
  756. #pragma unroll
  757. for (int l = 0; l < mma_C::ne; ++l) {
  758. sum[(j0/mma_B::J)*mma_C::ne + l] += (Cd[0].x[l]*dA[l/2][0] + Cd[1].x[l]*dA[l/2][1] - Cm[0].x[l]*mA[l/2][0] - Cm[1].x[l]*mA[l/2][1])*dB[l%2];
  759. }
  760. }
  761. #else
  762. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  763. NO_DEVICE_CODE;
  764. #endif // INT8_MMA_AVAILABLE
  765. }
  766. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
  767. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  768. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  769. const int kbx = threadIdx.x / QI3_K;
  770. const int kqsx = threadIdx.x % QI3_K;
  771. #pragma unroll
  772. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  773. int i = i0 + threadIdx.y;
  774. if (need_check) {
  775. i = min(i, i_max);
  776. }
  777. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
  778. const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
  779. const int x_qh_0 = get_int_from_uint8(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
  780. #pragma unroll
  781. for (int l = 0; l < QR3_K; ++l) {
  782. const int k = kbx*(QR3_K*QI3_K) + (kqsx/8)*32 + l*8 + kqsx % 8;
  783. const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303;
  784. const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404;
  785. int x_qs_k = (x_ql_k | x_qh_k) << (4*(k%2));
  786. x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
  787. if (kqsx % 2 != 0) {
  788. continue;
  789. }
  790. x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k;
  791. }
  792. }
  793. const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
  794. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  795. float * x_dmf = (float *) x_dm;
  796. #pragma unroll
  797. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
  798. int i = (i0 + threadIdx.y * QI3_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  799. if (need_check) {
  800. i = min(i, i_max);
  801. }
  802. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd;
  803. x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
  804. }
  805. #pragma unroll
  806. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
  807. int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
  808. if (need_check) {
  809. i = min(i, i_max);
  810. }
  811. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI3_K/4);
  812. const int ksc = threadIdx.x % (QI3_K/4);
  813. const int ksc_low = ksc % (QI3_K/8);
  814. const int shift_low = 4 * (ksc / (QI3_K/8));
  815. const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
  816. const int ksc_high = QI3_K/8;
  817. const int shift_high = 2 * ksc;
  818. const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
  819. const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
  820. x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = sc;
  821. }
  822. }
  823. template <int mmq_x, int mmq_y, int nwarps>
  824. static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
  825. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  826. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  827. const float * x_df = (const float *) x_dm;
  828. const int * y_qs = (const int *) y + 4;
  829. const float * y_df = (const float *) y;
  830. #pragma unroll
  831. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  832. const int j = j0 + threadIdx.y;
  833. #pragma unroll
  834. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  835. const int i = i0 + threadIdx.x;
  836. const int kbx = k0 / QI3_K;
  837. const int ky = (k0 % QI3_K) * QR3_K;
  838. const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
  839. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
  840. &x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
  841. x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
  842. }
  843. }
  844. }
  845. template <int mmq_x, int mmq_y, int nwarps>
  846. static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
  847. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  848. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  849. #ifdef INT8_MMA_AVAILABLE
  850. typedef mma_int_A_I16K4 mma_A;
  851. typedef mma_int_B_J8K4 mma_B;
  852. typedef mma_int_C_I16J8 mma_C;
  853. const float * x_df = (const float *) x_dm;
  854. const int * y_qs = (const int *) y + 4;
  855. const float * y_df = (const float *) y;
  856. const int i0 = threadIdx.y*mma_A::I;
  857. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  858. mma_A A[2];
  859. int scA[mma_C::ne/2][2];
  860. float dA[mma_C::ne/2];
  861. #pragma unroll
  862. for (int l = 0; l < mma_A::ne; ++l) {
  863. const int i = i0 + mma_A::get_i(l);
  864. const int k = QR3_K*k0 + mma_A::get_k(l);
  865. A[0].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + 0] >> (4*(k%2))) & 0x0F0F0F0F;
  866. A[1].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F;
  867. A[0].x[l] = __vsubss4(A[0].x[l], 0x04040404);
  868. A[1].x[l] = __vsubss4(A[1].x[l], 0x04040404);
  869. }
  870. #pragma unroll
  871. for (int l = 0; l < mma_C::ne/2; ++l) {
  872. const int i = i0 + mma_C::get_i(2*l);
  873. const int kbx = k0 / QI3_K;
  874. const int ky = (k0 % QI3_K) * QR3_K;
  875. const int8_t * sc = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
  876. scA[l][0] = sc[0];
  877. scA[l][1] = sc[1];
  878. }
  879. #pragma unroll
  880. for (int l = 0; l < mma_C::ne/2; ++l) {
  881. const int i = i0 + mma_C::get_i(2*l);
  882. dA[l] = x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + k0/QI3_K];
  883. }
  884. #pragma unroll
  885. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  886. mma_C C[2];
  887. mma_B B[2];
  888. float dB[mma_C::ne/2];
  889. #pragma unroll
  890. for (int l = 0; l < mma_B::ne; ++l) {
  891. const int j = j0 + mma_B::get_j(l);
  892. const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
  893. B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
  894. B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
  895. }
  896. #pragma unroll
  897. for (int l = 0; l < mma_C::ne/2; ++l) {
  898. const int j = j0 + mma_C::get_j(l);
  899. dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
  900. }
  901. C[0].mma_K4(A[0], B[0]);
  902. C[1].mma_K4(A[1], B[1]);
  903. #pragma unroll
  904. for (int l = 0; l < mma_C::ne; ++l) {
  905. sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*scA[l/2][0] + C[1].x[l]*scA[l/2][1])*dA[l/2]*dB[l%2];
  906. }
  907. }
  908. #else
  909. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  910. NO_DEVICE_CODE;
  911. #endif // INT8_MMA_AVAILABLE
  912. }
  913. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
  914. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  915. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  916. const int kbx = 0; // threadIdx.x / QI4_K
  917. const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
  918. #pragma unroll
  919. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  920. int i = i0 + threadIdx.y;
  921. if (need_check) {
  922. i = min(i, i_max);
  923. }
  924. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
  925. x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  926. }
  927. const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
  928. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  929. #pragma unroll
  930. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
  931. int i = (i0 + threadIdx.y * QI4_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  932. if (need_check) {
  933. i = min(i, i_max);
  934. }
  935. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd;
  936. x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
  937. }
  938. #pragma unroll
  939. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  940. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  941. if (need_check) {
  942. i = min(i, i_max);
  943. }
  944. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
  945. const int * scales = (const int *) bxi->scales;
  946. const int ksc = threadIdx.x % (WARP_SIZE/8);
  947. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  948. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  949. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  950. x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
  951. }
  952. }
  953. template <int mmq_x, int mmq_y, int nwarps>
  954. static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
  955. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  956. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  957. const int * y_qs = (const int *) y + 4;
  958. const half2 * y_ds = (const half2 *) y;
  959. #pragma unroll
  960. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  961. const int j = j0 + threadIdx.y;
  962. #pragma unroll
  963. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  964. const int i = i0 + threadIdx.x;
  965. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
  966. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
  967. &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
  968. x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]);
  969. }
  970. }
  971. }
  972. template <int mmq_x, int mmq_y, int nwarps>
  973. static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
  974. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  975. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  976. #ifdef INT8_MMA_AVAILABLE
  977. typedef mma_int_A_I16K8 mma_A;
  978. typedef mma_int_B_J8K8 mma_B;
  979. typedef mma_int_C_I16J8 mma_C;
  980. const int * y_qs = (const int *) y + 4;
  981. const half2 * y_ds = (const half2 *) y;
  982. const int i0 = threadIdx.y*mma_A::I;
  983. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  984. mma_A A[2];
  985. int scA[mma_C::ne/2][2];
  986. int mA[mma_C::ne/2][2];
  987. half2 dmA[mma_C::ne/2];
  988. #pragma unroll
  989. for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
  990. #pragma unroll
  991. for (int l = 0; l < mma_A::ne; ++l) {
  992. const int i = i0 + mma_A::get_i(l);
  993. const int k = k0 + mma_A::get_k(l);
  994. A[kvdr/4].x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
  995. }
  996. #pragma unroll
  997. for (int l = 0; l < mma_C::ne/2; ++l) {
  998. const int i = i0 + mma_C::get_i(2*l);
  999. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
  1000. const uint8_t * m = sc + 8;
  1001. scA[l][kvdr/4] = sc[kvdr/4];
  1002. mA[l][kvdr/4] = m[kvdr/4];
  1003. }
  1004. }
  1005. #pragma unroll
  1006. for (int l = 0; l < mma_C::ne/2; ++l) {
  1007. const int i = i0 + mma_C::get_i(2*l);
  1008. dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
  1009. }
  1010. #pragma unroll
  1011. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  1012. float tmpd[mma_C::ne] = {0.0f};
  1013. float tmpm[mma_C::ne] = {0.0f};
  1014. #pragma unroll
  1015. for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
  1016. mma_C C;
  1017. mma_B B;
  1018. half2 dsB[mma_C::ne/2];
  1019. #pragma unroll
  1020. for (int l = 0; l < mma_B::ne; ++l) {
  1021. const int j = j0 + mma_B::get_j(l);
  1022. const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
  1023. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  1024. }
  1025. #pragma unroll
  1026. for (int l = 0; l < mma_C::ne/2; ++l) {
  1027. const int j = j0 + mma_C::get_j(l);
  1028. dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
  1029. }
  1030. C.mma_K8(A[kvdr/4], B);
  1031. #pragma unroll
  1032. for (int l = 0; l < mma_C::ne; ++l) {
  1033. tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]);
  1034. tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]);
  1035. }
  1036. }
  1037. #pragma unroll
  1038. for (int l = 0; l < mma_C::ne; ++l) {
  1039. sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
  1040. }
  1041. }
  1042. #else
  1043. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  1044. NO_DEVICE_CODE;
  1045. #endif // INT8_MMA_AVAILABLE
  1046. }
  1047. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
  1048. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  1049. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  1050. const int kbx = 0; // threadIdx.x / QI5_K
  1051. const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
  1052. #pragma unroll
  1053. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  1054. int i = i0 + threadIdx.y;
  1055. if (need_check) {
  1056. i = min(i, i_max);
  1057. }
  1058. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
  1059. const int ky = QR5_K*kqsx;
  1060. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  1061. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  1062. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  1063. const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
  1064. const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
  1065. const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
  1066. const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
  1067. const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
  1068. x_qs[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
  1069. x_qs[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
  1070. }
  1071. const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
  1072. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  1073. #pragma unroll
  1074. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
  1075. int i = (i0 + threadIdx.y * QI5_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  1076. if (need_check) {
  1077. i = min(i, i_max);
  1078. }
  1079. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd;
  1080. x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
  1081. }
  1082. #pragma unroll
  1083. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  1084. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  1085. if (need_check) {
  1086. i = min(i, i_max);
  1087. }
  1088. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI5_K/8);
  1089. const int * scales = (const int *) bxi->scales;
  1090. const int ksc = threadIdx.x % (WARP_SIZE/8);
  1091. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  1092. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  1093. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  1094. x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
  1095. }
  1096. }
  1097. template <int mmq_x, int mmq_y, int nwarps>
  1098. static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
  1099. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  1100. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  1101. const int * y_qs = (const int *) y + 4;
  1102. const half2 * y_ds = (const half2 *) y;
  1103. #pragma unroll
  1104. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1105. const int j = j0 + threadIdx.y;
  1106. #pragma unroll
  1107. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1108. const int i = i0 + threadIdx.x;
  1109. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
  1110. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
  1111. &x_qs[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
  1112. x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]);
  1113. }
  1114. }
  1115. }
  1116. template <int mmq_x, int mmq_y, int nwarps>
  1117. static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
  1118. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  1119. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  1120. #ifdef INT8_MMA_AVAILABLE
  1121. typedef mma_int_A_I16K8 mma_A;
  1122. typedef mma_int_B_J8K8 mma_B;
  1123. typedef mma_int_C_I16J8 mma_C;
  1124. const int * y_qs = (const int *) y + 4;
  1125. const half2 * y_ds = (const half2 *) y;
  1126. const int i0 = threadIdx.y*mma_A::I;
  1127. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  1128. mma_A A[2];
  1129. int scA[mma_C::ne/2][2];
  1130. int mA[mma_C::ne/2][2];
  1131. half2 dmA[mma_C::ne/2];
  1132. #pragma unroll
  1133. for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
  1134. #pragma unroll
  1135. for (int l = 0; l < mma_A::ne; ++l) {
  1136. const int i = i0 + mma_A::get_i(l);
  1137. const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l);
  1138. A[kvdr/4].x[l] = x_qs[i*(QR5_K*WARP_SIZE + 1) + k];
  1139. }
  1140. #pragma unroll
  1141. for (int l = 0; l < mma_C::ne/2; ++l) {
  1142. const int i = i0 + mma_C::get_i(2*l);
  1143. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
  1144. const uint8_t * m = sc + 8;
  1145. scA[l][kvdr/4] = sc[kvdr/4];
  1146. mA[l][kvdr/4] = m[kvdr/4];
  1147. }
  1148. }
  1149. #pragma unroll
  1150. for (int l = 0; l < mma_C::ne/2; ++l) {
  1151. const int i = i0 + mma_C::get_i(2*l);
  1152. dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
  1153. }
  1154. #pragma unroll
  1155. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  1156. float tmpd[mma_C::ne] = {0.0f};
  1157. float tmpm[mma_C::ne] = {0.0f};
  1158. #pragma unroll
  1159. for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
  1160. mma_C C;
  1161. mma_B B;
  1162. half2 dsB[mma_C::ne/2];
  1163. #pragma unroll
  1164. for (int l = 0; l < mma_B::ne; ++l) {
  1165. const int j = j0 + mma_B::get_j(l);
  1166. const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
  1167. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  1168. }
  1169. #pragma unroll
  1170. for (int l = 0; l < mma_C::ne/2; ++l) {
  1171. const int j = j0 + mma_C::get_j(l);
  1172. dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
  1173. }
  1174. C.mma_K8(A[kvdr/4], B);
  1175. #pragma unroll
  1176. for (int l = 0; l < mma_C::ne; ++l) {
  1177. tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]);
  1178. tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]);
  1179. }
  1180. }
  1181. #pragma unroll
  1182. for (int l = 0; l < mma_C::ne; ++l) {
  1183. sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
  1184. }
  1185. }
  1186. #else
  1187. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  1188. NO_DEVICE_CODE;
  1189. #endif // INT8_MMA_AVAILABLE
  1190. }
  1191. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
  1192. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  1193. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  1194. const int kbx = 0; // threadIdx.x / QI6_K
  1195. const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
  1196. #pragma unroll
  1197. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  1198. int i = i0 + threadIdx.y;
  1199. if (need_check) {
  1200. i = min(i, i_max);
  1201. }
  1202. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
  1203. const int ky = QR6_K*kqsx;
  1204. const int ql = get_int_from_uint8(bxi->ql, kqsx);
  1205. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  1206. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  1207. const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
  1208. const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
  1209. const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
  1210. const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
  1211. const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
  1212. x_qs[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
  1213. x_qs[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
  1214. }
  1215. const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
  1216. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  1217. float * x_dmf = (float *) x_dm;
  1218. #pragma unroll
  1219. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
  1220. int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  1221. if (need_check) {
  1222. i = min(i, i_max);
  1223. }
  1224. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
  1225. x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
  1226. }
  1227. #pragma unroll
  1228. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  1229. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  1230. if (need_check) {
  1231. i = min(i, i_max);
  1232. }
  1233. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
  1234. x_sc[i * (WARP_SIZE/8) + i / 8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
  1235. }
  1236. }
  1237. template <int mmq_x, int mmq_y, int nwarps>
  1238. static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
  1239. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  1240. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  1241. const float * x_dmf = (const float *) x_dm;
  1242. const int * y_qs = (const int *) y + 4;
  1243. const float * y_df = (const float *) y;
  1244. #pragma unroll
  1245. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1246. const int j = j0 + threadIdx.y;
  1247. #pragma unroll
  1248. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1249. const int i = i0 + threadIdx.x;
  1250. const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
  1251. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
  1252. &x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
  1253. x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
  1254. }
  1255. }
  1256. }
  1257. template <int mmq_x, int mmq_y, int nwarps>
  1258. static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
  1259. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  1260. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  1261. #ifdef INT8_MMA_AVAILABLE
  1262. typedef mma_int_A_I16K4 mma_A;
  1263. typedef mma_int_B_J8K4 mma_B;
  1264. typedef mma_int_C_I16J8 mma_C;
  1265. const float * x_df = (const float *) x_dm;
  1266. const int * y_qs = (const int *) y + 4;
  1267. const float * y_df = (const float *) y;
  1268. const int i0 = threadIdx.y*mma_A::I;
  1269. #ifdef INT8_MMA_AVAILABLE
  1270. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  1271. #endif // INT8_MMA_AVAILABLE
  1272. mma_A A[4];
  1273. int scA[mma_C::ne/2][4];
  1274. float dA[mma_C::ne/2];
  1275. #pragma unroll
  1276. for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
  1277. #pragma unroll
  1278. for (int l = 0; l < mma_A::ne; ++l) {
  1279. const int i = i0 + mma_A::get_i(l);
  1280. const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l);
  1281. A[kvdr/2 + 0].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + 0];
  1282. A[kvdr/2 + 1].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
  1283. }
  1284. #pragma unroll
  1285. for (int l = 0; l < mma_C::ne/2; ++l) {
  1286. const int i = i0 + mma_C::get_i(2*l);
  1287. const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
  1288. scA[l][kvdr/2 + 0] = sc[kvdr/2 + 0];
  1289. scA[l][kvdr/2 + 1] = sc[kvdr/2 + 1];
  1290. }
  1291. }
  1292. #pragma unroll
  1293. for (int l = 0; l < mma_C::ne/2; ++l) {
  1294. const int i = i0 + mma_C::get_i(2*l);
  1295. dA[l] = x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + k0/QI6_K];
  1296. }
  1297. #pragma unroll
  1298. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  1299. float tmp[mma_C::ne] = {0.0f};
  1300. #pragma unroll
  1301. for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
  1302. mma_C C[2];
  1303. mma_B B[2];
  1304. float dB[mma_C::ne/2];
  1305. #pragma unroll
  1306. for (int l = 0; l < mma_B::ne; ++l) {
  1307. const int j = j0 + mma_B::get_j(l);
  1308. const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
  1309. B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
  1310. B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
  1311. }
  1312. #pragma unroll
  1313. for (int l = 0; l < mma_C::ne/2; ++l) {
  1314. const int j = j0 + mma_C::get_j(l);
  1315. dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
  1316. }
  1317. C[0].mma_K4(A[kvdr/2 + 0], B[0]);
  1318. C[1].mma_K4(A[kvdr/2 + 1], B[1]);
  1319. #pragma unroll
  1320. for (int l = 0; l < mma_C::ne; ++l) {
  1321. tmp[l] += (C[0].x[l]*scA[l/2][kvdr/2 + 0] + C[1].x[l]*scA[l/2][kvdr/2 + 1])*dB[l%2];
  1322. }
  1323. }
  1324. #pragma unroll
  1325. for (int l = 0; l < mma_C::ne; ++l) {
  1326. sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2];
  1327. }
  1328. }
  1329. #else
  1330. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  1331. NO_DEVICE_CODE;
  1332. #endif // INT8_MMA_AVAILABLE
  1333. }
  1334. template<int mmq_x, int mmq_y, int nwarps, bool need_check>
  1335. static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
  1336. #pragma unroll
  1337. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1338. const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
  1339. if (j >= ne1) {
  1340. return;
  1341. }
  1342. #pragma unroll
  1343. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1344. const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
  1345. if (need_check && i >= ne0) {
  1346. continue;
  1347. }
  1348. dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
  1349. }
  1350. }
  1351. }
  1352. template<int mmq_x, int mmq_y, int nwarps, bool need_check>
  1353. static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
  1354. typedef mma_int_C_I16J8 mma_C;
  1355. const int i0 = threadIdx.y*mma_C::I;
  1356. #ifdef INT8_MMA_AVAILABLE
  1357. static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
  1358. #endif // INT8_MMA_AVAILABLE
  1359. #pragma unroll
  1360. for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
  1361. #pragma unroll
  1362. for (int l = 0; l < mma_C::ne; ++l) {
  1363. const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l);
  1364. if (j >= ne1) {
  1365. continue;
  1366. }
  1367. const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l);
  1368. if (need_check && i >= ne0) {
  1369. continue;
  1370. }
  1371. dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
  1372. }
  1373. }
  1374. }
  1375. // -------------------------------------------------------------------------------------------------------------------------------------
  1376. template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
  1377. struct mmq_type_traits;
  1378. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1379. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
  1380. static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
  1381. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
  1382. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1383. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1384. };
  1385. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1386. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
  1387. static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
  1388. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
  1389. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1390. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1391. };
  1392. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1393. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
  1394. static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
  1395. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
  1396. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1397. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1398. };
  1399. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1400. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
  1401. static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
  1402. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
  1403. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1404. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1405. };
  1406. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1407. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
  1408. static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
  1409. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
  1410. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1411. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1412. };
  1413. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1414. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
  1415. static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
  1416. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
  1417. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1418. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1419. };
  1420. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1421. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
  1422. static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
  1423. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
  1424. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q3_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1425. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1426. };
  1427. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1428. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
  1429. static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
  1430. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
  1431. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1432. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1433. };
  1434. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1435. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
  1436. static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
  1437. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
  1438. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1439. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1440. };
  1441. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1442. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
  1443. static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
  1444. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
  1445. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1446. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1447. };
  1448. static bool mmq_need_sum(const ggml_type type_x) {
  1449. switch (type_x) {
  1450. case GGML_TYPE_Q4_0:
  1451. case GGML_TYPE_Q4_1:
  1452. return true;
  1453. case GGML_TYPE_Q5_0:
  1454. return false;
  1455. case GGML_TYPE_Q5_1:
  1456. return true;
  1457. case GGML_TYPE_Q8_0:
  1458. case GGML_TYPE_Q2_K:
  1459. case GGML_TYPE_Q3_K:
  1460. return false;
  1461. case GGML_TYPE_Q4_K:
  1462. case GGML_TYPE_Q5_K:
  1463. return true;
  1464. case GGML_TYPE_Q6_K:
  1465. return false;
  1466. default:
  1467. GGML_ASSERT(false);
  1468. break;
  1469. }
  1470. return false;
  1471. }
  1472. template <ggml_type type, int mmq_x, int nwarps, bool need_check>
  1473. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1474. #if defined(RDNA3) || defined(RDNA2)
  1475. __launch_bounds__(WARP_SIZE*nwarps, 2)
  1476. #endif // defined(RDNA3) || defined(RDNA2)
  1477. #else
  1478. #if __CUDA_ARCH__ >= CC_VOLTA
  1479. __launch_bounds__(WARP_SIZE*nwarps, 1)
  1480. #else
  1481. __launch_bounds__(WARP_SIZE*nwarps, 2)
  1482. #endif // __CUDA_ARCH__ >= CC_VOLTA
  1483. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1484. static __global__ void mul_mat_q(
  1485. const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
  1486. const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
  1487. // Skip unused template specializations for faster compilation:
  1488. if (mmq_x > get_mmq_x_max_device()) {
  1489. NO_DEVICE_CODE;
  1490. return;
  1491. }
  1492. constexpr int qk = ggml_cuda_type_traits<type>::qk;
  1493. constexpr int qr = ggml_cuda_type_traits<type>::qr;
  1494. constexpr int qi = ggml_cuda_type_traits<type>::qi;
  1495. constexpr int mmq_y = get_mmq_y_device(mmq_x);
  1496. constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
  1497. constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
  1498. #ifdef INT8_MMA_AVAILABLE
  1499. constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
  1500. constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
  1501. #else
  1502. constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
  1503. constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
  1504. #endif // INT8_MMA_AVAILABLE
  1505. constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
  1506. extern __shared__ char data_mul_mat_q[];
  1507. int * tile_x_qs = (int *) data_mul_mat_q;
  1508. half2 * tile_x_dm = (half2 *) (tile_x_qs + txs.qs);
  1509. int * tile_x_sc = (int *) (tile_x_dm + txs.dm);
  1510. int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
  1511. const int blocks_per_row_x = ne00 / qk;
  1512. const int blocks_per_warp = WARP_SIZE / qi;
  1513. const int & ne1 = ne11;
  1514. const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
  1515. const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
  1516. float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
  1517. for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
  1518. load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
  1519. #pragma unroll
  1520. for (int kr = 0; kr < qr; ++kr) {
  1521. const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int));
  1522. #pragma unroll
  1523. for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
  1524. int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
  1525. tile_y[l] = by0[l];
  1526. }
  1527. __syncthreads();
  1528. // #pragma unroll // unrolling this loop causes too much register pressure
  1529. for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
  1530. vec_dot(tile_x_qs, tile_x_dm, tile_x_sc, tile_y, sum, k0);
  1531. }
  1532. __syncthreads();
  1533. }
  1534. }
  1535. write_back(sum, dst, ne0, ne1);
  1536. }
  1537. struct mmq_args {
  1538. const char * x; const char * y; float * dst;
  1539. int64_t ne00; int64_t ne01; int64_t stride01;
  1540. int64_t ne10; int64_t ne11; int64_t stride11;
  1541. int64_t ne0;
  1542. };
  1543. constexpr int mmq_get_nwarps(int mmq_x) {
  1544. return mmq_x >= 32 ? 8 : 4;
  1545. }
  1546. static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) {
  1547. const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
  1548. const int nwarps = mmq_get_nwarps(mmq_x);
  1549. const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
  1550. const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
  1551. return shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
  1552. }
  1553. template <ggml_type type, int mmq_x, int nwarps>
  1554. static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
  1555. const int id = ggml_cuda_get_device();
  1556. const int cc = ggml_cuda_info().devices[id].cc;
  1557. const int mmq_y = get_mmq_y_host(cc, mmq_x);
  1558. const int block_num_x = (args.ne01 + mmq_y - 1) / mmq_y;
  1559. const int block_num_y = (args.ne11 + mmq_x - 1) / mmq_x;
  1560. const dim3 block_nums(block_num_x, block_num_y, 1);
  1561. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  1562. const int shmem = mmq_get_shmem(type, mmq_x, mmq_y);
  1563. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  1564. static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
  1565. if (!shmem_limit_raised[id]) {
  1566. CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
  1567. CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
  1568. shmem_limit_raised[id] = true;
  1569. }
  1570. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  1571. if (args.ne01 % mmq_y == 0) {
  1572. const bool need_check = false;
  1573. mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
  1574. (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  1575. } else {
  1576. const bool need_check = true;
  1577. mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
  1578. (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  1579. }
  1580. }
  1581. template <ggml_type type>
  1582. void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
  1583. const int id = ggml_cuda_get_device();
  1584. const int nsm = ggml_cuda_info().devices[id].nsm;
  1585. const int cc = ggml_cuda_info().devices[id].cc;
  1586. const int smpbo = ggml_cuda_info().devices[id].smpbo;
  1587. const int mmq_x_max = get_mmq_x_max_host(cc);
  1588. const int mmq_y = get_mmq_y_host(cc, mmq_x_max);
  1589. const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
  1590. int mmq_x_best = 0;
  1591. int nwaves_best = INT_MAX;
  1592. for (int mmq_x = 8; mmq_x <= mmq_x_max && nwaves_best > 1; mmq_x += 8) {
  1593. const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x;
  1594. const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm;
  1595. if (nwaves < nwaves_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
  1596. mmq_x_best = mmq_x;
  1597. nwaves_best = nwaves;
  1598. }
  1599. }
  1600. switch (mmq_x_best) {
  1601. case 8:
  1602. launch_mul_mat_q<type, 8, mmq_get_nwarps( 8)>(args, stream);
  1603. break;
  1604. case 16:
  1605. launch_mul_mat_q<type, 16, mmq_get_nwarps( 16)>(args, stream);
  1606. break;
  1607. case 24:
  1608. launch_mul_mat_q<type, 24, mmq_get_nwarps( 24)>(args, stream);
  1609. break;
  1610. case 32:
  1611. launch_mul_mat_q<type, 32, mmq_get_nwarps( 32)>(args, stream);
  1612. break;
  1613. case 40:
  1614. launch_mul_mat_q<type, 40, mmq_get_nwarps( 40)>(args, stream);
  1615. break;
  1616. case 48:
  1617. launch_mul_mat_q<type, 48, mmq_get_nwarps( 48)>(args, stream);
  1618. break;
  1619. case 56:
  1620. launch_mul_mat_q<type, 56, mmq_get_nwarps( 56)>(args, stream);
  1621. break;
  1622. case 64:
  1623. launch_mul_mat_q<type, 64, mmq_get_nwarps( 64)>(args, stream);
  1624. break;
  1625. case 72:
  1626. launch_mul_mat_q<type, 72, mmq_get_nwarps( 72)>(args, stream);
  1627. break;
  1628. case 80:
  1629. launch_mul_mat_q<type, 80, mmq_get_nwarps( 80)>(args, stream);
  1630. break;
  1631. case 88:
  1632. launch_mul_mat_q<type, 88, mmq_get_nwarps( 88)>(args, stream);
  1633. break;
  1634. case 96:
  1635. launch_mul_mat_q<type, 96, mmq_get_nwarps( 96)>(args, stream);
  1636. break;
  1637. case 104:
  1638. launch_mul_mat_q<type, 104, mmq_get_nwarps(104)>(args, stream);
  1639. break;
  1640. case 112:
  1641. launch_mul_mat_q<type, 112, mmq_get_nwarps(112)>(args, stream);
  1642. break;
  1643. case 120:
  1644. launch_mul_mat_q<type, 120, mmq_get_nwarps(120)>(args, stream);
  1645. break;
  1646. case 128:
  1647. launch_mul_mat_q<type, 128, mmq_get_nwarps(128)>(args, stream);
  1648. break;
  1649. default:
  1650. fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
  1651. GGML_ASSERT(false);
  1652. break;
  1653. }
  1654. }
  1655. #define DECL_MMQ_CASE(type) \
  1656. template void mul_mat_q_case<type>(const mmq_args & args, cudaStream_t stream) \
  1657. extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
  1658. extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
  1659. extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
  1660. extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
  1661. extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
  1662. extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
  1663. extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
  1664. extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
  1665. extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
  1666. extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
  1667. // -------------------------------------------------------------------------------------------------------------------------
  1668. void ggml_cuda_op_mul_mat_q(
  1669. ggml_backend_cuda_context & ctx,
  1670. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  1671. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  1672. const int64_t src1_padded_row_size, cudaStream_t stream);
  1673. bool ggml_cuda_supports_mmq(enum ggml_type type);